baenacoco's picture
Upload app.py with huggingface_hub
b2e5988 verified
"""Space 3: Train Voice (Whisper + F5-TTS fine-tuning)
Downloads audio from Hub -> Whisper transcription -> F5-TTS fine-tune -> saves model to Hub.
GPU: A100 (Whisper large-v3 + F5-TTS training)
"""
import gc
import json
import logging
import os
import shutil
import subprocess
import sys
import traceback
from pathlib import Path
import gradio as gr
import numpy as np
import soundfile as sf
import torch
from hub_utils import download_step, upload_step, list_projects, HF_DATASET_REPO_ID
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger(__name__)
# ── Config ──
IS_HF_SPACE = os.environ.get("SPACE_ID") is not None
_data_path = Path("/data")
if IS_HF_SPACE and _data_path.exists() and os.access(_data_path, os.W_OK):
BASE_DIR = _data_path
else:
BASE_DIR = Path("data")
AUDIO_DIR = BASE_DIR / "audio"
VOICE_MODEL_DIR = BASE_DIR / "voice_model"
TEMP_DIR = BASE_DIR / "temp"
HF_CACHE_DIR = BASE_DIR / "hf_cache"
for d in [AUDIO_DIR, VOICE_MODEL_DIR, TEMP_DIR, HF_CACHE_DIR]:
d.mkdir(parents=True, exist_ok=True)
os.environ["HF_HOME"] = str(HF_CACHE_DIR)
os.environ["TRANSFORMERS_CACHE"] = str(HF_CACHE_DIR)
AUDIO_SAMPLE_RATE = 16000
F5_SAMPLE_RATE = 24000
VOICE_FINETUNE_EPOCHS = 100
VOICE_FINETUNE_LR = 1e-5
VOICE_FINETUNE_BATCH_SIZE = 3200
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
APP_VERSION = "1.0.0"
def _clear_cache():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# ── Whisper transcription ──
def _transcribe_segments(segment_paths, progress_callback=None):
import whisper
logger.info("Loading Whisper for transcription...")
model = whisper.load_model("medium", device=DEVICE)
transcripts = []
for i, seg_path in enumerate(segment_paths):
if progress_callback:
progress_callback(i / len(segment_paths) * 0.3, f"Transcribiendo segmento {i+1}/{len(segment_paths)}...")
result = model.transcribe(seg_path, language="es", fp16=torch.cuda.is_available())
text = result["text"].strip()
if text:
transcripts.append({"audio_path": seg_path, "text": text})
del model
_clear_cache()
logger.info(f"Transcribed {len(transcripts)} segments")
return transcripts
# ── Dataset preparation ──
def _prepare_finetune_dataset(transcripts):
dataset_dir = TEMP_DIR / "voice_finetune_data"
if dataset_dir.exists():
shutil.rmtree(dataset_dir)
dataset_dir.mkdir(parents=True)
wavs_dir = dataset_dir / "wavs"
wavs_dir.mkdir()
metadata = []
for i, item in enumerate(transcripts):
audio, sr = sf.read(item["audio_path"])
if sr != F5_SAMPLE_RATE:
import torchaudio
audio_tensor = torch.from_numpy(audio).float()
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
resampler = torchaudio.transforms.Resample(sr, F5_SAMPLE_RATE)
audio_tensor = resampler(audio_tensor)
audio = audio_tensor.squeeze(0).numpy()
max_samples = 15 * F5_SAMPLE_RATE
min_samples = 2 * F5_SAMPLE_RATE
if len(audio) <= max_samples:
clips = [(audio, item["text"])]
else:
n_parts = max(1, len(audio) // (10 * F5_SAMPLE_RATE))
part_size = len(audio) // n_parts
clips = []
words = item["text"].split()
words_per_part = max(1, len(words) // n_parts)
for j in range(n_parts):
start = j * part_size
end = min((j + 1) * part_size, len(audio))
if end - start < min_samples:
continue
text_start = j * words_per_part
text_end = min((j + 1) * words_per_part, len(words))
part_text = " ".join(words[text_start:text_end])
if part_text:
clips.append((audio[start:end], part_text))
for j, (clip_audio, clip_text) in enumerate(clips):
fname = f"clip_{i:04d}_{j:02d}.wav"
wav_path = wavs_dir / fname
sf.write(str(wav_path), clip_audio, F5_SAMPLE_RATE)
duration = len(clip_audio) / F5_SAMPLE_RATE
metadata.append({"audio_file": str(wav_path.resolve()), "text": clip_text, "duration": round(duration, 2)})
meta_path = dataset_dir / "metadata.csv"
with open(meta_path, "w") as f:
f.write("audio_file|text\n")
for item in metadata:
f.write(f"{item['audio_file']}|{item['text']}\n")
logger.info(f"Prepared {len(metadata)} clips for fine-tuning")
return dataset_dir
def _ensure_vocab_file():
from importlib.resources import files as pkg_files
vocab_path = Path(pkg_files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt"))
if vocab_path.exists():
return
vocab_path.parent.mkdir(parents=True, exist_ok=True)
logger.info("Downloading pretrained vocab.txt for F5-TTS...")
import urllib.request
url = "https://raw.githubusercontent.com/SWivid/F5-TTS/main/data/Emilia_ZH_EN_pinyin/vocab.txt"
urllib.request.urlretrieve(url, str(vocab_path))
def _prepare_arrow_dataset(dataset_dir, progress_callback=None):
if progress_callback:
progress_callback(0.32, "Preparando dataset Arrow...")
_ensure_vocab_file()
meta_csv = dataset_dir / "metadata.csv"
arrow_dir = dataset_dir / "arrow_data"
arrow_dir.mkdir(parents=True, exist_ok=True)
import csv
from datasets import Dataset as HFDataset
from f5_tts.model.utils import convert_char_to_pinyin
audio_text_pairs = []
with open(meta_csv, "r", encoding="utf-8-sig") as f:
reader = csv.reader(f, delimiter="|")
next(reader, None)
for row in reader:
if len(row) >= 2:
audio_text_pairs.append((row[0].strip(), row[1].strip()))
if not audio_text_pairs:
raise RuntimeError("No audio-text pairs found in metadata.csv")
texts = [pair[1] for pair in audio_text_pairs]
pinyin_texts = convert_char_to_pinyin(texts, polyphone=True)
valid_audio_paths = []
valid_texts = []
durations = []
for i, (audio_path, text) in enumerate(audio_text_pairs):
if not Path(audio_path).exists():
continue
audio_info = sf.info(audio_path)
duration = audio_info.duration
if duration < 0.3 or duration > 30:
continue
valid_audio_paths.append(audio_path)
valid_texts.append(pinyin_texts[i])
durations.append(duration)
if not valid_audio_paths:
raise RuntimeError("No valid audio clips after filtering")
ds = HFDataset.from_dict({
"audio_path": valid_audio_paths,
"text": valid_texts,
"duration": durations,
})
ds.save_to_disk(str(arrow_dir / "raw"))
ds.to_parquet(str(arrow_dir / "raw.parquet"))
with open(arrow_dir / "duration.json", "w") as f:
json.dump({"duration": durations}, f)
from importlib.resources import files as pkg_files
pretrained_vocab = Path(pkg_files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt"))
if pretrained_vocab.exists():
shutil.copy2(str(pretrained_vocab), str(arrow_dir / "vocab.txt"))
logger.info(f"Arrow dataset: {len(valid_audio_paths)} samples, {sum(durations)/3600:.2f}h total")
return arrow_dir
def finetune_voice(segment_paths, epochs, learning_rate, batch_size, progress_callback=None):
if not segment_paths:
raise ValueError("No audio segments found.")
_clear_cache()
transcripts = _transcribe_segments(segment_paths, progress_callback)
if not transcripts:
raise ValueError("Could not transcribe any audio segments")
if progress_callback:
progress_callback(0.3, "Preparando dataset...")
dataset_dir = _prepare_finetune_dataset(transcripts)
arrow_dir = _prepare_arrow_dataset(dataset_dir, progress_callback)
if progress_callback:
progress_callback(0.35, "Iniciando fine-tuning F5-TTS...")
VOICE_MODEL_DIR.mkdir(parents=True, exist_ok=True)
from importlib.resources import files as pkg_files
f5_data_root = Path(pkg_files("f5_tts").joinpath("../../data"))
f5_data_root.mkdir(parents=True, exist_ok=True)
dataset_name = "voice_finetune"
tokenizer = "char"
expected_dir = f5_data_root / f"{dataset_name}_{tokenizer}"
if expected_dir.exists():
shutil.rmtree(expected_dir)
shutil.copytree(str(arrow_dir), str(expected_dir))
cmd = [
sys.executable, "-m", "f5_tts.train.finetune_cli",
"--exp_name", "F5TTS_v1_Base",
"--dataset_name", dataset_name,
"--learning_rate", str(learning_rate),
"--batch_size_per_gpu", str(batch_size),
"--epochs", str(epochs),
"--finetune",
"--save_per_updates", "500",
"--last_per_updates", "200",
"--num_warmup_updates", "100",
"--tokenizer", tokenizer,
]
logger.info(f"Running F5-TTS finetune: {' '.join(cmd)}")
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
for line in process.stdout:
line = line.strip()
if line:
logger.info(f"[F5-TTS] {line}")
if progress_callback and "step" in line.lower():
progress_callback(0.4, f"Training: {line[:80]}...")
process.wait()
if process.returncode != 0:
raise RuntimeError(f"F5-TTS fine-tuning failed with exit code {process.returncode}")
# F5-TTS saves checkpoints β€” search broadly for them
import glob as _glob
cwd = Path.cwd()
home = Path.home()
search_roots = [cwd, home, Path("/app"), Path("/home/user")]
found_any = False
voice_model_resolved = VOICE_MODEL_DIR.resolve()
for root in search_roots:
if not root.exists():
continue
for pattern in ["**/model_last*.pt", "**/model_last*.safetensors",
"**/ckpts/**/*.pt", "**/ckpts/**/*.safetensors"]:
for f in root.glob(pattern):
dest = VOICE_MODEL_DIR / f.name
# Skip if source and destination are the same file
if f.resolve() == dest.resolve():
logger.info(f"Skipping (same file): {f}")
found_any = True
continue
shutil.copy2(str(f), str(dest))
logger.info(f"Copied checkpoint: {f} -> {dest}")
found_any = True
if not found_any:
# Log what's in common checkpoint dirs for debugging
for d in [cwd / "ckpts", home / "ckpts", Path("/app/ckpts")]:
if d.exists():
all_files = list(d.rglob("*"))
logger.info(f"Files in {d}: {[str(f) for f in all_files[:20]]}")
else:
logger.info(f"Directory {d} does not exist")
# Also check cwd contents
logger.info(f"CWD={cwd}, contents: {[str(f) for f in cwd.iterdir()][:20]}")
ref_path = VOICE_MODEL_DIR / "reference.wav"
if segment_paths:
shutil.copy2(segment_paths[0], str(ref_path))
shutil.rmtree(dataset_dir, ignore_errors=True)
_clear_cache()
return str(VOICE_MODEL_DIR)
# ── Gradio handlers ──
def download_audio_from_hub(project_name, progress=gr.Progress()):
if not project_name or not project_name.strip():
return "Error: Debes introducir un nombre de proyecto"
name = project_name.strip()
try:
if AUDIO_DIR.exists():
shutil.rmtree(AUDIO_DIR)
AUDIO_DIR.mkdir(parents=True)
download_step(name, "step2_audio", str(BASE_DIR))
# Files are downloaded to BASE_DIR/{name}/step2_audio/ - move to AUDIO_DIR
src = BASE_DIR / name / "step2_audio"
if src.exists():
for f in src.iterdir():
shutil.move(str(f), str(AUDIO_DIR / f.name))
shutil.rmtree(BASE_DIR / name, ignore_errors=True)
segments = sorted(AUDIO_DIR.glob("segment_*.wav"))
return f"OK - Descargados {len(segments)} segmentos de audio"
except Exception as e:
return f"Error: {e}"
def train_voice_handler(project_name, epochs, lr, progress=gr.Progress()):
if not project_name or not project_name.strip():
return "Error: Debes introducir un nombre de proyecto"
segment_paths = sorted(str(p) for p in AUDIO_DIR.glob("segment_*.wav"))
if not segment_paths:
return "Error: No hay segmentos de audio. Descarga primero desde el Hub."
logger.info(f"=== Voice Training Started === epochs={epochs}, lr={lr}")
try:
result = finetune_voice(
segment_paths, epochs=int(epochs), learning_rate=lr,
batch_size=VOICE_FINETUNE_BATCH_SIZE,
progress_callback=lambda p, m: progress(p, desc=m),
)
logger.info(f"=== Voice Training Complete === {result}")
return f"OK - Modelo de voz guardado en: {result}"
except Exception as e:
logger.error(f"=== Voice Training Failed ===\n{traceback.format_exc()}")
return f"Error: {e}"
def save_to_hub(project_name):
if not project_name or not project_name.strip():
return "Error: Debes introducir un nombre de proyecto"
name = project_name.strip()
models = list(VOICE_MODEL_DIR.glob("*.pt")) + list(VOICE_MODEL_DIR.glob("*.safetensors"))
if not models:
return "Error: No hay modelo de voz para guardar. Entrena primero."
try:
return upload_step(name, "step3_voice", str(VOICE_MODEL_DIR))
except Exception as e:
return f"Error: {e}"
# ── UI ──
with gr.Blocks(title="Talking Head - Voice Train", theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# Talking Head - Entrenar Voz `v{APP_VERSION}`\nWhisper transcripcion + F5-TTS fine-tuning")
project_name = gr.Textbox(
label="Nombre del proyecto",
placeholder="mi_proyecto",
info="Obligatorio. Se usa como carpeta en el Hub.",
)
gr.Markdown("### 1. Descargar audio del Hub")
download_btn = gr.Button("Descargar audio del Hub", variant="secondary")
download_status = gr.Textbox(label="Estado descarga", interactive=False)
gr.Markdown("### 2. Entrenar modelo de voz")
with gr.Row():
voice_epochs = gr.Slider(10, 300, value=VOICE_FINETUNE_EPOCHS, step=10, label="Epochs")
voice_lr = gr.Number(value=VOICE_FINETUNE_LR, label="Learning Rate")
train_btn = gr.Button("Entrenar Voz", variant="primary")
train_status = gr.Textbox(label="Estado entrenamiento", interactive=False)
gr.Markdown("### 3. Guardar modelo en Hub")
save_btn = gr.Button("Guardar en Hub", variant="secondary")
save_status = gr.Textbox(label="Estado guardado", interactive=False)
download_btn.click(download_audio_from_hub, inputs=[project_name], outputs=[download_status])
train_btn.click(train_voice_handler, inputs=[project_name, voice_epochs, voice_lr], outputs=[train_status])
save_btn.click(save_to_hub, inputs=[project_name], outputs=[save_status])
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)