Upload folder using huggingface_hub
Browse files- asr_diarization/pipeline.py +4 -11
asr_diarization/pipeline.py
CHANGED
|
@@ -15,13 +15,14 @@ class ASR_Diarization:
|
|
| 15 |
def __init__(self, HF_TOKEN,
|
| 16 |
diar_model="pyannote/speaker-diarization-3.1",
|
| 17 |
asr_model="Capstone04/TrainedWhisper",
|
| 18 |
-
model_path
|
| 19 |
self.HF_TOKEN = HF_TOKEN
|
| 20 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
|
| 22 |
# Load diarization model
|
| 23 |
self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
|
| 24 |
|
|
|
|
| 25 |
if model_path and os.path.exists(model_path):
|
| 26 |
print(f"🔄 Loading custom ASR model from: {model_path}")
|
| 27 |
actual_asr_model = model_path
|
|
@@ -29,8 +30,8 @@ class ASR_Diarization:
|
|
| 29 |
print(f"🔄 Loading default ASR model: {asr_model}")
|
| 30 |
actual_asr_model = asr_model
|
| 31 |
|
| 32 |
-
processor = WhisperProcessor.from_pretrained(
|
| 33 |
-
model = WhisperForConditionalGeneration.from_pretrained(
|
| 34 |
|
| 35 |
self.asr_pipeline = hf_pipeline(
|
| 36 |
"automatic-speech-recognition",
|
|
@@ -41,14 +42,6 @@ class ASR_Diarization:
|
|
| 41 |
return_timestamps=True
|
| 42 |
)
|
| 43 |
|
| 44 |
-
# model_id = "Capstone04/TrainedWhisper"
|
| 45 |
-
# self.asr_pipeline = hf_pipeline(
|
| 46 |
-
# "automatic-speech-recognition",
|
| 47 |
-
# model=model_id,
|
| 48 |
-
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 49 |
-
# device=0 if torch.cuda.is_available() else -1,
|
| 50 |
-
# )
|
| 51 |
-
|
| 52 |
def run_diarization(self, audio_path):
|
| 53 |
diarization = self.diar_pipeline(audio_path)
|
| 54 |
return [
|
|
|
|
| 15 |
def __init__(self, HF_TOKEN,
|
| 16 |
diar_model="pyannote/speaker-diarization-3.1",
|
| 17 |
asr_model="Capstone04/TrainedWhisper",
|
| 18 |
+
model_path=None): # NEW: model_path parameter
|
| 19 |
self.HF_TOKEN = HF_TOKEN
|
| 20 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
|
| 22 |
# Load diarization model
|
| 23 |
self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
|
| 24 |
|
| 25 |
+
# MODIFIED: Use custom model_path if provided, otherwise use asr_model
|
| 26 |
if model_path and os.path.exists(model_path):
|
| 27 |
print(f"🔄 Loading custom ASR model from: {model_path}")
|
| 28 |
actual_asr_model = model_path
|
|
|
|
| 30 |
print(f"🔄 Loading default ASR model: {asr_model}")
|
| 31 |
actual_asr_model = asr_model
|
| 32 |
|
| 33 |
+
processor = WhisperProcessor.from_pretrained(actual_asr_model, token=HF_TOKEN)
|
| 34 |
+
model = WhisperForConditionalGeneration.from_pretrained(actual_asr_model, token=HF_TOKEN).to(self.device)
|
| 35 |
|
| 36 |
self.asr_pipeline = hf_pipeline(
|
| 37 |
"automatic-speech-recognition",
|
|
|
|
| 42 |
return_timestamps=True
|
| 43 |
)
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def run_diarization(self, audio_path):
|
| 46 |
diarization = self.diar_pipeline(audio_path)
|
| 47 |
return [
|