File size: 24,287 Bytes
5d5b597
 
 
 
 
 
39c69de
 
 
5d5b597
39c69de
5d5b597
 
 
af2bc7c
 
 
 
 
 
 
 
 
 
 
 
 
 
39c69de
 
 
5d5b597
39c69de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d5b597
39c69de
 
af2bc7c
 
 
 
 
 
 
 
 
 
 
 
 
 
39c69de
5d5b597
39c69de
 
af2bc7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d5b597
 
 
 
 
 
 
 
 
 
 
 
39c69de
af2bc7c
 
 
 
 
39c69de
af2bc7c
 
 
 
 
 
 
 
 
 
 
 
 
5d5b597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39c69de
5d5b597
 
 
 
 
 
 
39c69de
5d5b597
 
 
 
 
 
 
 
 
 
 
39c69de
 
 
 
 
 
 
 
 
 
 
 
 
5d5b597
 
39c69de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d5b597
 
39c69de
5d5b597
 
 
 
 
39c69de
 
 
 
5d5b597
39c69de
 
 
 
 
5d5b597
 
39c69de
5d5b597
39c69de
5d5b597
 
39c69de
5d5b597
 
 
 
 
 
39c69de
 
5d5b597
 
 
 
 
 
 
 
 
 
39c69de
 
 
5d5b597
 
 
 
 
 
 
 
 
 
 
 
 
39c69de
 
5d5b597
 
39c69de
af2bc7c
 
 
 
 
 
39c69de
af2bc7c
 
 
 
 
 
 
 
 
 
 
 
5d5b597
 
 
 
 
 
 
 
 
 
 
 
 
39c69de
5d5b597
 
 
39c69de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d5b597
 
39c69de
5d5b597
 
 
39c69de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d5b597
 
 
 
39c69de
5d5b597
 
 
39c69de
 
 
af2bc7c
 
 
 
 
 
39c69de
 
5d5b597
 
 
39c69de
5d5b597
 
 
 
39c69de
5d5b597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39c69de
 
5d5b597
 
 
 
 
 
39c69de
 
5d5b597
 
 
39c69de
5d5b597
 
 
 
 
 
 
 
39c69de
5d5b597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39c69de
5d5b597
 
 
 
 
39c69de
 
5d5b597
 
 
39c69de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d5b597
 
 
39c69de
 
5d5b597
 
 
 
 
 
39c69de
 
5d5b597
39c69de
 
5d5b597
39c69de
 
 
 
 
 
5d5b597
 
 
 
 
 
39c69de
 
 
 
 
 
5d5b597
 
 
 
 
 
 
 
 
 
 
 
 
39c69de
 
 
 
 
 
 
5d5b597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39c69de
 
af2bc7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39c69de
5d5b597
 
39c69de
5d5b597
 
 
 
 
 
39c69de
5d5b597
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
import os
import torch
import soundfile as sf
import logging
import argparse
import gradio as gr
import json
import threading
import queue
from datetime import datetime
from pathlib import Path
from mira.model import MiraTTS

MODEL = None

# Safe device detection with fallback
def get_device():
    """Safely detect available device."""
    try:
        if torch.cuda.is_available():
            # Try to actually access CUDA to verify it works
            torch.cuda.current_device()
            return "cuda"
    except Exception as e:
        logging.warning(f"CUDA not available or driver error: {e}")
    return "cpu"

DEVICE = get_device()
HISTORY_FILE = "generation_history.json"
GENERATION_QUEUE = queue.Queue()
PROCESSING_LOCK = threading.Lock()

class GenerationHistory:
    """Manage generation history with persistence."""
    
    def __init__(self, history_file=HISTORY_FILE):
        self.history_file = history_file
        self.history = self.load_history()
    
    def load_history(self):
        """Load history from JSON file."""
        if os.path.exists(self.history_file):
            try:
                with open(self.history_file, 'r', encoding='utf-8') as f:
                    return json.load(f)
            except Exception as e:
                logging.error(f"Error loading history: {e}")
                return []
        return []
    
    def save_history(self):
        """Save history to JSON file."""
        try:
            with open(self.history_file, 'w', encoding='utf-8') as f:
                json.dump(self.history, f, indent=2, ensure_ascii=False)
        except Exception as e:
            logging.error(f"Error saving history: {e}")
    
    def add_entry(self, entry):
        """Add a new entry to history."""
        self.history.insert(0, entry)  # Add to beginning
        # Keep only last 100 entries
        if len(self.history) > 100:
            self.history = self.history[:100]
        self.save_history()
    
    def get_history(self):
        """Get all history entries."""
        return self.history
    
    def clear_history(self):
        """Clear all history."""
        self.history = []
        self.save_history()

