File size: 7,227 Bytes
7987a90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import math
import os
import sys
import numpy as np
import torch
from faster_whisper import WhisperModel
from pyannote.audio import Pipeline
from pydub import AudioSegment
from speechbrain.inference.classifiers import EncoderClassifier
model_size = "large-v3"
# Run on GPU with FP16
# model = WhisperModel(model_size, device="cpu", compute_type="float16")
INDEX_TO_LANG = {
0: 'Abkhazian', 1: 'Afrikaans', 2: 'Amharic', 3: 'Arabic', 4: 'Assamese',
5: 'Azerbaijani', 6: 'Bashkir', 7: 'Belarusian', 8: 'Bulgarian', 9: 'Bengali',
10: 'Tibetan', 11: 'Breton', 12: 'Bosnian', 13: 'Catalan', 14: 'Cebuano',
15: 'Czech', 16: 'Welsh', 17: 'Danish', 18: 'German', 19: 'Greek',
20: 'English', 21: 'Esperanto', 22: 'Spanish', 23: 'Estonian', 24: 'Basque',
25: 'Persian', 26: 'Finnish', 27: 'Faroese', 28: 'French', 29: 'Galician',
30: 'Guarani', 31: 'Gujarati', 32: 'Manx', 33: 'Hausa', 34: 'Hawaiian',
35: 'Hindi', 36: 'Croatian', 37: 'Haitian', 38: 'Hungarian', 39: 'Armenian',
40: 'Interlingua', 41: 'Indonesian', 42: 'Icelandic', 43: 'Italian', 44: 'Hebrew',
45: 'Japanese', 46: 'Javanese', 47: 'Georgian', 48: 'Kazakh', 49: 'Central Khmer',
50: 'Kannada', 51: 'Korean', 52: 'Latin', 53: 'Luxembourgish', 54: 'Lingala',
55: 'Lao', 56: 'Lithuanian', 57: 'Latvian', 58: 'Malagasy', 59: 'Maori',
60: 'Macedonian', 61: 'Malayalam', 62: 'Mongolian', 63: 'Marathi', 64: 'Malay',
65: 'Maltese', 66: 'Burmese', 67: 'Nepali', 68: 'Dutch', 69: 'Norwegian Nynorsk',
70: 'Norwegian', 71: 'Occitan', 72: 'Panjabi', 73: 'Polish', 74: 'Pushto',
75: 'Portuguese', 76: 'Romanian', 77: 'Russian', 78: 'Sanskrit', 79: 'Scots',
80: 'Sindhi', 81: 'Sinhala', 82: 'Slovak', 83: 'Slovenian', 84: 'Shona',
85: 'Somali', 86: 'Albanian', 87: 'Serbian', 88: 'Sundanese', 89: 'Swedish',
90: 'Swahili', 91: 'Tamil', 92: 'Telugu', 93: 'Tajik', 94: 'Thai',
95: 'Turkmen', 96: 'Tagalog', 97: 'Turkish', 98: 'Tatar', 99: 'Ukrainian',
100: 'Urdu', 101: 'Uzbek', 102: 'Vietnamese', 103: 'Waray', 104: 'Yiddish',
105: 'Yoruba', 106: 'Chinese'
}
LANG_TO_INDEX = {v: k for k, v in INDEX_TO_LANG.items()}
def identify_languages(file_path, languages: list[str] = ["Russian", "Belarusian", "Ukrainian", "Kazakh"]) -> dict[
str, float]:
language_id = EncoderClassifier.from_hparams(source="speechbrain/lang-id-voxlingua107-ecapa")
signal = language_id.load_audio(file_path)
lang_scores, _, _, _ = language_id.classify_batch(signal)
all_scores = {INDEX_TO_LANG[i]: 100 * math.exp(score) for i, score in enumerate(lang_scores[0])}
selected_scores = {lang: float(all_scores[lang]) for lang in languages}
return selected_scores
def detect_language_local(file_path):
language_scores = identify_languages(file_path)
language_result = max(language_scores, key=language_scores.get)
if language_result.lower() in ["russian", "belarusian", "ukrainian"]:
selected_language = "ru"
else:
selected_language = "kk"
return selected_language
def transcribe_and_diarize_audio(filename, language):
diarized_segments = _diarize_audio(filename)
combined_diarized_segments = _combine_segments_with_same_speaker(diarized_segments)
transcribed_segments = _transcribe_audio(filename, language)
pure_text = "\n".join(segment.text for segment in transcribed_segments)
segments = _combine_diarized_and_transcribed_segments(
combined_diarized_segments,
transcribed_segments,
)
diarized_text = " ".join(
"[%.1fs -> %.1fs] (%s) %s" % (segment["start"], segment["end"], segment["speaker"], segment["text"]) for segment
in segments)
return pure_text, diarized_text
diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=os.getenv('HUGGING_FACE_TOKEN'),
)
model_size = "large-v3"
transcription_model = WhisperModel(
model_size,
device="cuda:0" if torch.cuda.is_available() else "cpu",
# device="cpu",
compute_type="int8",
# compute_type="int8_float16",
# compute_type="float32"
)
def get_audio_length_in_minutes(file_path):
audio = AudioSegment.from_file(file_path)
duration = len(audio)
return round(duration / 60000, 2)
def _diarize_audio(filename):
diarization_pipeline.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# diarization_pipeline.to(torch.device("cpu"))
diarization_result = diarization_pipeline(filename, max_speakers=2, min_speakers=2)
diarized_segments = []
for turn, _, speaker in diarization_result.itertracks(yield_label=True):
diarized_segments.append(
{
"segment": {"start": turn.start, "end": turn.end},
"speaker": speaker,
}
)
return diarized_segments
def _combine_segments_with_same_speaker(segments):
new_segments = []
prev_segment = cur_segment = segments[0]
for i in range(1, len(segments)):
cur_segment = segments[i]
# check if we have changed speaker ("label")
if cur_segment["speaker"] != prev_segment["speaker"] and i < len(segments):
# add the start/end times for the super-segment to the new list
new_segments.append(
{
"segment": {
"start": prev_segment["segment"]["start"],
"end": cur_segment["segment"]["start"],
},
"speaker": prev_segment["speaker"],
}
)
prev_segment = segments[i]
return new_segments
def _transcribe_audio(filename, language):
segments, _ = transcription_model.transcribe(
filename,
beam_size=20,
language=language,
)
return list(segments)
def _combine_diarized_and_transcribed_segments(diarized_segments, transcribed_segments):
# get the end timestamps for each chunk from the ASR output
end_timestamps = np.array(
[
(chunk.end if chunk.end is not None else sys.float_info.max)
for chunk in transcribed_segments
]
)
segmented_preds = []
# align the diarizer timestamps and the ASR timestamps
for segment in diarized_segments:
# get the diarizer end timestamp
end_time = segment["segment"]["end"]
# find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
upto_idx = np.argmin(np.abs(end_timestamps - end_time))
segmented_preds.append(
{
"speaker": segment["speaker"],
"text": "".join(
[chunk.text for chunk in transcribed_segments[: upto_idx + 1]]
),
"start": transcribed_segments[0].start,
"end": transcribed_segments[upto_idx].end,
}
)
# crop the transcribed_segmentss and timestamp lists according to the latest timestamp (for faster argmin)
transcribed_segments = transcribed_segments[upto_idx + 1:]
end_timestamps = end_timestamps[upto_idx + 1:]
if len(end_timestamps) == 0:
break
return segmented_preds
|