File size: 5,803 Bytes
75fb62d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
F5-TTS Hungarian — Inference Example
Zero-shot voice cloning for Hungarian text-to-speech.

Usage:
    python inference_example.py --ref_audio your_voice.wav --text "Szia, ez egy teszt."

Requirements:
    pip install torch torchaudio soundfile numpy f5-tts faster-whisper
"""
import argparse
import sys
import os
import time
import numpy as np
import soundfile as sf
import torch
import torchaudio

# ── Monkey-patch torchaudio for cross-platform compatibility ──
_orig_load = torchaudio.load
def _patched_load(fp, **kw):
    d, sr = sf.read(str(fp), dtype="float32")
    if d.ndim == 1:
        d = d[np.newaxis, :]
    else:
        d = d.T
    return torch.from_numpy(d), sr
torchaudio.load = _patched_load

_orig_save = torchaudio.save
def _patched_save(fp, waveform, sample_rate, **kw):
    wav_np = waveform.squeeze(0).numpy() if waveform.dim() > 1 else waveform.numpy()
    sf.write(str(fp), wav_np, sample_rate)
torchaudio.save = _patched_save

from f5_tts.api import F5TTS


def transcribe_reference(audio_path: str) -> str:
    """Transcribe reference audio using Whisper large-v3-turbo.
    
    CRITICAL: The ref_text MUST match the actual content of ref_audio.
    If they don't match, the generated audio will be garbled.
    """
    try:
        from faster_whisper import WhisperModel
        print("Transcribing reference audio with Whisper...")
        whisper = WhisperModel("large-v3-turbo", device="cuda" if torch.cuda.is_available() else "cpu")
        segments, info = whisper.transcribe(audio_path, language="hu", beam_size=5)
        text = " ".join(s.text.strip() for s in segments)
        print(f"Transcription: {text}")
        del whisper
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return text
    except ImportError:
        print("WARNING: faster-whisper not installed. You must provide --ref_text manually.")
        print("Install: pip install faster-whisper")
        sys.exit(1)


def trim_adaptive(audio: np.ndarray, sr: int, 
                  max_trim_ms: int = 400, 
                  energy_window_ms: int = 10,
                  threshold_ratio: float = 0.15) -> np.ndarray:
    """Trim leading artifact using energy-based detection.
    
    The model sometimes adds prefix vowels or consonant deformations.
    This adaptive trim removes them based on energy analysis.
    """
    max_samples = int(sr * max_trim_ms / 1000)
    window_samples = int(sr * energy_window_ms / 1000)
    
    if len(audio) < max_samples:
        return audio
    
    # Calculate energy in windows
    search_region = audio[:max_samples]
    energies = []
    for i in range(0, len(search_region) - window_samples, window_samples):
        window = search_region[i:i + window_samples]
        energies.append(np.sqrt(np.mean(window ** 2)))
    
    if not energies:
        return audio
    
    # Find the first window that exceeds the threshold
    max_energy = max(energies)
    threshold = max_energy * threshold_ratio
    
    trim_point = 0
    for i, e in enumerate(energies):
        if e > threshold:
            trim_point = max(0, i * window_samples - window_samples)
            break
    
    return audio[trim_point:]


def main():
    parser = argparse.ArgumentParser(description="F5-TTS Hungarian — Zero-shot Voice Cloning")
    parser.add_argument("--text", type=str, required=True,
                        help="Text to synthesize in Hungarian")
    parser.add_argument("--ref_audio", type=str, required=True,
                        help="Path to reference audio (5-15 seconds, WAV)")
    parser.add_argument("--ref_text", type=str, default=None,
                        help="Exact transcription of ref_audio. If not provided, Whisper will transcribe it.")
    parser.add_argument("--output", type=str, default="output.wav",
                        help="Output WAV file path")
    parser.add_argument("--ckpt", type=str, default="model_last_final.safetensors",
                        help="Path to model checkpoint (.safetensors or .pt)")
    parser.add_argument("--vocab", type=str, default="vocab.txt",
                        help="Path to vocabulary file")
    parser.add_argument("--device", type=str, default="cuda",
                        help="Device: cuda or cpu")
    parser.add_argument("--no_trim", action="store_true",
                        help="Disable adaptive artifact trimming")
    args = parser.parse_args()

    # Validate inputs
    if not os.path.isfile(args.ref_audio):
        print(f"Error: Reference audio not found: {args.ref_audio}")
        sys.exit(1)
    if not os.path.isfile(args.ckpt):
        print(f"Error: Checkpoint not found: {args.ckpt}")
        sys.exit(1)

    # Get reference text
    if args.ref_text is None:
        ref_text = transcribe_reference(args.ref_audio)
    else:
        ref_text = args.ref_text

    # Load model
    print(f"Loading F5-TTS Hungarian model from {args.ckpt}...")
    t0 = time.time()
    model = F5TTS(
        model="F5TTS_v1_Base",
        ckpt_file=args.ckpt,
        vocab_file=args.vocab,
        device=args.device,
        use_ema=True,
    )
    print(f"Model loaded in {time.time()-t0:.1f}s")

    # Generate
    print(f"Generating: \"{args.text[:80]}{'...' if len(args.text) > 80 else ''}\"")
    t0 = time.time()
    wav, sr, _ = model.infer(
        ref_file=args.ref_audio,
        ref_text=ref_text,
        gen_text=args.text,
    )
    gen_time = time.time() - t0
    duration = len(wav) / sr

    # Trim artifacts
    if not args.no_trim:
        wav = trim_adaptive(wav, sr)

    # Save
    sf.write(args.output, wav, sr)
    
    print(f"✅ Generated {duration:.1f}s audio in {gen_time:.1f}s (RTF: {gen_time/duration:.2f})")
    print(f"Saved to: {args.output}")


if __name__ == "__main__":
    main()