# Global history manager
HISTORY_MANAGER = GenerationHistory()

def initialize_model(model_dir="YatharthS/MiraTTS", device=None):
    """Load the MiraTTS model once at the beginning."""
    global DEVICE
    if device:
        # Verify the requested device is available
        if device == "cuda":
            try:
                if not torch.cuda.is_available():
                    logging.warning("CUDA requested but not available, falling back to CPU")
                    DEVICE = "cpu"
                else:
                    torch.cuda.current_device()  # Test CUDA access
                    DEVICE = device
            except Exception as e:
                logging.warning(f"CUDA test failed: {e}, falling back to CPU")
                DEVICE = "cpu"
        else:
            DEVICE = device
    
    logging.info(f"Loading MiraTTS model from: {model_dir}")
    logging.info(f"Using device: {DEVICE}")
    
    try:
        model = MiraTTS(model_dir)
        
        # Move model to appropriate device
        if hasattr(model, 'to') and DEVICE == "cuda":
            try:
                model = model.to(DEVICE)
            except Exception as e:
                logging.warning(f"Failed to move model to CUDA: {e}, using CPU")
                DEVICE = "cpu"
        
        return model
    except Exception as e:
        logging.error(f"Error initializing model: {e}")
        raise

def generate_audio(text, prompt_audio_path):
    """Generate audio from text using MiraTTS with voice cloning."""
    global MODEL
    
    if MODEL is None:
        MODEL = initialize_model()
    
    try:
        # Encode the prompt audio
        context_tokens = MODEL.encode_audio(prompt_audio_path)
        
        # Move context tokens to device if needed
        if torch.is_tensor(context_tokens) and DEVICE == "cuda":
            try:
                context_tokens = context_tokens.to(DEVICE)
            except Exception as e:
                logging.warning(f"Failed to move tensors to CUDA: {e}")
        
        # Generate audio with appropriate context
        try:
            if DEVICE == "cpu":
                with torch.inference_mode():
                    audio = MODEL.generate(text, context_tokens)
            else:
                with torch.cuda.amp.autocast():
                    audio = MODEL.generate(text, context_tokens)
        except Exception as e:
            # Fallback to simple generation if autocast fails
            logging.warning(f"Autocast failed: {e}, using standard generation")
            with torch.inference_mode():
                audio = MODEL.generate(text, context_tokens)
        
        # Convert to numpy array if it's a tensor and handle dtype
        if torch.is_tensor(audio):
            audio = audio.cpu().numpy()
        
        # Ensure correct dtype for soundfile (convert from float16 to float32)
        if audio.dtype == 'float16':
            audio = audio.astype('float32')
        elif audio.dtype not in ['float32', 'float64', 'int16', 'int32']:
            audio = audio.astype('float32')
            
        return audio, 48000  # Return audio and sample rate
    except Exception as e:
        logging.error(f"Error during generation: {e}")
        raise e

def run_tts(text, prompt_audio_path, save_dir="results", mode="clone"):
    """Perform TTS inference and save the generated audio."""
    logging.info(f"Saving audio to: {save_dir}")
    
    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)
    
    # Generate unique filename using timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_path = os.path.join(save_dir, f"mira_tts_{timestamp}.wav")
    
    logging.info("Starting MiraTTS inference...")
    
    # Generate audio
    audio, sample_rate = generate_audio(text, prompt_audio_path)
    
    # Save audio file
    sf.write(save_path, audio, samplerate=sample_rate)
    
    logging.info(f"Audio saved at: {save_path}")
    
    # Add to history
    history_entry = {
        "timestamp": datetime.now().isoformat(),
        "text": text[:100] + "..." if len(text) > 100 else text,
        "full_text": text,
        "mode": mode,
        "file_path": save_path,
        "reference_audio": prompt_audio_path if mode == "clone" else None,
        "device": DEVICE
    }
    HISTORY_MANAGER.add_entry(history_entry)
    
    return save_path

def background_worker():
    """Background worker to process generation tasks."""
    while True:
        try:
            task = GENERATION_QUEUE.get()
            if task is None:  # Poison pill to stop the worker
                break
            
            callback, args = task
            callback(*args)
            
        except Exception as e:
            logging.error(f"Error in background worker: {e}")
        finally:
            GENERATION_QUEUE.task_done()

# Start background worker thread
worker_thread = threading.Thread(target=background_worker, daemon=True)
worker_thread.start()

