Approximetal commited on
Commit
c4db34f
·
verified ·
1 Parent(s): 0c63007

Update inference_gradio.py

Browse files
Files changed (1) hide show
  1. inference_gradio.py +7 -2
inference_gradio.py CHANGED
@@ -396,6 +396,7 @@ class WhisperxModel:
396
  def __init__(self, model_name):
397
  from pathlib import Path
398
  import whisperx.vad as wx_vad
 
399
  from whisperx import load_model
400
  prompt = None # "This might be a blend of Simplified Chinese and English speech, do not translate, only transcription be allowed."
401
 
@@ -408,8 +409,9 @@ class WhisperxModel:
408
  )
409
  vad_fp = None
410
  else:
411
- # Monkey-patch whisperx.vad.load_vad_model so it loads our local
412
- # segmentation model without enforcing the baked-in SHA256 check.
 
413
  def _patched_load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
414
  import torch
415
  from pyannote.audio import Model
@@ -430,7 +432,10 @@ class WhisperxModel:
430
  vad_pipeline.instantiate(hyperparameters)
431
  return vad_pipeline
432
 
 
 
433
  wx_vad.load_vad_model = _patched_load_vad_model
 
434
 
435
  self.model = load_model(
436
  model_name,
 
396
  def __init__(self, model_name):
397
  from pathlib import Path
398
  import whisperx.vad as wx_vad
399
+ import whisperx.asr as wx_asr
400
  from whisperx import load_model
401
  prompt = None # "This might be a blend of Simplified Chinese and English speech, do not translate, only transcription be allowed."
402
 
 
409
  )
410
  vad_fp = None
411
  else:
412
+ # Monkey-patch whisperx *before* constructing the pipeline so it
413
+ # loads our local segmentation model without enforcing the
414
+ # baked-in SHA256 checksum.
415
  def _patched_load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
416
  import torch
417
  from pyannote.audio import Model
 
432
  vad_pipeline.instantiate(hyperparameters)
433
  return vad_pipeline
434
 
435
+ # asr.py does `from .vad import load_vad_model`, so we must patch
436
+ # both the `vad` module and the alias in `asr`.
437
  wx_vad.load_vad_model = _patched_load_vad_model
438
+ wx_asr.load_vad_model = _patched_load_vad_model
439
 
440
  self.model = load_model(
441
  model_name,