Capstone04 commited on
Commit
2d43080
·
verified ·
1 Parent(s): 7a22c1a

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. asr_diarization/pipeline.py +4 -11
asr_diarization/pipeline.py CHANGED
@@ -15,13 +15,14 @@ class ASR_Diarization:
15
  def __init__(self, HF_TOKEN,
16
  diar_model="pyannote/speaker-diarization-3.1",
17
  asr_model="Capstone04/TrainedWhisper",
18
- model_path = "None"):
19
  self.HF_TOKEN = HF_TOKEN
20
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
  # Load diarization model
23
  self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
24
 
 
25
  if model_path and os.path.exists(model_path):
26
  print(f"🔄 Loading custom ASR model from: {model_path}")
27
  actual_asr_model = model_path
@@ -29,8 +30,8 @@ class ASR_Diarization:
29
  print(f"🔄 Loading default ASR model: {asr_model}")
30
  actual_asr_model = asr_model
31
 
32
- processor = WhisperProcessor.from_pretrained(asr_model, token=HF_TOKEN)
33
- model = WhisperForConditionalGeneration.from_pretrained(asr_model, token=HF_TOKEN).to(self.device)
34
 
35
  self.asr_pipeline = hf_pipeline(
36
  "automatic-speech-recognition",
@@ -41,14 +42,6 @@ class ASR_Diarization:
41
  return_timestamps=True
42
  )
43
 
44
- # model_id = "Capstone04/TrainedWhisper"
45
- # self.asr_pipeline = hf_pipeline(
46
- # "automatic-speech-recognition",
47
- # model=model_id,
48
- # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
49
- # device=0 if torch.cuda.is_available() else -1,
50
- # )
51
-
52
  def run_diarization(self, audio_path):
53
  diarization = self.diar_pipeline(audio_path)
54
  return [
 
15
  def __init__(self, HF_TOKEN,
16
  diar_model="pyannote/speaker-diarization-3.1",
17
  asr_model="Capstone04/TrainedWhisper",
18
+ model_path=None): # NEW: model_path parameter
19
  self.HF_TOKEN = HF_TOKEN
20
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
  # Load diarization model
23
  self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
24
 
25
+ # MODIFIED: Use custom model_path if provided, otherwise use asr_model
26
  if model_path and os.path.exists(model_path):
27
  print(f"🔄 Loading custom ASR model from: {model_path}")
28
  actual_asr_model = model_path
 
30
  print(f"🔄 Loading default ASR model: {asr_model}")
31
  actual_asr_model = asr_model
32
 
33
+ processor = WhisperProcessor.from_pretrained(actual_asr_model, token=HF_TOKEN)
34
+ model = WhisperForConditionalGeneration.from_pretrained(actual_asr_model, token=HF_TOKEN).to(self.device)
35
 
36
  self.asr_pipeline = hf_pipeline(
37
  "automatic-speech-recognition",
 
42
  return_timestamps=True
43
  )
44
 
 
 
 
 
 
 
 
 
45
  def run_diarization(self, audio_path):
46
  diarization = self.diar_pipeline(audio_path)
47
  return [