def voice_clone_callback(text, prompt_audio_upload, prompt_audio_record, progress=gr.Progress()):
    """Gradio callback for voice cloning using MiraTTS."""
    if not text.strip():
        return None, get_history_display()
        
    # Use uploaded audio or recorded audio
    prompt_audio = prompt_audio_upload if prompt_audio_upload else prompt_audio_record
    
    if not prompt_audio:
        return None, get_history_display()
    
    progress(0, desc="Initializing...")
    
    try:
        progress(0.3, desc="Encoding audio...")
        progress(0.6, desc="Generating speech...")
        audio_output_path = run_tts(text, prompt_audio, mode="clone")
        progress(1.0, desc="Complete!")
        return audio_output_path, get_history_display()
    except Exception as e:
        logging.error(f"Error in voice cloning: {e}")
        return None, get_history_display()

def voice_creation_callback(text, temperature, top_p, top_k, progress=gr.Progress()):
    """Gradio callback for creating synthetic voice with custom parameters."""
    if not text.strip():
        return None, get_history_display()
        
    global MODEL
    
    if MODEL is None:
        MODEL = initialize_model()
    
    progress(0, desc="Initializing...")
    
    try:
        # Set custom generation parameters
        MODEL.set_params(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            max_new_tokens=1024,
            repetition_penalty=1.2
        )
        
        progress(0.3, desc="Loading default voice...")
        
        # Use a default voice context
        possible_paths = [
            "/models3/src/MiraTTS/models/MiraTTS/example1.wav",
            "models/MiraTTS/example1.wav",
            "./models/MiraTTS/example1.wav"
        ]
        
        default_audio = None
        for path in possible_paths:
            if os.path.exists(path):
                default_audio = path
                break
        
        if default_audio:
            progress(0.6, desc="Generating speech...")
            
            # Generate audio with dtype conversion
            context_tokens = MODEL.encode_audio(default_audio)
            
            # Move to device safely
            if torch.is_tensor(context_tokens) and DEVICE == "cuda":
                try:
                    context_tokens = context_tokens.to(DEVICE)
                except Exception as e:
                    logging.warning(f"Failed to move tensors to CUDA: {e}")
            
            try:
                if DEVICE == "cpu":
                    with torch.inference_mode():
                        audio = MODEL.generate(text, context_tokens)
                else:
                    with torch.cuda.amp.autocast():
                        audio = MODEL.generate(text, context_tokens)
            except Exception as e:
                # Fallback to simple generation
                logging.warning(f"Autocast failed: {e}, using standard generation")
                with torch.inference_mode():
                    audio = MODEL.generate(text, context_tokens)
            
            # Handle tensor conversion and dtype
            if torch.is_tensor(audio):
                audio = audio.cpu().numpy()
            
            # Ensure correct dtype for soundfile
            if audio.dtype == 'float16':
                audio = audio.astype('float32')
            elif audio.dtype not in ['float32', 'float64', 'int16', 'int32']:
                audio = audio.astype('float32')
            
            # Save the audio
            os.makedirs("results", exist_ok=True)
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            save_path = os.path.join("results", f"mira_tts_creation_{timestamp}.wav")
            sf.write(save_path, audio, samplerate=48000)
            
            # Add to history
            history_entry = {
                "timestamp": datetime.now().isoformat(),
                "text": text[:100] + "..." if len(text) > 100 else text,
                "full_text": text,
                "mode": "creation",
                "file_path": save_path,
                "reference_audio": None,
                "device": DEVICE,
                "temperature": temperature,
                "top_p": top_p,
                "top_k": top_k
            }
            HISTORY_MANAGER.add_entry(history_entry)
            
            progress(1.0, desc="Complete!")
            return save_path, get_history_display()
        else:
            logging.warning("No default audio found for voice creation")
            return None, get_history_display()
            
    except Exception as e:
        logging.error(f"Error in voice creation: {e}")
        return None, get_history_display()

def get_history_display():
    """Get formatted history for display."""
    history = HISTORY_MANAGER.get_history()
    
    if not history:
        return "No generation history yet."
    
    display_text = "# Generation History\n\n"
    
    for idx, entry in enumerate(history[:20]):  # Show last 20
        timestamp = datetime.fromisoformat(entry['timestamp']).strftime("%Y-%m-%d %H:%M:%S")
        mode = entry['mode'].capitalize()
        text_preview = entry['text']
        file_name = os.path.basename(entry['file_path'])
        
        display_text += f"### {idx + 1}. {timestamp} - {mode}\n"
        display_text += f"**Text:** {text_preview}\n"
        display_text += f"**File:** `{file_name}`\n"
        display_text += f"**Device:** {entry.get('device', 'N/A')}\n"
        
        if entry.get('temperature'):
            display_text += f"**Params:** T={entry.get('temperature')}, p={entry.get('top_p')}, k={entry.get('top_k')}\n"
        
        display_text += "\n---\n\n"
    
    return display_text

