Rafiahmed Patel commited on
Commit
64846de
Β·
1 Parent(s): 2739b64

cpu for fatser whisper

Browse files
Files changed (1) hide show
  1. app.py +29 -16
app.py CHANGED
@@ -29,16 +29,24 @@ logging.basicConfig(
29
  logger = logging.getLogger(__name__)
30
 
31
  # Configuration - Auto-detect GPU
32
- if torch.cuda.is_available():
33
- DEVICE = "cuda"
34
- COMPUTE_TYPE = "float16" # Use float16 for GPU
 
 
35
  logger_temp = logging.getLogger(__name__)
36
- logger_temp.info(f"πŸš€ GPU detected! Using CUDA with {torch.cuda.get_device_name(0)}")
37
  else:
38
- DEVICE = "cpu"
39
- COMPUTE_TYPE = "int8" # Use int8 for CPU
40
  logger_temp = logging.getLogger(__name__)
41
- logger_temp.info("Running on CPU")
 
 
 
 
 
 
 
42
 
43
  # Set temp directory to writable location
44
  os.environ['TMPDIR'] = '/tmp'
@@ -58,23 +66,25 @@ tts_model = None
58
  # ==================== Model Loading ====================
59
 
60
  def load_models():
61
- """Load models on startup"""
62
  global whisper_model, tts_model
63
 
64
  if whisper_model is None:
65
  logger.info("Loading Whisper model...")
66
  whisper_model = WhisperModel(
67
  "small",
68
- device=DEVICE,
69
- compute_type=COMPUTE_TYPE,
70
  cpu_threads=4
71
  )
72
  logger.info("βœ… Whisper model loaded!")
73
 
74
  if tts_model is None:
75
  logger.info("Loading TTS model...")
76
- tts_model = ChatterboxMultilingualTTS.from_pretrained(device=DEVICE)
77
- logger.info("βœ… TTS model loaded!")
 
 
78
 
79
  return whisper_model, tts_model
80
 
@@ -560,10 +570,13 @@ def create_interface():
560
  # ==================== Main ====================
561
 
562
  if __name__ == "__main__":
563
- # Load models at startup
564
- logger.info("Initializing models...")
565
- load_models()
566
- logger.info("Models loaded successfully!")
 
 
 
567
 
568
  # Create and launch interface
569
  # .queue() is essential for long-running tasks like model generation
 
29
  logger = logging.getLogger(__name__)
30
 
31
  # Configuration - Auto-detect GPU
32
+ # Note: faster-whisper uses ctranslate2 which doesn't work well with ZeroGPU,
33
+ # so we always use CPU for Whisper. TTS will use GPU when available.
34
+ if torch.cuda.is_available() and not SPACES_AVAILABLE:
35
+ # Only use GPU for local CUDA setups, not ZeroGPU
36
+ TTS_DEVICE = "cuda"
37
  logger_temp = logging.getLogger(__name__)
38
+ logger_temp.info(f"πŸš€ GPU detected! Using CUDA with {torch.cuda.get_device_name(0)} for TTS")
39
  else:
40
+ TTS_DEVICE = "cpu"
 
41
  logger_temp = logging.getLogger(__name__)
42
+ if SPACES_AVAILABLE:
43
+ logger_temp.info("πŸš€ Running on ZeroGPU - TTS will use GPU inside decorated function")
44
+ else:
45
+ logger_temp.info("Running on CPU")
46
+
47
+ # Whisper always uses CPU (ctranslate2 compatibility)
48
+ WHISPER_DEVICE = "cpu"
49
+ WHISPER_COMPUTE_TYPE = "int8"
50
 
51
  # Set temp directory to writable location
52
  os.environ['TMPDIR'] = '/tmp'
 
66
  # ==================== Model Loading ====================
67
 
68
  def load_models():
69
+ """Load models (lazy loading for ZeroGPU compatibility)"""
70
  global whisper_model, tts_model
71
 
72
  if whisper_model is None:
73
  logger.info("Loading Whisper model...")
74
  whisper_model = WhisperModel(
75
  "small",
76
+ device=WHISPER_DEVICE,
77
+ compute_type=WHISPER_COMPUTE_TYPE,
78
  cpu_threads=4
79
  )
80
  logger.info("βœ… Whisper model loaded!")
81
 
82
  if tts_model is None:
83
  logger.info("Loading TTS model...")
84
+ # In ZeroGPU, determine device at runtime
85
+ tts_device = "cuda" if (SPACES_AVAILABLE and torch.cuda.is_available()) else TTS_DEVICE
86
+ tts_model = ChatterboxMultilingualTTS.from_pretrained(device=tts_device)
87
+ logger.info(f"βœ… TTS model loaded on {tts_device}!")
88
 
89
  return whisper_model, tts_model
90
 
 
570
  # ==================== Main ====================
571
 
572
  if __name__ == "__main__":
573
+ # Load models at startup (except in ZeroGPU where GPU isn't available yet)
574
+ if not SPACES_AVAILABLE:
575
+ logger.info("Initializing models...")
576
+ load_models()
577
+ logger.info("Models loaded successfully!")
578
+ else:
579
+ logger.info("Running in ZeroGPU mode - models will be loaded on first request")
580
 
581
  # Create and launch interface
582
  # .queue() is essential for long-running tasks like model generation