f5-tts-hungarian / inference_example.py
Maxdorger29's picture
Upload folder using huggingface_hub
75fb62d verified
"""
F5-TTS Hungarian — Inference Example
Zero-shot voice cloning for Hungarian text-to-speech.
Usage:
python inference_example.py --ref_audio your_voice.wav --text "Szia, ez egy teszt."
Requirements:
pip install torch torchaudio soundfile numpy f5-tts faster-whisper
"""
import argparse
import sys
import os
import time
import numpy as np
import soundfile as sf
import torch
import torchaudio
# ── Monkey-patch torchaudio for cross-platform compatibility ──
_orig_load = torchaudio.load
def _patched_load(fp, **kw):
d, sr = sf.read(str(fp), dtype="float32")
if d.ndim == 1:
d = d[np.newaxis, :]
else:
d = d.T
return torch.from_numpy(d), sr
torchaudio.load = _patched_load
_orig_save = torchaudio.save
def _patched_save(fp, waveform, sample_rate, **kw):
wav_np = waveform.squeeze(0).numpy() if waveform.dim() > 1 else waveform.numpy()
sf.write(str(fp), wav_np, sample_rate)
torchaudio.save = _patched_save
from f5_tts.api import F5TTS
def transcribe_reference(audio_path: str) -> str:
"""Transcribe reference audio using Whisper large-v3-turbo.
CRITICAL: The ref_text MUST match the actual content of ref_audio.
If they don't match, the generated audio will be garbled.
"""
try:
from faster_whisper import WhisperModel
print("Transcribing reference audio with Whisper...")
whisper = WhisperModel("large-v3-turbo", device="cuda" if torch.cuda.is_available() else "cpu")
segments, info = whisper.transcribe(audio_path, language="hu", beam_size=5)
text = " ".join(s.text.strip() for s in segments)
print(f"Transcription: {text}")
del whisper
if torch.cuda.is_available():
torch.cuda.empty_cache()
return text
except ImportError:
print("WARNING: faster-whisper not installed. You must provide --ref_text manually.")
print("Install: pip install faster-whisper")
sys.exit(1)
def trim_adaptive(audio: np.ndarray, sr: int,
max_trim_ms: int = 400,
energy_window_ms: int = 10,
threshold_ratio: float = 0.15) -> np.ndarray:
"""Trim leading artifact using energy-based detection.
The model sometimes adds prefix vowels or consonant deformations.
This adaptive trim removes them based on energy analysis.
"""
max_samples = int(sr * max_trim_ms / 1000)
window_samples = int(sr * energy_window_ms / 1000)
if len(audio) < max_samples:
return audio
# Calculate energy in windows
search_region = audio[:max_samples]
energies = []
for i in range(0, len(search_region) - window_samples, window_samples):
window = search_region[i:i + window_samples]
energies.append(np.sqrt(np.mean(window ** 2)))
if not energies:
return audio
# Find the first window that exceeds the threshold
max_energy = max(energies)
threshold = max_energy * threshold_ratio
trim_point = 0
for i, e in enumerate(energies):
if e > threshold:
trim_point = max(0, i * window_samples - window_samples)
break
return audio[trim_point:]
def main():
parser = argparse.ArgumentParser(description="F5-TTS Hungarian — Zero-shot Voice Cloning")
parser.add_argument("--text", type=str, required=True,
help="Text to synthesize in Hungarian")
parser.add_argument("--ref_audio", type=str, required=True,
help="Path to reference audio (5-15 seconds, WAV)")
parser.add_argument("--ref_text", type=str, default=None,
help="Exact transcription of ref_audio. If not provided, Whisper will transcribe it.")
parser.add_argument("--output", type=str, default="output.wav",
help="Output WAV file path")
parser.add_argument("--ckpt", type=str, default="model_last_final.safetensors",
help="Path to model checkpoint (.safetensors or .pt)")
parser.add_argument("--vocab", type=str, default="vocab.txt",
help="Path to vocabulary file")
parser.add_argument("--device", type=str, default="cuda",
help="Device: cuda or cpu")
parser.add_argument("--no_trim", action="store_true",
help="Disable adaptive artifact trimming")
args = parser.parse_args()
# Validate inputs
if not os.path.isfile(args.ref_audio):
print(f"Error: Reference audio not found: {args.ref_audio}")
sys.exit(1)
if not os.path.isfile(args.ckpt):
print(f"Error: Checkpoint not found: {args.ckpt}")
sys.exit(1)
# Get reference text
if args.ref_text is None:
ref_text = transcribe_reference(args.ref_audio)
else:
ref_text = args.ref_text
# Load model
print(f"Loading F5-TTS Hungarian model from {args.ckpt}...")
t0 = time.time()
model = F5TTS(
model="F5TTS_v1_Base",
ckpt_file=args.ckpt,
vocab_file=args.vocab,
device=args.device,
use_ema=True,
)
print(f"Model loaded in {time.time()-t0:.1f}s")
# Generate
print(f"Generating: \"{args.text[:80]}{'...' if len(args.text) > 80 else ''}\"")
t0 = time.time()
wav, sr, _ = model.infer(
ref_file=args.ref_audio,
ref_text=ref_text,
gen_text=args.text,
)
gen_time = time.time() - t0
duration = len(wav) / sr
# Trim artifacts
if not args.no_trim:
wav = trim_adaptive(wav, sr)
# Save
sf.write(args.output, wav, sr)
print(f"✅ Generated {duration:.1f}s audio in {gen_time:.1f}s (RTF: {gen_time/duration:.2f})")
print(f"Saved to: {args.output}")
if __name__ == "__main__":
main()