Translsis commited on
Commit
39c69de
·
verified ·
1 Parent(s): 0b42a1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +280 -39
app.py CHANGED
@@ -4,15 +4,80 @@ import soundfile as sf
4
  import logging
5
  import argparse
6
  import gradio as gr
 
 
 
7
  from datetime import datetime
 
8
  from mira.model import MiraTTS
9
 
10
  MODEL = None
 
 
 
 
11
 
12
- def initialize_model(model_dir="YatharthS/MiraTTS"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  """Load the MiraTTS model once at the beginning."""
 
 
 
 
14
  logging.info(f"Loading MiraTTS model from: {model_dir}")
 
 
15
  model = MiraTTS(model_dir)
 
 
 
 
 
16
  return model
17
 
18
  def generate_audio(text, prompt_audio_path):
@@ -26,8 +91,13 @@ def generate_audio(text, prompt_audio_path):
26
  # Encode the prompt audio
27
  context_tokens = MODEL.encode_audio(prompt_audio_path)
28
 
 
 
 
 
29
  # Generate audio
30
- audio = MODEL.generate(text, context_tokens)
 
31
 
32
  # Convert to numpy array if it's a tensor and handle dtype
33
  if torch.is_tensor(audio):
@@ -44,7 +114,7 @@ def generate_audio(text, prompt_audio_path):
44
  logging.error(f"Error during generation: {e}")
45
  raise e
46
 
47
- def run_tts(text, prompt_audio_path, save_dir="results"):
48
  """Perform TTS inference and save the generated audio."""
49
  logging.info(f"Saving audio to: {save_dir}")
50
 
@@ -52,7 +122,7 @@ def run_tts(text, prompt_audio_path, save_dir="results"):
52
  os.makedirs(save_dir, exist_ok=True)
53
 
54
  # Generate unique filename using timestamp
55
- timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
56
  save_path = os.path.join(save_dir, f"mira_tts_{timestamp}.wav")
57
 
58
  logging.info("Starting MiraTTS inference...")
@@ -64,36 +134,76 @@ def run_tts(text, prompt_audio_path, save_dir="results"):
64
  sf.write(save_path, audio, samplerate=sample_rate)
65
 
66
  logging.info(f"Audio saved at: {save_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  return save_path
68
 
69
- def voice_clone_callback(text, prompt_audio_upload, prompt_audio_record):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  """Gradio callback for voice cloning using MiraTTS."""
71
  if not text.strip():
72
- return None
73
 
74
  # Use uploaded audio or recorded audio
75
  prompt_audio = prompt_audio_upload if prompt_audio_upload else prompt_audio_record
76
 
77
  if not prompt_audio:
78
- return None
79
-
 
 
80
  try:
81
- audio_output_path = run_tts(text, prompt_audio)
82
- return audio_output_path
 
 
 
83
  except Exception as e:
84
  logging.error(f"Error in voice cloning: {e}")
85
- return None
86
 
87
- def voice_creation_callback(text, temperature, top_p, top_k):
88
  """Gradio callback for creating synthetic voice with custom parameters."""
89
  if not text.strip():
90
- return None
91
 
92
  global MODEL
93
 
94
  if MODEL is None:
95
  MODEL = initialize_model()
96
 
 
 
97
  try:
98
  # Set custom generation parameters
99
  MODEL.set_params(
@@ -104,8 +214,9 @@ def voice_creation_callback(text, temperature, top_p, top_k):
104
  repetition_penalty=1.2
105
  )
106
 
107
- # Use a default voice context (you may want to provide default audio files)
108
- # Check multiple possible paths for example audio
 
109
  possible_paths = [
110
  "/models3/src/MiraTTS/models/MiraTTS/example1.wav",
111
  "models/MiraTTS/example1.wav",
@@ -119,9 +230,17 @@ def voice_creation_callback(text, temperature, top_p, top_k):
119
  break
120
 
121
  if default_audio:
 
 
122
  # Generate audio with dtype conversion
123
  context_tokens = MODEL.encode_audio(default_audio)
124
- audio = MODEL.generate(text, context_tokens)
 
 
 
 
 
 
125
 
126
  # Handle tensor conversion and dtype
127
  if torch.is_tensor(audio):
@@ -135,35 +254,95 @@ def voice_creation_callback(text, temperature, top_p, top_k):
135
 
136
  # Save the audio
137
  os.makedirs("results", exist_ok=True)
138
- timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
139
  save_path = os.path.join("results", f"mira_tts_creation_{timestamp}.wav")
140
  sf.write(save_path, audio, samplerate=48000)
141
 
142
- return save_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  else:
144
  logging.warning("No default audio found for voice creation")
145
- return None
146
 
147
  except Exception as e:
148
  logging.error(f"Error in voice creation: {e}")
149
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  def build_ui():
152
  """Build the Gradio interface similar to SparkTTS."""
153
 
154
- with gr.Blocks(title="MiraTTS Web Interface") as demo:
155
  # Title
156
  gr.HTML('<h1 style="text-align: center;">MiraTTS - High Quality Voice Synthesis</h1>')
157
 
 
 
 
 
 
 
158
  # Description
159
  gr.Markdown("""
160
  MiraTTS is a highly optimized Text-to-Speech model based on Spark-TTS with LMDeploy acceleration.
161
- It provides over 100x realtime generation speed with high-quality 48kHz audio output.
162
  """)
163
 
164
  with gr.Tabs():
165
  # Voice Clone Tab
166
- with gr.TabItem("Voice Clone"):
167
  gr.Markdown("### Clone any voice using a reference audio sample")
168
 
169
  with gr.Row():
@@ -186,18 +365,20 @@ def build_ui():
186
  )
187
 
188
  with gr.Row():
189
- clone_button = gr.Button("Generate Audio", variant="primary")
190
- clear_button = gr.Button("Clear")
191
 
192
  audio_output_clone = gr.Audio(
193
  label="Generated Audio",
194
  autoplay=True
195
  )
196
 
 
 
197
  clone_button.click(
198
  voice_clone_callback,
199
  inputs=[text_input, prompt_audio_upload, prompt_audio_record],
200
- outputs=[audio_output_clone],
201
  )
202
 
203
  clear_button.click(
@@ -206,7 +387,7 @@ def build_ui():
206
  )
207
 
208
  # Voice Creation Tab
209
- with gr.TabItem("Voice Creation"):
210
  gr.Markdown("### Create synthetic voices with custom parameters")
211
 
212
  with gr.Row():
@@ -242,42 +423,90 @@ def build_ui():
242
  )
243
 
244
  with gr.Column():
245
- create_button = gr.Button("Create Voice", variant="primary")
246
  audio_output_creation = gr.Audio(
247
  label="Generated Audio",
248
  autoplay=True
249
  )
250
 
 
 
251
  create_button.click(
252
  voice_creation_callback,
253
  inputs=[text_input_creation, temperature, top_p, top_k],
254
- outputs=[audio_output_creation],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  )
256
 
257
  # About Tab
258
- with gr.TabItem("About"):
259
- gr.Markdown("""
260
  ## About MiraTTS
261
 
262
  MiraTTS is an optimized version of Spark-TTS with the following features:
263
 
264
  - **Ultra-fast generation**: Over 100x realtime speed using LMDeploy optimization
265
  - **High quality**: Generates crisp 48kHz audio outputs
266
- - **Memory efficient**: Works within 6GB VRAM
267
- - **Low latency**: As low as 100ms generation time
268
  - **Voice cloning**: Clone any voice from a short audio sample
 
 
269
 
270
- ### Model Information
271
- - Base model: Spark-TTS-0.5B
272
- - Optimization: LMDeploy + FlashSR
273
- - Sample rate: 48kHz
274
- - Model size: ~500M parameters
 
275
 
276
  ### Usage Tips
277
  - For voice cloning, use clear audio samples between 3-30 seconds
278
  - Ensure reference audio is at least 16kHz quality
279
  - Longer text inputs may require more memory
280
  - Adjust generation parameters for different voice styles
 
 
 
 
 
 
281
  """)
282
 
283
  return demo
@@ -291,6 +520,13 @@ def parse_arguments():
291
  default="YatharthS/MiraTTS",
292
  help="Path to the MiraTTS model directory or HuggingFace model ID"
293
  )
 
 
 
 
 
 
 
