File size: 7,227 Bytes
7987a90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import math
import os
import sys

import numpy as np
import torch
from faster_whisper import WhisperModel
from pyannote.audio import Pipeline
from pydub import AudioSegment
from speechbrain.inference.classifiers import EncoderClassifier

model_size = "large-v3"

# Run on GPU with FP16
# model = WhisperModel(model_size, device="cpu", compute_type="float16")

INDEX_TO_LANG = {
    0: 'Abkhazian', 1: 'Afrikaans', 2: 'Amharic', 3: 'Arabic', 4: 'Assamese',
    5: 'Azerbaijani', 6: 'Bashkir', 7: 'Belarusian', 8: 'Bulgarian', 9: 'Bengali',
    10: 'Tibetan', 11: 'Breton', 12: 'Bosnian', 13: 'Catalan', 14: 'Cebuano',
    15: 'Czech', 16: 'Welsh', 17: 'Danish', 18: 'German', 19: 'Greek',
    20: 'English', 21: 'Esperanto', 22: 'Spanish', 23: 'Estonian', 24: 'Basque',
    25: 'Persian', 26: 'Finnish', 27: 'Faroese', 28: 'French', 29: 'Galician',
    30: 'Guarani', 31: 'Gujarati', 32: 'Manx', 33: 'Hausa', 34: 'Hawaiian',
    35: 'Hindi', 36: 'Croatian', 37: 'Haitian', 38: 'Hungarian', 39: 'Armenian',
    40: 'Interlingua', 41: 'Indonesian', 42: 'Icelandic', 43: 'Italian', 44: 'Hebrew',
    45: 'Japanese', 46: 'Javanese', 47: 'Georgian', 48: 'Kazakh', 49: 'Central Khmer',
    50: 'Kannada', 51: 'Korean', 52: 'Latin', 53: 'Luxembourgish', 54: 'Lingala',
    55: 'Lao', 56: 'Lithuanian', 57: 'Latvian', 58: 'Malagasy', 59: 'Maori',
    60: 'Macedonian', 61: 'Malayalam', 62: 'Mongolian', 63: 'Marathi', 64: 'Malay',
    65: 'Maltese', 66: 'Burmese', 67: 'Nepali', 68: 'Dutch', 69: 'Norwegian Nynorsk',
    70: 'Norwegian', 71: 'Occitan', 72: 'Panjabi', 73: 'Polish', 74: 'Pushto',
    75: 'Portuguese', 76: 'Romanian', 77: 'Russian', 78: 'Sanskrit', 79: 'Scots',
    80: 'Sindhi', 81: 'Sinhala', 82: 'Slovak', 83: 'Slovenian', 84: 'Shona',
    85: 'Somali', 86: 'Albanian', 87: 'Serbian', 88: 'Sundanese', 89: 'Swedish',
    90: 'Swahili', 91: 'Tamil', 92: 'Telugu', 93: 'Tajik', 94: 'Thai',
    95: 'Turkmen', 96: 'Tagalog', 97: 'Turkish', 98: 'Tatar', 99: 'Ukrainian',
    100: 'Urdu', 101: 'Uzbek', 102: 'Vietnamese', 103: 'Waray', 104: 'Yiddish',
    105: 'Yoruba', 106: 'Chinese'
}
LANG_TO_INDEX = {v: k for k, v in INDEX_TO_LANG.items()}


def identify_languages(file_path, languages: list[str] = ["Russian", "Belarusian", "Ukrainian", "Kazakh"]) -> dict[
    str, float]:
    language_id = EncoderClassifier.from_hparams(source="speechbrain/lang-id-voxlingua107-ecapa")
    signal = language_id.load_audio(file_path)
    lang_scores, _, _, _ = language_id.classify_batch(signal)
    all_scores = {INDEX_TO_LANG[i]: 100 * math.exp(score) for i, score in enumerate(lang_scores[0])}
    selected_scores = {lang: float(all_scores[lang]) for lang in languages}

    return selected_scores


