liuyang commited on
Commit
28a7e7e
·
1 Parent(s): 0cb30bb

Refactor model loading in app.py to return both Whisper and diarization models, enhancing GPU utilization during transcription processes.

Browse files
Files changed (1) hide show
  1. app.py +11 -31
app.py CHANGED
@@ -33,7 +33,6 @@ from faster_whisper.vad import VadOptions
33
  import requests
34
  import base64
35
  from pyannote.audio import Pipeline
36
- from huggingface_hub import snapshot_download
37
 
38
  import os, sys, importlib.util, pathlib, ctypes, tempfile, wave, math
39
  spec = importlib.util.find_spec("nvidia.cudnn")
@@ -49,8 +48,6 @@ try:
49
  except OSError as e:
50
  sys.exit(f"❌ Could not load {cnn_so} : {e}")
51
 
52
- model_cache_path = "large-v3-turbo" # fallback to model name
53
-
54
  # Lazy global holder ----------------------------------------------------------
55
  _whisper = None
56
  _diarizer = None
@@ -74,30 +71,18 @@ except Exception as e:
74
  print(f"Could not load diarization model: {e}")
75
  _diarizer = None
76
 
77
- # ---------------------------------------------------------------------
78
- # Leave _load_models() UNdecorated
79
  def _load_models():
80
  global _whisper, _diarizer
81
  if _whisper is None:
82
- _whisper = WhisperModel(model_cache_path,
83
- device="cuda",
84
- compute_type="float16")
85
- if _diarizer is None:
86
- _diarizer = (
87
- Pipeline.from_pretrained(
88
- "pyannote/speaker-diarization-3.1",
89
- use_auth_token=os.getenv("HF_TOKEN"),
90
- ).to(torch.device("cuda"))
91
  )
92
- # do NOT return the models
93
- return None
94
- # ---------------------------------------------------------------------
95
-
96
- # One‐shot GPU warming function
97
- @spaces.GPU
98
- def warm_models():
99
- _load_models() # runs in the GPU worker, models stay there
100
- return "ready" # <-- picklable
101
 
102
  # -----------------------------------------------------------------------------
103
  class WhisperTranscriber:
@@ -124,8 +109,7 @@ class WhisperTranscriber:
124
  @spaces.GPU # each call gets a GPU slice
125
  def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None):
126
  """Transcribe the entire audio file without speaker diarization"""
127
- #whisper, _ = _load_models() # models live on the GPU
128
- whisper = _whisper
129
 
130
  print("Transcribing full audio...")
131
  start_time = time.time()
@@ -218,8 +202,7 @@ class WhisperTranscriber:
218
  @spaces.GPU # each call gets a GPU slice
219
  def transcribe_audio_segments(self, audio_segments, language=None, translate=False, prompt=None):
220
  """Transcribe multiple audio segments using faster_whisper"""
221
- #whisper, diarizer = _load_models() # models live on the GPU
222
- whisper = _whisper
223
 
224
  print(f"Transcribing {len(audio_segments)} audio segments...")
225
  start_time = time.time()
@@ -293,9 +276,7 @@ class WhisperTranscriber:
293
  @spaces.GPU # each call gets a GPU slice
294
  def perform_diarization(self, audio_path, num_speakers=None):
295
  """Perform speaker diarization"""
296
- #whisper, diarizer = _load_models() # models live on the GPU
297
- whisper = _whisper
298
- diarizer = _diarizer
299
 
300
  if diarizer is None:
301
  print("Diarization model not available, creating single speaker segment")
@@ -634,5 +615,4 @@ with demo:
634
  """)
635
 
636
  if __name__ == "__main__":
637
- warm_models() # prime the GPU worker once at startup
638
  demo.launch(debug=True)
 
33
  import requests
34
  import base64
35
  from pyannote.audio import Pipeline
 
36
 
37
  import os, sys, importlib.util, pathlib, ctypes, tempfile, wave, math
38
  spec = importlib.util.find_spec("nvidia.cudnn")
 
48
  except OSError as e:
49
  sys.exit(f"❌ Could not load {cnn_so} : {e}")
50
 
 
 
51
  # Lazy global holder ----------------------------------------------------------
52
  _whisper = None
53
  _diarizer = None
 
71
  print(f"Could not load diarization model: {e}")
72
  _diarizer = None
73
 
74
+ @spaces.GPU # GPU is guaranteed to exist *inside* this function
 
75
  def _load_models():
76
  global _whisper, _diarizer
77
  if _whisper is None:
78
+ print("Loading Whisper model...")
79
+ _whisper = WhisperModel(
80
+ "large-v3-turbo",
81
+ device="cuda",
82
+ compute_type="float16",
 
 
 
 
83
  )
84
+ print("Whisper model loaded successfully")
85
+ return _whisper, _diarizer
 
 
 
 
 
 
 
86
 
87
  # -----------------------------------------------------------------------------
88
  class WhisperTranscriber:
 
109
  @spaces.GPU # each call gets a GPU slice
110
  def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None):
111
  """Transcribe the entire audio file without speaker diarization"""
112
+ whisper, _ = _load_models() # models live on the GPU
 
113
 
114
  print("Transcribing full audio...")
115
  start_time = time.time()
 
202
  @spaces.GPU # each call gets a GPU slice
203
  def transcribe_audio_segments(self, audio_segments, language=None, translate=False, prompt=None):
204
  """Transcribe multiple audio segments using faster_whisper"""
205
+ whisper, diarizer = _load_models() # models live on the GPU
 
206
 
207
  print(f"Transcribing {len(audio_segments)} audio segments...")
208
  start_time = time.time()
 
276
  @spaces.GPU # each call gets a GPU slice
277
  def perform_diarization(self, audio_path, num_speakers=None):
278
  """Perform speaker diarization"""
279
+ whisper, diarizer = _load_models() # models live on the GPU
 
 
280
 
281
  if diarizer is None:
282
  print("Diarization model not available, creating single speaker segment")
 
615
  """)
616
 
617
  if __name__ == "__main__":
 
618
  demo.launch(debug=True)