File size: 11,849 Bytes
cf7fa42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
#!/usr/bin/env python3
"""
F5-TTS Voice Cloning Script (Portuguese/Multi-lingual)
Wraps AgentF5TTSChunk for convenient CLI usage.

Usage:
  Single mode: python voice_clone.py --text "Olá mundo" --ref-audio voice.wav --checkpoint models/model.safetensors
  Batch mode:  python voice_clone.py --srt subtitles.srt --ref-dir ./speakers --checkpoint models/model.safetensors
"""

import argparse
import os
import re
import sys
import logging
import torch
from typing import List, Dict, Optional, Tuple

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

try:
    from tqdm import tqdm
except ImportError:
    # Fallback if tqdm is not installed
    def tqdm(iterable, **kwargs):
        return iterable

try:
    from AgentF5TTSChunk import AgentF5TTS
except ImportError:
    # If not in same dir, try adding current dir to path
    sys.path.append(os.getcwd())
    try:
        from AgentF5TTSChunk import AgentF5TTS
    except ImportError:
        logger.error("Error: AgentF5TTSChunk.py not found.")
        sys.exit(1)


def parse_srt(srt_file: str) -> List[Dict]:
    """
    Parse SRT file and extract subtitle entries
    Returns list of dicts with 'id', 'start', 'end', 'text'
    """
    logger.info(f"Parsing SRT file: {srt_file}")

    with open(srt_file, 'r', encoding='utf-8') as f:
        content = f.read()

    # Normalize newlines
    content = content.replace('\r\n', '\n')
    
    # Split by double newlines to separate subtitle blocks, handle multiple newlines
    blocks = re.split(r'\n{2,}', content.strip())

    subtitles = []
    for block in blocks:
        lines = [l.strip() for l in block.split('\n') if l.strip()]
        if len(lines) >= 2: # At least ID and Timestamp
            try:
                # First line should be the ID
                if lines[0].isdigit():
                    subtitle_id = int(lines[0])
                    timestamp_line_idx = 1
                else:
                    # Sometimes ID is missing or merged? Try to find timestamp line
                    subtitle_id = len(subtitles) + 1
                    timestamp_line_idx = 0
                    if '-->' not in lines[0]:
                         logger.warning(f"Skipping malformed block (no timestamp): {block[:50]}...")
                         continue

                timestamp = lines[timestamp_line_idx]
                # Remaining lines are the text
                text = ' '.join(lines[timestamp_line_idx + 1:]).strip()

                if text:
                    subtitles.append({
                        'id': subtitle_id,
                        'timestamp': timestamp,
                        'text': text
                    })
            except (ValueError, IndexError) as e:
                logger.warning(f"Skipping malformed block: {block[:50]}... Error: {e}")
                continue

    logger.info(f"Parsed {len(subtitles)} subtitle entries")
    return subtitles


def find_reference_audio(reference_dir: str, subtitle_id: int, audio_prefix: str = 'segment') -> Optional[str]:
    """
    Fallback: Find reference audio by ID (e.g., segment_001.wav)
    """
    if not reference_dir:
        return None
        
    patterns = [
        f"{audio_prefix}_{subtitle_id:03d}.wav",
        f"{audio_prefix}_{subtitle_id:03d}.mp3",
        f"{audio_prefix}_{subtitle_id:03d}.MP4",
        f"{audio_prefix}_{subtitle_id}.wav",
        f"{audio_prefix}_{subtitle_id}.mp3",
        f"{audio_prefix}_{subtitle_id}.MP4",
        f"{audio_prefix}{subtitle_id:03d}.wav",
        f"{audio_prefix}{subtitle_id:03d}.mp3",
        f"{audio_prefix}{subtitle_id:03d}.MP4",
    ]

    for pattern in patterns:
        audio_path = os.path.join(reference_dir, pattern)
        if os.path.exists(audio_path):
            return audio_path

    return None


def resolve_speaker_ref(agent: AgentF5TTS, text: str, reference_dir: str, default_ref: Optional[str] = None) -> Tuple[str, Optional[str]]:
    """
    Use agent's logic to parse speaker/emotion, then resolve file.
    """
    # Use the agent's internal parser
    # Note: Accessing protected member _determine_speaker_emotion
    speaker, emotion = agent._determine_speaker_emotion(text)
    
    # Remove tags from text
    clean_text = re.sub(r'\[speaker:.*?\]\s*', '', text).strip()
    
    ref_audio = default_ref

    if speaker and reference_dir:
        # Candidate filenames to look for
        candidates = []
        if emotion and emotion != "neutral":
            candidates.append(f"{speaker}_{emotion}.wav")
            candidates.append(f"{speaker}_{emotion}.mp3")
        
        candidates.append(f"{speaker}.wav")
        candidates.append(f"{speaker}.mp3")
        
        # Lowercase fallback
        if emotion and emotion != "neutral":
            candidates.append(f"{speaker.lower()}_{emotion.lower()}.wav")
        candidates.append(f"{speaker.lower()}.wav")
        
        found = False
        for cand in candidates:
            path = os.path.join(reference_dir, cand)
            if os.path.exists(path):
                ref_audio = path
                found = True
                break
        
        if found:
            logger.debug(f"Role matched: {os.path.basename(ref_audio)} (Speaker: {speaker}, Emotion: {emotion})")
    
    return clean_text, ref_audio


