Translsis commited on
Commit
af2bc7c
·
verified ·
1 Parent(s): 39c69de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -21
app.py CHANGED
@@ -12,7 +12,20 @@ from pathlib import Path
12
  from mira.model import MiraTTS
13
 
14
  MODEL = None
15
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  HISTORY_FILE = "generation_history.json"
17
  GENERATION_QUEUE = queue.Queue()
18
  PROCESSING_LOCK = threading.Lock()
@@ -67,18 +80,39 @@ def initialize_model(model_dir="YatharthS/MiraTTS", device=None):
67
  """Load the MiraTTS model once at the beginning."""
68
  global DEVICE
69
  if device:
70
- DEVICE = device
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  logging.info(f"Loading MiraTTS model from: {model_dir}")
73
  logging.info(f"Using device: {DEVICE}")
74
 
75
- model = MiraTTS(model_dir)
76
-
77
- # Move model to appropriate device
78
- if hasattr(model, 'to'):
79
- model = model.to(DEVICE)
80
-
81
- return model
 
 
 
 
 
 
 
 
82
 
83
  def generate_audio(text, prompt_audio_path):
84
  """Generate audio from text using MiraTTS with voice cloning."""
@@ -92,12 +126,25 @@ def generate_audio(text, prompt_audio_path):
92
  context_tokens = MODEL.encode_audio(prompt_audio_path)
93
 
94
  # Move context tokens to device if needed
95
- if torch.is_tensor(context_tokens):
96
- context_tokens = context_tokens.to(DEVICE)
 
 
 
97
 
98
- # Generate audio
99
- with torch.inference_mode() if DEVICE == "cpu" else torch.cuda.amp.autocast():
100
- audio = MODEL.generate(text, context_tokens)
 
 
 
 
 
 
 
 
 
 
101
 
102
  # Convert to numpy array if it's a tensor and handle dtype
103
  if torch.is_tensor(audio):