294
  parser.add_argument(
295
  "--server_name",
296
  type=str,
@@ -320,15 +556,20 @@ if __name__ == "__main__":
320
  # Parse arguments
321
  args = parse_arguments()
322
 
 
 
 
 
323
  # Initialize model
324
  logging.info("Initializing MiraTTS model...")
325
- MODEL = initialize_model(args.model_dir)
326
 
327
  # Build and launch interface
328
  logging.info("Building Gradio interface...")
329
  demo = build_ui()
330
 
331
  logging.info(f"Launching web interface on {args.server_name}:{args.server_port}")
 
332
  demo.launch(
333
  server_name=args.server_name,
334
  server_port=args.server_port,
 
4
  import logging
5
  import argparse
6
  import gradio as gr
7
+ import json
8
+ import threading
9
+ import queue
10
  from datetime import datetime
11
+ from pathlib import Path
12
  from mira.model import MiraTTS
13
 
14
  MODEL = None
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ HISTORY_FILE = "generation_history.json"
17
+ GENERATION_QUEUE = queue.Queue()
18
+ PROCESSING_LOCK = threading.Lock()
19
 
20
+ class GenerationHistory:
21
+ """Manage generation history with persistence."""
22
+
23
+ def __init__(self, history_file=HISTORY_FILE):
24
+ self.history_file = history_file
25
+ self.history = self.load_history()
26
+
27
+ def load_history(self):
28
+ """Load history from JSON file."""
29
+ if os.path.exists(self.history_file):
30
+ try:
31
+ with open(self.history_file, 'r', encoding='utf-8') as f:
32
+ return json.load(f)
33
+ except Exception as e:
34
+ logging.error(f"Error loading history: {e}")
35
+ return []
36
+ return []
37
+
38
+ def save_history(self):
39
+ """Save history to JSON file."""
40
+ try:
41
+ with open(self.history_file, 'w', encoding='utf-8') as f:
42
+ json.dump(self.history, f, indent=2, ensure_ascii=False)
43
+ except Exception as e:
44
+ logging.error(f"Error saving history: {e}")
45
+
46
+ def add_entry(self, entry):
47
+ """Add a new entry to history."""
48
+ self.history.insert(0, entry) # Add to beginning
49
+ # Keep only last 100 entries
50
+ if len(self.history) > 100:
51
+ self.history = self.history[:100]
52
+ self.save_history()
53
+
54
+ def get_history(self):
55
+ """Get all history entries."""
56
+ return self.history
57
+
58
+ def clear_history(self):
59
+ """Clear all history."""
60
+ self.history = []
61
+ self.save_history()
62
+
63
+ # Global history manager
64
+ HISTORY_MANAGER = GenerationHistory()
65
+
66
+ def initialize_model(model_dir="YatharthS/MiraTTS", device=None):
67
  """Load the MiraTTS model once at the beginning."""
68
+ global DEVICE
69
+ if device:
70
+ DEVICE = device
71
+
72
  logging.info(f"Loading MiraTTS model from: {model_dir}")
73
+ logging.info(f"Using device: {DEVICE}")
74
+
75
  model = MiraTTS(model_dir)
76
+
77
+ # Move model to appropriate device
78
+ if hasattr(model, 'to'):
79
+ model = model.to(DEVICE)
80
+
81
  return model
82
 
83
  def generate_audio(text, prompt_audio_path):
 
91
  # Encode the prompt audio
92
  context_tokens = MODEL.encode_audio(prompt_audio_path)
93
 
94
+ # Move context tokens to device if needed
95
+ if torch.is_tensor(context_tokens):
96
+ context_tokens = context_tokens.to(DEVICE)
97
+
98
  # Generate audio
99
+ with torch.inference_mode() if DEVICE == "cpu" else torch.cuda.amp.autocast():
100
+ audio = MODEL.generate(text, context_tokens)
101
 
102
  # Convert to numpy array if it's a tensor and handle dtype
103
  if torch.is_tensor(audio):
 
114
  logging.error(f"Error during generation: {e}")
115
  raise e
116
 
117
+ def run_tts(text, prompt_audio_path, save_dir="results", mode="clone"):
118
  """Perform TTS inference and save the generated audio."""
119
  logging.info(f"Saving audio to: {save_dir}")
120
 
 
122
  os.makedirs(save_dir, exist_ok=True)
123
 
124
  # Generate unique filename using timestamp
125
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
126
  save_path = os.path.join(save_dir, f"mira_tts_{timestamp}.wav")
127
 
128
  logging.info("Starting MiraTTS inference...")
 
134
  sf.write(save_path, audio, samplerate=sample_rate)
135
 
136
  logging.info(f"Audio saved at: {save_path}")
137
+
138
+ # Add to history
139
+ history_entry = {
140
+ "timestamp": datetime.now().isoformat(),
141
+ "text": text[:100] + "..." if len(text) > 100 else text,
142
+ "full_text": text,
143
+ "mode": mode,
144
+ "file_path": save_path,
145
+ "reference_audio": prompt_audio_path if mode == "clone" else None,
146
+ "device": DEVICE
147
+ }
148
+ HISTORY_MANAGER.add_entry(history_entry)
149
+
150
  return save_path
151
 
152
+ def background_worker():
153
+ """Background worker to process generation tasks."""
154
+ while True:
155
+ try:
156
+ task = GENERATION_QUEUE.get()
157
+ if task is None: # Poison pill to stop the worker
158
+ break
159
+
160
+ callback, args = task
161
+ callback(*args)
162
+
163
+ except Exception as e:
164
+ logging.error(f"Error in background worker: {e}")
165
+ finally:
166
+ GENERATION_QUEUE.task_done()
167
+
168
+ # Start background worker thread
169
+ worker_thread = threading.Thread(target=background_worker, daemon=True)
170
+ worker_thread.start()
171
+
172
+ def voice_clone_callback(text, prompt_audio_upload, prompt_audio_record, progress=gr.Progress()):
173
  """Gradio callback for voice cloning using MiraTTS."""
174
  if not text.strip():
175
+ return None, get_history_display()
176
 
177
  # Use uploaded audio or recorded audio
178
  prompt_audio = prompt_audio_upload if prompt_audio_upload else prompt_audio_record
179
 
180
  if not prompt_audio:
181
+ return None, get_history_display()
182
+
183
+ progress(0, desc="Initializing...")
184
+
185
  try:
186
+ progress(0.3, desc="Encoding audio...")
187
+ progress(0.6, desc="Generating speech...")
188
+ audio_output_path = run_tts(text, prompt_audio, mode="clone")
189
+ progress(1.0, desc="Complete!")
190
+ return audio_output_path, get_history_display()
191
  except Exception as e:
192
  logging.error(f"Error in voice cloning: {e}")
193
+ return None, get_history_display()
194
 
195
+ def voice_creation_callback(text, temperature, top_p, top_k, progress=gr.Progress()):
196
  """Gradio callback for creating synthetic voice with custom parameters."""
197
  if not text.strip():
198
+ return None, get_history_display()
199
 
200
  global MODEL
201
 
202
  if MODEL is None:
203
  MODEL = initialize_model()
204
 
205
+ progress(0, desc="Initializing...")
206
+
207
  try:
208
  # Set custom generation parameters
209
  MODEL.set_params(
 
214
  repetition_penalty=1.2
215
  )
216
 
217
+ progress(0.3, desc="Loading default voice...")
218
+
219
+ # Use a default voice context
220
  possible_paths = [
221
  "/models3/src/MiraTTS/models/MiraTTS/example1.wav",
222
  "models/MiraTTS/example1.wav",
 
230
  break
231
 
232
  if default_audio:
233
+ progress(0.6, desc="Generating speech...")
234
+
235
  # Generate audio with dtype conversion
236
  context_tokens = MODEL.encode_audio(default_audio)
237
+
238
+ # Move to device
239
+ if torch.is_tensor(context_tokens):
240
+ context_tokens = context_tokens.to(DEVICE)
241
+
242
+ with torch.inference_mode() if DEVICE == "cpu" else torch.cuda.amp.autocast():
243
+ audio = MODEL.generate(text, context_tokens)
244
 
245
  # Handle tensor conversion and dtype
246
  if torch.is_tensor(audio):
 
254
 
255
  # Save the audio
256
  os.makedirs("results", exist_ok=True)
257
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
258
  save_path = os.path.join("results", f"mira_tts_creation_{timestamp}.wav")
259
  sf.write(save_path, audio, samplerate=48000)
260
 
261
+ # Add to history
262
+ history_entry = {
263
+ "timestamp": datetime.now().isoformat(),
264
+ "text": text[:100] + "..." if len(text) > 100 else text,
265
+ "full_text": text,
266
+ "mode": "creation",
267
+ "file_path": save_path,
268
+ "reference_audio": None,
269
+ "device": DEVICE,
270
+ "temperature": temperature,
271
+ "top_p": top_p,
272
+ "top_k": top_k
273
+ }
274
+ HISTORY_MANAGER.add_entry(history_entry)
275
+
276
+ progress(1.0, desc="Complete!")
277
+ return save_path, get_history_display()
278
  else:
279
  logging.warning("No default audio found for voice creation")
280
+ return None, get_history_display()
281
 
282
  except Exception as e:
283
  logging.error(f"Error in voice creation: {e}")
284
+ return None, get_history_display()
285
+
286
+ def get_history_display():
287
+ """Get formatted history for display."""
288
+ history = HISTORY_MANAGER.get_history()
289
+
290
+ if not history:
291
+ return "No generation history yet."
292
+
293
+ display_text = "# Generation History\n\n"
294
+
295
+ for idx, entry in enumerate(history[:20]): # Show last 20
296
+ timestamp = datetime.fromisoformat(entry['timestamp']).strftime("%Y-%m-%d %H:%M:%S")
297
+ mode = entry['mode'].capitalize()
298
+ text_preview = entry['text']
299
+ file_name = os.path.basename(entry['file_path'])
300
+
301
+ display_text += f"### {idx + 1}. {timestamp} - {mode}\n"
302
+ display_text += f"**Text:** {text_preview}\n"
303
+ display_text += f"**File:** `{file_name}`\n"
304
+ display_text += f"**Device:** {entry.get('device', 'N/A')}\n"
305
+
306
+ if entry.get('temperature'):
307
+ display_text += f"**Params:** T={entry.get('temperature')}, p={entry.get('top_p')}, k={entry.get('top_k')}\n"
308
+
309
+ display_text += "\n---\n\n"
310
+
311
+ return display_text
312
+
313
+ def get_history_files():
314
+ """Get list of history files for download."""
315
+ history = HISTORY_MANAGER.get_history()
316
+ return [(entry['file_path'], os.path.basename(entry['file_path']))
317
+ for entry in history if os.path.exists(entry['file_path'])]
318
+
319
+ def clear_history_callback():
320
+ """Clear generation history."""
321
+ HISTORY_MANAGER.clear_history()
322
+ return get_history_display(), []
323
 
324
  def build_ui():
325
  """Build the Gradio interface similar to SparkTTS."""
326
 
327
+ with gr.Blocks(title="MiraTTS Web Interface", theme=gr.themes.Soft()) as demo:
328
  # Title
329
  gr.HTML('<h1 style="text-align: center;">MiraTTS - High Quality Voice Synthesis</h1>')
330
 
331
+ # Device info
332
+ device_info = f"🖥️ Running on: **{DEVICE.upper()}**"
333
+ if DEVICE == "cuda":
334
+ device_info += f" (GPU: {torch.cuda.get_device_name(0)})"
335
+ gr.Markdown(device_info)
336
+
337
  # Description
338
  gr.Markdown("""
339
  MiraTTS is a highly optimized Text-to-Speech model based on Spark-TTS with LMDeploy acceleration.
340
+ It provides high-quality 48kHz audio output with background processing support.
341
  """)
342
 
343
  with gr.Tabs():
344
  # Voice Clone Tab
345
+ with gr.TabItem("🎤 Voice Clone"):
346
  gr.Markdown("### Clone any voice using a reference audio sample")
347
 
348
  with gr.Row():
 
365
  )
366
 
367
  with gr.Row():
368
+ clone_button = gr.Button("🎵 Generate Audio", variant="primary")
369
+ clear_button = gr.Button("🗑️ Clear")
370
 
371
  audio_output_clone = gr.Audio(
372
  label="Generated Audio",
373
  autoplay=True
374
  )
375
 
376
+ history_display_clone = gr.Markdown(get_history_display())
377
+
378
  clone_button.click(
379
  voice_clone_callback,
380
  inputs=[text_input, prompt_audio_upload, prompt_audio_record],
381
+ outputs=[audio_output_clone, history_display_clone],
382
  )
383
 
384
  clear_button.click(
 
387
  )
388
 
389
  # Voice Creation Tab
390
+ with gr.TabItem("Voice Creation"):
391
  gr.Markdown("### Create synthetic voices with custom parameters")
392
 
393
  with gr.Row():
 
423
  )
424
 
425
  with gr.Column():
426
+ create_button = gr.Button("🎨 Create Voice", variant="primary")
427
  audio_output_creation = gr.Audio(
428
  label="Generated Audio",
429
  autoplay=True
430
  )
431
 
432
+ history_display_creation = gr.Markdown(get_history_display())
433
+
434
  create_button.click(
435
  voice_creation_callback,
436
  inputs=[text_input_creation, temperature, top_p, top_k],
437
+ outputs=[audio_output_creation, history_display_creation],
438
+ )
439
+
440
+ # History Tab
441
+ with gr.TabItem("📜 History"):
442
+ gr.Markdown("### Review and download previous generations")
443
+
444
+ with gr.Row():
445
+ refresh_button = gr.Button("🔄 Refresh History", variant="secondary")
446
+ clear_history_button = gr.Button("🗑️ Clear History", variant="stop")
447
+
448
+ history_display_main = gr.Markdown(get_history_display())
449
+
450
+ gr.Markdown("### Download Files")
451
+ file_browser = gr.File(
452
+ label="Generated Audio Files",
453
+ file_count="multiple",
454
+ interactive=False
455
+ )
456
+
457
+ def refresh_history():
458
+ files = get_history_files()
459
+ return get_history_display(), [f[0] for f in files]
460
+
461
+ refresh_button.click(
462
+ refresh_history,
463
+ outputs=[history_display_main, file_browser]
464
+ )
465
+
466
+ clear_history_button.click(
467
+ clear_history_callback,
468
+ outputs=[history_display_main, file_browser]
469
+ )
470
+
471
+ # Auto-load files on tab open
472
+ demo.load(
473
+ refresh_history,
474
+ outputs=[history_display_main, file_browser]
475
  )
476
 
477
  # About Tab
478
+ with gr.TabItem("ℹ️ About"):
479
+ gr.Markdown(f"""
480
  ## About MiraTTS
481
 
482
  MiraTTS is an optimized version of Spark-TTS with the following features:
483
 
484
  - **Ultra-fast generation**: Over 100x realtime speed using LMDeploy optimization
485
  - **High quality**: Generates crisp 48kHz audio outputs
486
+ - **Memory efficient**: Works within 6GB VRAM or on CPU
487
+ - **Low latency**: As low as 100ms generation time (GPU)
488
  - **Voice cloning**: Clone any voice from a short audio sample
489
+ - **Background processing**: Non-blocking audio generation
490
+ - **Generation history**: Review and download all generated audio
491
 
492
+ ### Current Configuration
493
+ - **Device**: {DEVICE.upper()}
494
+ - **Base model**: Spark-TTS-0.5B
495
+ - **Optimization**: LMDeploy + FlashSR
496
+ - **Sample rate**: 48kHz
497
+ - **Model size**: ~500M parameters
498
 
499
  ### Usage Tips
500
  - For voice cloning, use clear audio samples between 3-30 seconds
501
  - Ensure reference audio is at least 16kHz quality
502
  - Longer text inputs may require more memory
503
  - Adjust generation parameters for different voice styles
504
+ - CPU mode is slower but works without GPU
505
+ - Check the History tab to download previous generations
506
+
507
+ ### Performance Notes
508
+ - **GPU**: ~100-200ms per generation
509
+ - **CPU**: ~2-5 seconds per generation (depending on CPU)
510
  """)
511
 
512
  return demo
 
520
  default="YatharthS/MiraTTS",
521
  help="Path to the MiraTTS model directory or HuggingFace model ID"
522
  )
523
+ parser.add_argument(
524
+ "--device",
525
+ type=str,
526
+ default=None,
527
+ choices=["cuda", "cpu"],
528
+ help="Device to run model on (default: auto-detect)"
529
+ )
530
  parser.add_argument(
531
  "--server_name",
532
  type=str,
 
556
  # Parse arguments
557
  args = parse_arguments()
558
 
559
+ # Set device if specified
560
+ if args.device:
561
+ DEVICE = args.device
562
+
563
  # Initialize model
564
  logging.info("Initializing MiraTTS model...")
565
+ MODEL = initialize_model(args.model_dir, args.device)
566
 
567
  # Build and launch interface
568
  logging.info("Building Gradio interface...")
569
  demo = build_ui()
570
 
571
  logging.info(f"Launching web interface on {args.server_name}:{args.server_port}")
572
+ logging.info(f"Device: {DEVICE}")
573
  demo.launch(
574
  server_name=args.server_name,
575
  server_port=args.server_port,