def get_history_files():
    """Get list of history files for download."""
    history = HISTORY_MANAGER.get_history()
    return [(entry['file_path'], os.path.basename(entry['file_path'])) 
            for entry in history if os.path.exists(entry['file_path'])]

def clear_history_callback():
    """Clear generation history."""
    HISTORY_MANAGER.clear_history()
    return get_history_display(), []

def build_ui():
    """Build the Gradio interface similar to SparkTTS."""
    
    with gr.Blocks(title="MiraTTS Web Interface", theme=gr.themes.Soft()) as demo:
        # Title
        gr.HTML('<h1 style="text-align: center;">MiraTTS - High Quality Voice Synthesis</h1>')
        
        # Device info
        device_info = f"🖥️ Running on: **{DEVICE.upper()}**"
        if DEVICE == "cuda":
            try:
                device_info += f" (GPU: {torch.cuda.get_device_name(0)})"
            except:
                device_info += " (GPU)"
        else:
            device_info += " (CPU mode - slower but works without GPU)"
        gr.Markdown(device_info)
        
        # Description
        gr.Markdown("""
        MiraTTS is a highly optimized Text-to-Speech model based on Spark-TTS with LMDeploy acceleration.
        It provides high-quality 48kHz audio output with background processing support.
        """)
        
        with gr.Tabs():
            # Voice Clone Tab
            with gr.TabItem("🎤 Voice Clone"):
                gr.Markdown("### Clone any voice using a reference audio sample")
                
                with gr.Row():
                    prompt_audio_upload = gr.Audio(
                        sources="upload",
                        type="filepath",
                        label="Upload Reference Audio (recommended: 3-30 seconds, 16kHz+)",
                    )
                    prompt_audio_record = gr.Audio(
                        sources="microphone",
                        type="filepath",
                        label="Record Reference Audio",
                    )
                
                text_input = gr.Textbox(
                    label="Text to Synthesize",
                    lines=3,
                    placeholder="Enter the text you want to convert to speech...",
                    value="Hello! This is a demonstration of MiraTTS voice cloning capabilities."
                )
                
                with gr.Row():
                    clone_button = gr.Button("🎵 Generate Audio", variant="primary")
                    clear_button = gr.Button("🗑️ Clear")
                
                audio_output_clone = gr.Audio(
                    label="Generated Audio",
                    autoplay=True
                )
                
                history_display_clone = gr.Markdown(get_history_display())
                
                clone_button.click(
                    voice_clone_callback,
                    inputs=[text_input, prompt_audio_upload, prompt_audio_record],
                    outputs=[audio_output_clone, history_display_clone],
                )
                
                clear_button.click(
                    lambda: (None, None, "", None),
                    outputs=[prompt_audio_upload, prompt_audio_record, text_input, audio_output_clone]
                )
            
            # Voice Creation Tab
            with gr.TabItem("✨ Voice Creation"):
                gr.Markdown("### Create synthetic voices with custom parameters")
                
                with gr.Row():
                    with gr.Column():
                        text_input_creation = gr.Textbox(
                            label="Text to Synthesize",
                            lines=3,
                            placeholder="Enter text here...",
                            value="You can create customized voices by adjusting the generation parameters below."
                        )
                        
                        with gr.Row():
                            temperature = gr.Slider(
                                minimum=0.1,
                                maximum=1.5,
                                step=0.1,
                                value=0.8,
                                label="Temperature (creativity)"
                            )
                            top_p = gr.Slider(
                                minimum=0.1,
                                maximum=1.0,
                                step=0.05,
                                value=0.95,
                                label="Top-p (nucleus sampling)"
                            )
                            top_k = gr.Slider(
                                minimum=1,
                                maximum=100,
                                step=1,
                                value=50,
                                label="Top-k (vocabulary size)"
                            )
                    
                    with gr.Column():
                        create_button = gr.Button("🎨 Create Voice", variant="primary")
                        audio_output_creation = gr.Audio(
                            label="Generated Audio",
                            autoplay=True
                        )
                
                history_display_creation = gr.Markdown(get_history_display())
                
                create_button.click(
                    voice_creation_callback,
                    inputs=[text_input_creation, temperature, top_p, top_k],
                    outputs=[audio_output_creation, history_display_creation],
                )
            
            # History Tab
            with gr.TabItem("📜 History"):
                gr.Markdown("### Review and download previous generations")
                
                with gr.Row():
                    refresh_button = gr.Button("🔄 Refresh History", variant="secondary")
                    clear_history_button = gr.Button("🗑️ Clear History", variant="stop")
                
                history_display_main = gr.Markdown(get_history_display())
                
                gr.Markdown("### Download Files")
                file_browser = gr.File(
                    label="Generated Audio Files",
                    file_count="multiple",
                    interactive=False
                )
                
                def refresh_history():
                    files = get_history_files()
                    return get_history_display(), [f[0] for f in files]
                
                refresh_button.click(
                    refresh_history,
                    outputs=[history_display_main, file_browser]
                )
                
                clear_history_button.click(
                    clear_history_callback,
                    outputs=[history_display_main, file_browser]
                )
                
                # Auto-load files on tab open
                demo.load(
                    refresh_history,
                    outputs=[history_display_main, file_browser]
                )
            
            # About Tab
            with gr.TabItem("ℹ️ About"):
                gr.Markdown(f"""
                ## About MiraTTS
                
                MiraTTS is an optimized version of Spark-TTS with the following features:
                
                - **Ultra-fast generation**: Over 100x realtime speed using LMDeploy optimization
                - **High quality**: Generates crisp 48kHz audio outputs
                - **Memory efficient**: Works within 6GB VRAM or on CPU
                - **Low latency**: As low as 100ms generation time (GPU)
                - **Voice cloning**: Clone any voice from a short audio sample
                - **Background processing**: Non-blocking audio generation
                - **Generation history**: Review and download all generated audio
                
                ### Current Configuration
                - **Device**: {DEVICE.upper()}
                - **Base model**: Spark-TTS-0.5B
                - **Optimization**: LMDeploy + FlashSR
                - **Sample rate**: 48kHz
                - **Model size**: ~500M parameters
                
                ### Usage Tips
                - For voice cloning, use clear audio samples between 3-30 seconds
                - Ensure reference audio is at least 16kHz quality
                - Longer text inputs may require more memory
                - Adjust generation parameters for different voice styles
                - CPU mode is slower but works without GPU
                - Check the History tab to download previous generations
                
                ### Performance Notes
                - **GPU**: ~100-200ms per generation
                - **CPU**: ~2-5 seconds per generation (depending on CPU)
                """)
    
    return demo

