liuyang commited on
Commit
e48217c
·
1 Parent(s): 6d56dd1

restore diarization initialization

Browse files
Files changed (1) hide show
  1. app.py +28 -13
app.py CHANGED
@@ -32,11 +32,39 @@ from faster_whisper import WhisperModel
32
  from faster_whisper.vad import VadOptions
33
  import requests
34
  import base64
 
35
 
36
  # Lazy global holder ----------------------------------------------------------
37
  _whisper = None
38
  _diarizer = None
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  @spaces.GPU # GPU is guaranteed to exist *inside* this function
41
  def _load_models():
42
  global _whisper, _diarizer
@@ -48,19 +76,6 @@ def _load_models():
48
  compute_type="float16",
49
  )
50
  print("Whisper model loaded successfully")
51
- if _diarizer is None:
52
- print("Loading diarization model...")
53
- try:
54
- from pyannote.audio import Pipeline
55
- _diarizer = Pipeline.from_pretrained(
56
- "pyannote/speaker-diarization-3.1",
57
- use_auth_token=os.getenv("HF_TOKEN"),
58
- torch_dtype=torch.float16,
59
- ).to(torch.device("cuda"))
60
- print("Diarization model loaded successfully")
61
- except Exception as e:
62
- print(f"Could not load diarization model: {e}")
63
- _diarizer = None
64
  return _whisper, _diarizer
65
 
66
  # -----------------------------------------------------------------------------
 
32
  from faster_whisper.vad import VadOptions
33
  import requests
34
  import base64
35
+ from pyannote.audio import Pipeline
36
 
37
  # Lazy global holder ----------------------------------------------------------
38
  _whisper = None
39
  _diarizer = None
40
 
41
+
42
+ # Create global diarization pipeline
43
+ try:
44
+ print("Loading diarization model...")
45
+ torch.backends.cuda.matmul.allow_tf32 = True
46
+ torch.backends.cudnn.allow_tf32 = True
47
+ torch.set_float32_matmul_precision('high')
48
+
49
+ _diarizer = Pipeline.from_pretrained(
50
+ "pyannote/speaker-diarization-3.1",
51
+ use_auth_token=os.getenv("HF_TOKEN"),
52
+ torch_dtype=torch.float16,
53
+ ).to(torch.device("cuda"))
54
+ _diarizer.model.half() # FP16
55
+
56
+ for m in _diarizer.model.modules(): # compact LSTM weights
57
+ if isinstance(m, torch.nn.LSTM):
58
+ m.flatten_parameters()
59
+
60
+ _diarizer.model = torch.compile(_diarizer.model, mode="reduce-overhead")
61
+ print("Diarization model loaded successfully")
62
+ except Exception as e:
63
+ import traceback
64
+ traceback.print_exc()
65
+ print(f"Could not load diarization model: {e}")
66
+ _diarizer = None
67
+
68
  @spaces.GPU # GPU is guaranteed to exist *inside* this function
69
  def _load_models():
70
  global _whisper, _diarizer
 
76
  compute_type="float16",
77
  )
78
  print("Whisper model loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  return _whisper, _diarizer
80
 
81
  # -----------------------------------------------------------------------------