Capstone04 commited on
Commit
7a22c1a
·
verified ·
1 Parent(s): 4e1b90a

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. asr_diarization/pipeline.py +9 -1
asr_diarization/pipeline.py CHANGED
@@ -14,13 +14,21 @@ from jiwer import wer, Compose, ToLowerCase, RemovePunctuation, RemoveMultipleSp
14
  class ASR_Diarization:
15
  def __init__(self, HF_TOKEN,
16
  diar_model="pyannote/speaker-diarization-3.1",
17
- asr_model="Capstone04/TrainedWhisper"):
 
18
  self.HF_TOKEN = HF_TOKEN
19
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  # Load diarization model
22
  self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
23
 
 
 
 
 
 
 
 
24
  processor = WhisperProcessor.from_pretrained(asr_model, token=HF_TOKEN)
25
  model = WhisperForConditionalGeneration.from_pretrained(asr_model, token=HF_TOKEN).to(self.device)
26
 
 
14
  class ASR_Diarization:
15
  def __init__(self, HF_TOKEN,
16
  diar_model="pyannote/speaker-diarization-3.1",
17
+ asr_model="Capstone04/TrainedWhisper",
18
+ model_path = "None"):
19
  self.HF_TOKEN = HF_TOKEN
20
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
  # Load diarization model
23
  self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
24
 
25
+ if model_path and os.path.exists(model_path):
26
+ print(f"🔄 Loading custom ASR model from: {model_path}")
27
+ actual_asr_model = model_path
28
+ else:
29
+ print(f"🔄 Loading default ASR model: {asr_model}")
30
+ actual_asr_model = asr_model
31
+
32
  processor = WhisperProcessor.from_pretrained(asr_model, token=HF_TOKEN)
33
  model = WhisperForConditionalGeneration.from_pretrained(asr_model, token=HF_TOKEN).to(self.device)
34