def parse_arguments():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description="MiraTTS Gradio Web Interface")
    parser.add_argument(
        "--model_dir",
        type=str,
        default="YatharthS/MiraTTS",
        help="Path to the MiraTTS model directory or HuggingFace model ID"
    )
    parser.add_argument(
        "--device",
        type=str,
        default=None,
        choices=["cuda", "cpu"],
        help="Device to run model on (default: auto-detect)"
    )
    parser.add_argument(
        "--server_name",
        type=str,
        default="127.0.0.1",
        help="Server host/IP for Gradio app"
    )
    parser.add_argument(
        "--server_port",
        type=int,
        default=7860,
        help="Server port for Gradio app"
    )
    parser.add_argument(
        "--share",
        action="store_true",
        help="Create a public shareable link"
    )
    return parser.parse_args()

if __name__ == "__main__":
    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    
    # Parse arguments
    args = parse_arguments()
    
    # Set device if specified
    if args.device:
        if args.device == "cuda":
            try:
                if not torch.cuda.is_available():
                    logging.warning("CUDA requested but not available, falling back to CPU")
                    DEVICE = "cpu"
                else:
                    torch.cuda.current_device()  # Test CUDA access
                    DEVICE = args.device
            except Exception as e:
                logging.warning(f"CUDA test failed: {e}, falling back to CPU")
                DEVICE = "cpu"
        else:
            DEVICE = args.device
    
    logging.info(f"Device selected: {DEVICE}")
    
    # Initialize model
    logging.info("Initializing MiraTTS model...")
    MODEL = initialize_model(args.model_dir, args.device)
    
    # Build and launch interface
    logging.info("Building Gradio interface...")
    demo = build_ui()
    
    logging.info(f"Launching web interface on {args.server_name}:{args.server_port}")
    logging.info(f"Device: {DEVICE}")
    demo.launch(
        server_name=args.server_name,
        server_port=args.server_port,
        share=args.share
    )