def detect_language_local(file_path):
    language_scores = identify_languages(file_path)
    language_result = max(language_scores, key=language_scores.get)
    if language_result.lower() in ["russian", "belarusian", "ukrainian"]:
        selected_language = "ru"
    else:
        selected_language = "kk"
    return selected_language


def transcribe_and_diarize_audio(filename, language):
    diarized_segments = _diarize_audio(filename)
    combined_diarized_segments = _combine_segments_with_same_speaker(diarized_segments)
    transcribed_segments = _transcribe_audio(filename, language)
    pure_text = "\n".join(segment.text for segment in transcribed_segments)
    segments = _combine_diarized_and_transcribed_segments(
        combined_diarized_segments,
        transcribed_segments,
    )
    diarized_text = " ".join(
        "[%.1fs -> %.1fs] (%s) %s" % (segment["start"], segment["end"], segment["speaker"], segment["text"]) for segment
        in segments)

    return pure_text, diarized_text


diarization_pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    use_auth_token=os.getenv('HUGGING_FACE_TOKEN'),
)
model_size = "large-v3"
transcription_model = WhisperModel(
    model_size,
    device="cuda:0" if torch.cuda.is_available() else "cpu",
    # device="cpu",
    compute_type="int8",
    # compute_type="int8_float16",
    # compute_type="float32"
)


def get_audio_length_in_minutes(file_path):
    audio = AudioSegment.from_file(file_path)
    duration = len(audio)
    return round(duration / 60000, 2)


def _diarize_audio(filename):
    diarization_pipeline.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    # diarization_pipeline.to(torch.device("cpu"))
    diarization_result = diarization_pipeline(filename, max_speakers=2, min_speakers=2)
    diarized_segments = []
    for turn, _, speaker in diarization_result.itertracks(yield_label=True):
        diarized_segments.append(
            {
                "segment": {"start": turn.start, "end": turn.end},
                "speaker": speaker,
            }
        )
    return diarized_segments


def _combine_segments_with_same_speaker(segments):
    new_segments = []
    prev_segment = cur_segment = segments[0]

    for i in range(1, len(segments)):
        cur_segment = segments[i]
        # check if we have changed speaker ("label")
        if cur_segment["speaker"] != prev_segment["speaker"] and i < len(segments):
            # add the start/end times for the super-segment to the new list
            new_segments.append(
                {
                    "segment": {
                        "start": prev_segment["segment"]["start"],
                        "end": cur_segment["segment"]["start"],
                    },
                    "speaker": prev_segment["speaker"],
                }
            )
            prev_segment = segments[i]
    return new_segments


def _transcribe_audio(filename, language):
    segments, _ = transcription_model.transcribe(
        filename,
        beam_size=20,
        language=language,
    )
    return list(segments)


def _combine_diarized_and_transcribed_segments(diarized_segments, transcribed_segments):
    # get the end timestamps for each chunk from the ASR output
    end_timestamps = np.array(
        [
            (chunk.end if chunk.end is not None else sys.float_info.max)
            for chunk in transcribed_segments
        ]
    )
    segmented_preds = []

    # align the diarizer timestamps and the ASR timestamps
    for segment in diarized_segments:
        # get the diarizer end timestamp
        end_time = segment["segment"]["end"]
        # find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
        upto_idx = np.argmin(np.abs(end_timestamps - end_time))

        segmented_preds.append(
            {
                "speaker": segment["speaker"],
                "text": "".join(
                    [chunk.text for chunk in transcribed_segments[: upto_idx + 1]]
                ),
                "start": transcribed_segments[0].start,
                "end": transcribed_segments[upto_idx].end,
            }
        )

        # crop the transcribed_segmentss and timestamp lists according to the latest timestamp (for faster argmin)
        transcribed_segments = transcribed_segments[upto_idx + 1:]
        end_timestamps = end_timestamps[upto_idx + 1:]

        if len(end_timestamps) == 0:
            break

    return segmented_preds