def parse_args():
    parser = argparse.ArgumentParser(
        description='F5-TTS Voice Cloning Script (Wraps AgentF5TTS)',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
EXAMPLES:
  # Single Mode
  python voice_clone.py --text "Olá, tudo bem?" --ref-audio ref.wav --checkpoint models/model.safetensors

  # Batch Mode (SRT)
  python voice_clone.py --srt subs.srt --ref-dir ./speakers --checkpoint models/model.safetensors
        """
    )

    # Input Mode
    mode_group = parser.add_mutually_exclusive_group(required=True)
    mode_group.add_argument('--text', type=str, help='Text to synthesize')
    mode_group.add_argument('--srt', type=str, help='Path to SRT subtitle file')

    # Reference Audio
    ref_group = parser.add_mutually_exclusive_group()
    ref_group.add_argument('--ref-audio', type=str, help='[Single] Reference audio path')
    ref_group.add_argument('--ref-dir', type=str, help='[Batch] Directory with reference audios (speakers or segments)')
    # Alias for backward compatibility or typo tolerance
    ref_group.add_argument('--reference-dir', dest='ref_dir', help=argparse.SUPPRESS)

    # Reference Text (Optional, prevents model from transcribing audio)
    parser.add_argument('--ref-text', type=str, default="", help='Reference text for the reference audio (optional)')

    # Model Configuration
    parser.add_argument('--checkpoint', type=str, required=True, help='Path to F5-TTS safetensors checkpoint')
    parser.add_argument('--vocoder', type=str, default='vocos', choices=['vocos', 'bigvgan'], help='Vocoder type')
    parser.add_argument('--device', type=str, default=None, help='Device (cuda:0, cpu, mps)')
    parser.add_argument('--speed', type=float, default=1.0, help='Speed factor for speech generation (default: 1.0)')
    
    # Output Configuration
    parser.add_argument('--output', type=str, default='outputs', help='Output directory')
    parser.add_argument('--output-prefix', type=str, default='clone', help='Output filename prefix')
    parser.add_argument('--skip-existing', action='store_true', help='Skip existing output files')

    # Batch specialized
    parser.add_argument('--audio-prefix', type=str, default='segment', help='Prefix for ID-based reference lookup')

    return parser.parse_args()


def main():
    args = parse_args()

    # Device Setup
    if args.device:
        device = args.device
    else:
        device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    
    logger.info(f"Using device: {device}")

    # Create Output Dir
    os.makedirs(args.output, exist_ok=True)

    # Initialize Agent
    logger.info(f"Initializing AgentF5TTS with checkpoint: {args.checkpoint}")
    try:
        agent = AgentF5TTS(
            ckpt_file=args.checkpoint, 
            vocoder_name=args.vocoder, 
            device=device
        )
    except Exception as e:
        logger.error(f"Failed to initialize agent: {e}")
        return

    # Single Mode
    if args.text:
        logger.info("-" * 40)
        logger.info("SINGLE MODE PROCESSING")
        logger.info("-" * 40)
        
        if not args.ref_audio or not os.path.exists(args.ref_audio):
            logger.error(f"Reference audio not found: {args.ref_audio}")
            return

        # Try to parse speaker tags just in case
        clean_text, effective_ref = resolve_speaker_ref(
            agent, 
            args.text, 
            os.path.dirname(args.ref_audio), 
            default_ref=args.ref_audio
        )
        
        output_path = os.path.join(args.output, "output_single.wav")
        logger.info(f"Text: {clean_text}")
        logger.info(f"Ref:  {effective_ref}")
        
        try:
            agent.infer(
                ref_file=effective_ref,
                ref_text=args.ref_text,
                gen_text=clean_text,
                file_wave=output_path,
                remove_silence=True,
                speed=args.speed
            )
            logger.info(f"✓ Saved: {output_path}")
        except Exception as e:
            logger.error(f"✗ Error: {e}")

    # Batch Mode
    elif args.srt:
        logger.info("-" * 40)
        logger.info("BATCH MODE PROCESSING")
        logger.info("-" * 40)

        subtitles = parse_srt(args.srt)
        if not subtitles:
            logger.error("No subtitles found.")
            return

        logger.info(f"Processing {len(subtitles)} entries...")
        success = 0
        errors = 0
        skipped = 0

        # Use tqdm for progress bar
        pbar = tqdm(subtitles, desc="Synthesizing", unit="line")
        
        for sub in pbar:
            sid = sub['id']
            raw_text = sub['text']
            
            # Update progress bar description
            pbar.set_description(f"Processing ID {sid}")
            
            # Determine Output Path
            out_name = f"{args.output_prefix}_{sid:03d}.wav"
            out_path = os.path.join(args.output, out_name)
            
            if args.skip_existing and os.path.exists(out_path):
                skipped += 1
                continue

            # Resolve Speaker/Reference
            if args.ref_audio:
                default_ref = args.ref_audio
            else:
                default_ref = find_reference_audio(args.ref_dir, sid, args.audio_prefix)
            
            clean_text, ref_audio = resolve_speaker_ref(agent, raw_text, args.ref_dir, default_ref)

            if not ref_audio or not os.path.exists(ref_audio):
                logger.warning(f"ID {sid}: No reference audio found. Skipping.")
                errors += 1
                continue

            # Generate via Agent
            try:
                agent.infer(
                    ref_file=ref_audio,
                    ref_text=args.ref_text if args.ref_audio else "", # Use ref_text only if using single ref audio
                    gen_text=clean_text,
                    file_wave=out_path,
                    remove_silence=True,
                    speed=args.speed
                )
                success += 1
            except Exception as e:
                logger.error(f"ID {sid}: Generation failed: {e}")
                errors += 1

        logger.info("-" * 40)
        logger.info(f"Done. Success: {success}, Skipped: {skipped}, Errors: {errors}")

if __name__ == "__main__":
    main()