| """Space 3: Train Voice (Whisper + F5-TTS fine-tuning) |
| |
| Downloads audio from Hub -> Whisper transcription -> F5-TTS fine-tune -> saves model to Hub. |
| GPU: A100 (Whisper large-v3 + F5-TTS training) |
| """ |
| import gc |
| import json |
| import logging |
| import os |
| import shutil |
| import subprocess |
| import sys |
| import traceback |
| from pathlib import Path |
|
|
| import gradio as gr |
| import numpy as np |
| import soundfile as sf |
| import torch |
|
|
| from hub_utils import download_step, upload_step, list_projects, HF_DATASET_REPO_ID |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| |
| IS_HF_SPACE = os.environ.get("SPACE_ID") is not None |
| _data_path = Path("/data") |
| if IS_HF_SPACE and _data_path.exists() and os.access(_data_path, os.W_OK): |
| BASE_DIR = _data_path |
| else: |
| BASE_DIR = Path("data") |
|
|
| AUDIO_DIR = BASE_DIR / "audio" |
| VOICE_MODEL_DIR = BASE_DIR / "voice_model" |
| TEMP_DIR = BASE_DIR / "temp" |
| HF_CACHE_DIR = BASE_DIR / "hf_cache" |
|
|
| for d in [AUDIO_DIR, VOICE_MODEL_DIR, TEMP_DIR, HF_CACHE_DIR]: |
| d.mkdir(parents=True, exist_ok=True) |
|
|
| os.environ["HF_HOME"] = str(HF_CACHE_DIR) |
| os.environ["TRANSFORMERS_CACHE"] = str(HF_CACHE_DIR) |
|
|
| AUDIO_SAMPLE_RATE = 16000 |
| F5_SAMPLE_RATE = 24000 |
| VOICE_FINETUNE_EPOCHS = 100 |
| VOICE_FINETUNE_LR = 1e-5 |
| VOICE_FINETUNE_BATCH_SIZE = 3200 |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| APP_VERSION = "1.0.0" |
|
|
|
|
| def _clear_cache(): |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
|
|
|
|
| |
|
|
| def _transcribe_segments(segment_paths, progress_callback=None): |
| import whisper |
| logger.info("Loading Whisper for transcription...") |
| model = whisper.load_model("medium", device=DEVICE) |
| transcripts = [] |
| for i, seg_path in enumerate(segment_paths): |
| if progress_callback: |
| progress_callback(i / len(segment_paths) * 0.3, f"Transcribiendo segmento {i+1}/{len(segment_paths)}...") |
| result = model.transcribe(seg_path, language="es", fp16=torch.cuda.is_available()) |
| text = result["text"].strip() |
| if text: |
| transcripts.append({"audio_path": seg_path, "text": text}) |
| del model |
| _clear_cache() |
| logger.info(f"Transcribed {len(transcripts)} segments") |
| return transcripts |
|
|
|
|
| |
|
|
| def _prepare_finetune_dataset(transcripts): |
| dataset_dir = TEMP_DIR / "voice_finetune_data" |
| if dataset_dir.exists(): |
| shutil.rmtree(dataset_dir) |
| dataset_dir.mkdir(parents=True) |
| wavs_dir = dataset_dir / "wavs" |
| wavs_dir.mkdir() |
|
|
| metadata = [] |
| for i, item in enumerate(transcripts): |
| audio, sr = sf.read(item["audio_path"]) |
| if sr != F5_SAMPLE_RATE: |
| import torchaudio |
| audio_tensor = torch.from_numpy(audio).float() |
| if audio_tensor.dim() == 1: |
| audio_tensor = audio_tensor.unsqueeze(0) |
| resampler = torchaudio.transforms.Resample(sr, F5_SAMPLE_RATE) |
| audio_tensor = resampler(audio_tensor) |
| audio = audio_tensor.squeeze(0).numpy() |
|
|
| max_samples = 15 * F5_SAMPLE_RATE |
| min_samples = 2 * F5_SAMPLE_RATE |
|
|
| if len(audio) <= max_samples: |
| clips = [(audio, item["text"])] |
| else: |
| n_parts = max(1, len(audio) // (10 * F5_SAMPLE_RATE)) |
| part_size = len(audio) // n_parts |
| clips = [] |
| words = item["text"].split() |
| words_per_part = max(1, len(words) // n_parts) |
| for j in range(n_parts): |
| start = j * part_size |
| end = min((j + 1) * part_size, len(audio)) |
| if end - start < min_samples: |
| continue |
| text_start = j * words_per_part |
| text_end = min((j + 1) * words_per_part, len(words)) |
| part_text = " ".join(words[text_start:text_end]) |
| if part_text: |
| clips.append((audio[start:end], part_text)) |
|
|
| for j, (clip_audio, clip_text) in enumerate(clips): |
| fname = f"clip_{i:04d}_{j:02d}.wav" |
| wav_path = wavs_dir / fname |
| sf.write(str(wav_path), clip_audio, F5_SAMPLE_RATE) |
| duration = len(clip_audio) / F5_SAMPLE_RATE |
| metadata.append({"audio_file": str(wav_path.resolve()), "text": clip_text, "duration": round(duration, 2)}) |
|
|
| meta_path = dataset_dir / "metadata.csv" |
| with open(meta_path, "w") as f: |
| f.write("audio_file|text\n") |
| for item in metadata: |
| f.write(f"{item['audio_file']}|{item['text']}\n") |
|
|
| logger.info(f"Prepared {len(metadata)} clips for fine-tuning") |
| return dataset_dir |
|
|
|
|
| def _ensure_vocab_file(): |
| from importlib.resources import files as pkg_files |
| vocab_path = Path(pkg_files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")) |
| if vocab_path.exists(): |
| return |
| vocab_path.parent.mkdir(parents=True, exist_ok=True) |
| logger.info("Downloading pretrained vocab.txt for F5-TTS...") |
| import urllib.request |
| url = "https://raw.githubusercontent.com/SWivid/F5-TTS/main/data/Emilia_ZH_EN_pinyin/vocab.txt" |
| urllib.request.urlretrieve(url, str(vocab_path)) |
|
|
|
|
| def _prepare_arrow_dataset(dataset_dir, progress_callback=None): |
| if progress_callback: |
| progress_callback(0.32, "Preparando dataset Arrow...") |
|
|
| _ensure_vocab_file() |
|
|
| meta_csv = dataset_dir / "metadata.csv" |
| arrow_dir = dataset_dir / "arrow_data" |
| arrow_dir.mkdir(parents=True, exist_ok=True) |
|
|
| import csv |
| from datasets import Dataset as HFDataset |
| from f5_tts.model.utils import convert_char_to_pinyin |
|
|
| audio_text_pairs = [] |
| with open(meta_csv, "r", encoding="utf-8-sig") as f: |
| reader = csv.reader(f, delimiter="|") |
| next(reader, None) |
| for row in reader: |
| if len(row) >= 2: |
| audio_text_pairs.append((row[0].strip(), row[1].strip())) |
|
|
| if not audio_text_pairs: |
| raise RuntimeError("No audio-text pairs found in metadata.csv") |
|
|
| texts = [pair[1] for pair in audio_text_pairs] |
| pinyin_texts = convert_char_to_pinyin(texts, polyphone=True) |
|
|
| valid_audio_paths = [] |
| valid_texts = [] |
| durations = [] |
| for i, (audio_path, text) in enumerate(audio_text_pairs): |
| if not Path(audio_path).exists(): |
| continue |
| audio_info = sf.info(audio_path) |
| duration = audio_info.duration |
| if duration < 0.3 or duration > 30: |
| continue |
| valid_audio_paths.append(audio_path) |
| valid_texts.append(pinyin_texts[i]) |
| durations.append(duration) |
|
|
| if not valid_audio_paths: |
| raise RuntimeError("No valid audio clips after filtering") |
|
|
| ds = HFDataset.from_dict({ |
| "audio_path": valid_audio_paths, |
| "text": valid_texts, |
| "duration": durations, |
| }) |
| ds.save_to_disk(str(arrow_dir / "raw")) |
| ds.to_parquet(str(arrow_dir / "raw.parquet")) |
|
|
| with open(arrow_dir / "duration.json", "w") as f: |
| json.dump({"duration": durations}, f) |
|
|
| from importlib.resources import files as pkg_files |
| pretrained_vocab = Path(pkg_files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")) |
| if pretrained_vocab.exists(): |
| shutil.copy2(str(pretrained_vocab), str(arrow_dir / "vocab.txt")) |
|
|
| logger.info(f"Arrow dataset: {len(valid_audio_paths)} samples, {sum(durations)/3600:.2f}h total") |
| return arrow_dir |
|
|
|
|
| def finetune_voice(segment_paths, epochs, learning_rate, batch_size, progress_callback=None): |
| if not segment_paths: |
| raise ValueError("No audio segments found.") |
|
|
| _clear_cache() |
|
|
| transcripts = _transcribe_segments(segment_paths, progress_callback) |
| if not transcripts: |
| raise ValueError("Could not transcribe any audio segments") |
|
|
| if progress_callback: |
| progress_callback(0.3, "Preparando dataset...") |
| dataset_dir = _prepare_finetune_dataset(transcripts) |
| arrow_dir = _prepare_arrow_dataset(dataset_dir, progress_callback) |
|
|
| if progress_callback: |
| progress_callback(0.35, "Iniciando fine-tuning F5-TTS...") |
|
|
| VOICE_MODEL_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| from importlib.resources import files as pkg_files |
| f5_data_root = Path(pkg_files("f5_tts").joinpath("../../data")) |
| f5_data_root.mkdir(parents=True, exist_ok=True) |
|
|
| dataset_name = "voice_finetune" |
| tokenizer = "char" |
| expected_dir = f5_data_root / f"{dataset_name}_{tokenizer}" |
| if expected_dir.exists(): |
| shutil.rmtree(expected_dir) |
| shutil.copytree(str(arrow_dir), str(expected_dir)) |
|
|
| cmd = [ |
| sys.executable, "-m", "f5_tts.train.finetune_cli", |
| "--exp_name", "F5TTS_v1_Base", |
| "--dataset_name", dataset_name, |
| "--learning_rate", str(learning_rate), |
| "--batch_size_per_gpu", str(batch_size), |
| "--epochs", str(epochs), |
| "--finetune", |
| "--save_per_updates", "500", |
| "--last_per_updates", "200", |
| "--num_warmup_updates", "100", |
| "--tokenizer", tokenizer, |
| ] |
|
|
| logger.info(f"Running F5-TTS finetune: {' '.join(cmd)}") |
|
|
| process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1) |
| for line in process.stdout: |
| line = line.strip() |
| if line: |
| logger.info(f"[F5-TTS] {line}") |
| if progress_callback and "step" in line.lower(): |
| progress_callback(0.4, f"Training: {line[:80]}...") |
|
|
| process.wait() |
| if process.returncode != 0: |
| raise RuntimeError(f"F5-TTS fine-tuning failed with exit code {process.returncode}") |
|
|
| |
| import glob as _glob |
| cwd = Path.cwd() |
| home = Path.home() |
| search_roots = [cwd, home, Path("/app"), Path("/home/user")] |
| found_any = False |
| voice_model_resolved = VOICE_MODEL_DIR.resolve() |
| for root in search_roots: |
| if not root.exists(): |
| continue |
| for pattern in ["**/model_last*.pt", "**/model_last*.safetensors", |
| "**/ckpts/**/*.pt", "**/ckpts/**/*.safetensors"]: |
| for f in root.glob(pattern): |
| dest = VOICE_MODEL_DIR / f.name |
| |
| if f.resolve() == dest.resolve(): |
| logger.info(f"Skipping (same file): {f}") |
| found_any = True |
| continue |
| shutil.copy2(str(f), str(dest)) |
| logger.info(f"Copied checkpoint: {f} -> {dest}") |
| found_any = True |
| if not found_any: |
| |
| for d in [cwd / "ckpts", home / "ckpts", Path("/app/ckpts")]: |
| if d.exists(): |
| all_files = list(d.rglob("*")) |
| logger.info(f"Files in {d}: {[str(f) for f in all_files[:20]]}") |
| else: |
| logger.info(f"Directory {d} does not exist") |
| |
| logger.info(f"CWD={cwd}, contents: {[str(f) for f in cwd.iterdir()][:20]}") |
|
|
| ref_path = VOICE_MODEL_DIR / "reference.wav" |
| if segment_paths: |
| shutil.copy2(segment_paths[0], str(ref_path)) |
|
|
| shutil.rmtree(dataset_dir, ignore_errors=True) |
| _clear_cache() |
|
|
| return str(VOICE_MODEL_DIR) |
|
|
|
|
| |
|
|
| def download_audio_from_hub(project_name, progress=gr.Progress()): |
| if not project_name or not project_name.strip(): |
| return "Error: Debes introducir un nombre de proyecto" |
| name = project_name.strip() |
| try: |
| if AUDIO_DIR.exists(): |
| shutil.rmtree(AUDIO_DIR) |
| AUDIO_DIR.mkdir(parents=True) |
|
|
| download_step(name, "step2_audio", str(BASE_DIR)) |
| |
| src = BASE_DIR / name / "step2_audio" |
| if src.exists(): |
| for f in src.iterdir(): |
| shutil.move(str(f), str(AUDIO_DIR / f.name)) |
| shutil.rmtree(BASE_DIR / name, ignore_errors=True) |
|
|
| segments = sorted(AUDIO_DIR.glob("segment_*.wav")) |
| return f"OK - Descargados {len(segments)} segmentos de audio" |
| except Exception as e: |
| return f"Error: {e}" |
|
|
|
|
| def train_voice_handler(project_name, epochs, lr, progress=gr.Progress()): |
| if not project_name or not project_name.strip(): |
| return "Error: Debes introducir un nombre de proyecto" |
|
|
| segment_paths = sorted(str(p) for p in AUDIO_DIR.glob("segment_*.wav")) |
| if not segment_paths: |
| return "Error: No hay segmentos de audio. Descarga primero desde el Hub." |
|
|
| logger.info(f"=== Voice Training Started === epochs={epochs}, lr={lr}") |
| try: |
| result = finetune_voice( |
| segment_paths, epochs=int(epochs), learning_rate=lr, |
| batch_size=VOICE_FINETUNE_BATCH_SIZE, |
| progress_callback=lambda p, m: progress(p, desc=m), |
| ) |
| logger.info(f"=== Voice Training Complete === {result}") |
| return f"OK - Modelo de voz guardado en: {result}" |
| except Exception as e: |
| logger.error(f"=== Voice Training Failed ===\n{traceback.format_exc()}") |
| return f"Error: {e}" |
|
|
|
|
| def save_to_hub(project_name): |
| if not project_name or not project_name.strip(): |
| return "Error: Debes introducir un nombre de proyecto" |
| name = project_name.strip() |
| models = list(VOICE_MODEL_DIR.glob("*.pt")) + list(VOICE_MODEL_DIR.glob("*.safetensors")) |
| if not models: |
| return "Error: No hay modelo de voz para guardar. Entrena primero." |
| try: |
| return upload_step(name, "step3_voice", str(VOICE_MODEL_DIR)) |
| except Exception as e: |
| return f"Error: {e}" |
|
|
|
|
| |
|
|
| with gr.Blocks(title="Talking Head - Voice Train", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(f"# Talking Head - Entrenar Voz `v{APP_VERSION}`\nWhisper transcripcion + F5-TTS fine-tuning") |
|
|
| project_name = gr.Textbox( |
| label="Nombre del proyecto", |
| placeholder="mi_proyecto", |
| info="Obligatorio. Se usa como carpeta en el Hub.", |
| ) |
|
|
| gr.Markdown("### 1. Descargar audio del Hub") |
| download_btn = gr.Button("Descargar audio del Hub", variant="secondary") |
| download_status = gr.Textbox(label="Estado descarga", interactive=False) |
|
|
| gr.Markdown("### 2. Entrenar modelo de voz") |
| with gr.Row(): |
| voice_epochs = gr.Slider(10, 300, value=VOICE_FINETUNE_EPOCHS, step=10, label="Epochs") |
| voice_lr = gr.Number(value=VOICE_FINETUNE_LR, label="Learning Rate") |
| train_btn = gr.Button("Entrenar Voz", variant="primary") |
| train_status = gr.Textbox(label="Estado entrenamiento", interactive=False) |
|
|
| gr.Markdown("### 3. Guardar modelo en Hub") |
| save_btn = gr.Button("Guardar en Hub", variant="secondary") |
| save_status = gr.Textbox(label="Estado guardado", interactive=False) |
|
|
| download_btn.click(download_audio_from_hub, inputs=[project_name], outputs=[download_status]) |
| train_btn.click(train_voice_handler, inputs=[project_name, voice_epochs, voice_lr], outputs=[train_status]) |
| save_btn.click(save_to_hub, inputs=[project_name], outputs=[save_status]) |
|
|
| if __name__ == "__main__": |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False) |
|
|