Capstone04 commited on
Commit
c10f9e3
·
verified ·
1 Parent(s): 5fd6c86

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. asr_diarization/pipeline.py +7 -3
asr_diarization/pipeline.py CHANGED
@@ -21,12 +21,16 @@ class ASR_Diarization:
21
  # Load diarization model
22
  self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
23
 
24
- # Load ASR model with timestamps
 
 
25
  self.asr_pipeline = hf_pipeline(
26
  "automatic-speech-recognition",
27
- model=asr_model,
 
 
28
  device=0 if self.device == "cuda" else -1,
29
- return_timestamps=True,
30
  )
31
 
32
  def run_diarization(self, audio_path):
 
21
  # Load diarization model
22
  self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
23
 
24
+ processor = WhisperProcessor.from_pretrained(asr_model, token=HF_TOKEN)
25
+ model = WhisperForConditionalGeneration.from_pretrained(asr_model, token=HF_TOKEN).to(self.device)
26
+
27
  self.asr_pipeline = hf_pipeline(
28
  "automatic-speech-recognition",
29
+ model=model,
30
+ tokenizer=processor.tokenizer,
31
+ feature_extractor=processor.feature_extractor,
32
  device=0 if self.device == "cuda" else -1,
33
+ return_timestamps=True
34
  )
35
 
36
  def run_diarization(self, audio_path):