Capstone04 commited on
Commit
70eeab8
·
verified ·
1 Parent(s): 236dd81

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. asr_diarization/pipeline.py +16 -16
asr_diarization/pipeline.py CHANGED
@@ -14,33 +14,33 @@ from jiwer import wer, Compose, ToLowerCase, RemovePunctuation, RemoveMultipleSp
14
  class ASR_Diarization:
15
  def __init__(self, HF_TOKEN,
16
  diar_model="pyannote/speaker-diarization-3.1",
17
- asr_model="Capstone04/TrainedWhisper"):
18
  self.HF_TOKEN = HF_TOKEN
19
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  # Load diarization model
22
  self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
23
 
24
- # processor = WhisperProcessor.from_pretrained(asr_model, token=HF_TOKEN)
25
- # model = WhisperForConditionalGeneration.from_pretrained(asr_model, token=HF_TOKEN).to(self.device)
26
 
27
- # self.asr_pipeline = hf_pipeline(
28
- # "automatic-speech-recognition",
29
- # model=model,
30
- # tokenizer=processor.tokenizer,
31
- # feature_extractor=processor.feature_extractor,
32
- # device=0 if self.device == "cuda" else -1,
33
- # return_timestamps=True
34
- # )
35
-
36
- model_id = "Capstone04/TrainedWhisper"
37
  self.asr_pipeline = hf_pipeline(
38
  "automatic-speech-recognition",
39
- model=model_id,
40
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
41
- device=0 if torch.cuda.is_available() else -1,
 
 
42
  )
43
 
 
 
 
 
 
 
 
 
44
  def run_diarization(self, audio_path):
45
  diarization = self.diar_pipeline(audio_path)
46
  return [
 
14
  class ASR_Diarization:
15
  def __init__(self, HF_TOKEN,
16
  diar_model="pyannote/speaker-diarization-3.1",
17
+ asr_model="openai/whisper-large-v3"):
18
  self.HF_TOKEN = HF_TOKEN
19
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  # Load diarization model
22
  self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
23
 
24
+ processor = WhisperProcessor.from_pretrained(asr_model, token=HF_TOKEN)
25
+ model = WhisperForConditionalGeneration.from_pretrained(asr_model, token=HF_TOKEN).to(self.device)
26
 
 
 
 
 
 
 
 
 
 
 
27
  self.asr_pipeline = hf_pipeline(
28
  "automatic-speech-recognition",
29
+ model=model,
30
+ tokenizer=processor.tokenizer,
31
+ feature_extractor=processor.feature_extractor,
32
+ device=0 if self.device == "cuda" else -1,
33
+ return_timestamps=True
34
  )
35
 
36
+ # model_id = "Capstone04/TrainedWhisper"
37
+ # self.asr_pipeline = hf_pipeline(
38
+ # "automatic-speech-recognition",
39
+ # model=model_id,
40
+ # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
41
+ # device=0 if torch.cuda.is_available() else -1,
42
+ # )
43
+
44
  def run_diarization(self, audio_path):
45
  diarization = self.diar_pipeline(audio_path)
46
  return [