@@ -235,12 +282,25 @@ def voice_creation_callback(text, temperature, top_p, top_k, progress=gr.Progres
235
  # Generate audio with dtype conversion
236
  context_tokens = MODEL.encode_audio(default_audio)
237
 
238
- # Move to device
239
- if torch.is_tensor(context_tokens):
240
- context_tokens = context_tokens.to(DEVICE)
 
 
 
241
 
242
- with torch.inference_mode() if DEVICE == "cpu" else torch.cuda.amp.autocast():
243
- audio = MODEL.generate(text, context_tokens)
 
 
 
 
 
 
 
 
 
 
244
 
245
  # Handle tensor conversion and dtype
246
  if torch.is_tensor(audio):
@@ -331,7 +391,12 @@ def build_ui():
331
  # Device info
332
  device_info = f"🖥️ Running on: **{DEVICE.upper()}**"
333
  if DEVICE == "cuda":
334
- device_info += f" (GPU: {torch.cuda.get_device_name(0)})"
 
 
 
 
 
335
  gr.Markdown(device_info)
336
 
337
  # Description
@@ -558,7 +623,21 @@ if __name__ == "__main__":
558
 
559
  # Set device if specified
560
  if args.device:
561
- DEVICE = args.device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
 
563
  # Initialize model
564
  logging.info("Initializing MiraTTS model...")
 
12
  from mira.model import MiraTTS
13
 
14
  MODEL = None
15
+
16
+ # Safe device detection with fallback
17
+ def get_device():
18
+ """Safely detect available device."""
19
+ try:
20
+ if torch.cuda.is_available():
21
+ # Try to actually access CUDA to verify it works
22
+ torch.cuda.current_device()
23
+ return "cuda"
24
+ except Exception as e:
25
+ logging.warning(f"CUDA not available or driver error: {e}")
26
+ return "cpu"
27
+
28
+ DEVICE = get_device()
29
  HISTORY_FILE = "generation_history.json"
30
  GENERATION_QUEUE = queue.Queue()
31
  PROCESSING_LOCK = threading.Lock()
 
80
  """Load the MiraTTS model once at the beginning."""
81
  global DEVICE
82
  if device:
83
+ # Verify the requested device is available
84
+ if device == "cuda":
85
+ try:
86
+ if not torch.cuda.is_available():
87
+ logging.warning("CUDA requested but not available, falling back to CPU")
88
+ DEVICE = "cpu"
89
+ else:
90
+ torch.cuda.current_device() # Test CUDA access
91
+ DEVICE = device
92
+ except Exception as e:
93
+ logging.warning(f"CUDA test failed: {e}, falling back to CPU")
94
+ DEVICE = "cpu"
95
+ else:
96
+ DEVICE = device
97
 
98
  logging.info(f"Loading MiraTTS model from: {model_dir}")
99
  logging.info(f"Using device: {DEVICE}")
100
 
101
+ try:
102
+ model = MiraTTS(model_dir)
103
+
104
+ # Move model to appropriate device
105
+ if hasattr(model, 'to') and DEVICE == "cuda":
106
+ try:
107
+ model = model.to(DEVICE)
108
+ except Exception as e:
109
+ logging.warning(f"Failed to move model to CUDA: {e}, using CPU")
110
+ DEVICE = "cpu"
111
+
112
+ return model
113
+ except Exception as e:
114
+ logging.error(f"Error initializing model: {e}")
115
+ raise
116
 
117
  def generate_audio(text, prompt_audio_path):
118
  """Generate audio from text using MiraTTS with voice cloning."""
 
126
  context_tokens = MODEL.encode_audio(prompt_audio_path)
127
 
128
  # Move context tokens to device if needed
129
+ if torch.is_tensor(context_tokens) and DEVICE == "cuda":
130
+ try:
131
+ context_tokens = context_tokens.to(DEVICE)
132
+ except Exception as e:
133
+ logging.warning(f"Failed to move tensors to CUDA: {e}")
134
 
135
+ # Generate audio with appropriate context
136
+ try:
137
+ if DEVICE == "cpu":
138
+ with torch.inference_mode():
139
+ audio = MODEL.generate(text, context_tokens)
140
+ else:
141
+ with torch.cuda.amp.autocast():
142
+ audio = MODEL.generate(text, context_tokens)
143
+ except Exception as e:
144
+ # Fallback to simple generation if autocast fails
145
+ logging.warning(f"Autocast failed: {e}, using standard generation")
146
+ with torch.inference_mode():
147
+ audio = MODEL.generate(text, context_tokens)
148
 
149
  # Convert to numpy array if it's a tensor and handle dtype
150
  if torch.is_tensor(audio):
 
282
  # Generate audio with dtype conversion
283
  context_tokens = MODEL.encode_audio(default_audio)
284
 
285
+ # Move to device safely
286
+ if torch.is_tensor(context_tokens) and DEVICE == "cuda":
287
+ try:
288
+ context_tokens = context_tokens.to(DEVICE)
289
+ except Exception as e:
290
+ logging.warning(f"Failed to move tensors to CUDA: {e}")
291
 
292
+ try:
293
+ if DEVICE == "cpu":
294
+ with torch.inference_mode():
295
+ audio = MODEL.generate(text, context_tokens)
296
+ else:
297
+ with torch.cuda.amp.autocast():
298
+ audio = MODEL.generate(text, context_tokens)
299
+ except Exception as e:
300
+ # Fallback to simple generation
301
+ logging.warning(f"Autocast failed: {e}, using standard generation")
302
+ with torch.inference_mode():
303
+ audio = MODEL.generate(text, context_tokens)
304
 
305
  # Handle tensor conversion and dtype
306
  if torch.is_tensor(audio):
 
391
  # Device info
392
  device_info = f"🖥️ Running on: **{DEVICE.upper()}**"
393
  if DEVICE == "cuda":
394
+ try:
395
+ device_info += f" (GPU: {torch.cuda.get_device_name(0)})"
396
+ except:
397
+ device_info += " (GPU)"
398
+ else:
399
+ device_info += " (CPU mode - slower but works without GPU)"
400
  gr.Markdown(device_info)
401
 
402
  # Description
 
623
 
624
  # Set device if specified
625
  if args.device:
626
+ if args.device == "cuda":
627
+ try:
628
+ if not torch.cuda.is_available():
629
+ logging.warning("CUDA requested but not available, falling back to CPU")
630
+ DEVICE = "cpu"
631
+ else:
632
+ torch.cuda.current_device() # Test CUDA access
633
+ DEVICE = args.device
634
+ except Exception as e:
635
+ logging.warning(f"CUDA test failed: {e}, falling back to CPU")
636
+ DEVICE = "cpu"
637
+ else:
638
+ DEVICE = args.device
639
+
640
+ logging.info(f"Device selected: {DEVICE}")
641
 
642
  # Initialize model
643
  logging.info("Initializing MiraTTS model...")