File size: 958 Bytes
b7c959f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
from transformers import AutoFeatureExtractor, WhisperForAudioClassification
import librosa
def get_language(audio_path):
feature_extractor = AutoFeatureExtractor.from_pretrained("/home/investigacion/disco4TB/workspace_pablo/firvox_whisper_research/whisper-medium-fleurs-lang-id/lang_identification_models_noFirVox_audios")
model = WhisperForAudioClassification.from_pretrained("/home/investigacion/disco4TB/workspace_pablo/firvox_whisper_research/whisper-medium-fleurs-lang-id/lang_identification_models_noFirVox_audios").to("cuda")
audio, sr= librosa.load(audio_path, sr=16000)
inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt")
input_features = inputs.input_features.to("cuda")
with torch.no_grad():
logits = model(input_features).logits
predicted_class_ids = torch.argmax(logits).item()
predicted_label = model.config.id2label[predicted_class_ids]
return predicted_label |