File size: 6,131 Bytes
994a3aa
 
 
 
 
 
 
 
 
21d573e
994a3aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21d573e
994a3aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6b4d67
994a3aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb2e16d
994a3aa
 
bb2e16d
bfbc67e
bb2e16d
994a3aa
bb2e16d
994a3aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb2e16d
994a3aa
 
 
 
 
303bfa4
994a3aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21d573e
994a3aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import base64
import gc
import logging
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor
from typing import List

import torch
from faster_whisper import WhisperModel
from pydantic import ValidationError
from starlette.exceptions import HTTPException

from alignment import load_align_model, align
from config import InferenceConfig, model_settings
from diarize import DiarizationPipeline, assign_word_speakers
from schema import SingleSegment
from utils import load_audio

# Get current LD_LIBRARY_PATH
original = os.environ.get("LD_LIBRARY_PATH", "")

cudnn_path = "/opt/conda/lib/python3.11/site-packages/nvidia/cudnn/lib/"
os.environ['LD_LIBRARY_PATH'] = original + ":" + cudnn_path

logger = logging.getLogger(__name__)


class EndpointHandler():
    def __init__(self, path=""):
        device = "cuda"
        self.asr_pipeline = WhisperModel(
            model_settings.asr_model,
            device=device,
            compute_type="float16",
            download_root="cache"
        )

        model_a, metadata = load_align_model(
            language_code=model_settings.language,
            device=device,
            model_name=model_settings.align_model,
            model_dir="cache",
        )
        self.align_model = model_a
        self.align_metadata = metadata

        self.diarize_model = DiarizationPipeline(
            token=model_settings.hf_token,
            device=device
        )

    def __call__(self, inputs):
        file = inputs.pop("inputs")
        file = base64.b64decode(file)
        parameters = inputs.pop("parameters", {})
        try:
            parameters = InferenceConfig(**parameters)
        except ValidationError as e:
            logger.error(f"Error validating parameters: {e}")
            raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}")

        logger.info(f"inference parameters: {parameters}")
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
            tmp.write(file)
            tmp_path = tmp.name
        audio = load_audio(tmp_path)
        os.remove(tmp_path)

        def run_asr(_audio):
            try:
                _segments, _info = self.asr_pipeline.transcribe(
                    _audio,
                    language=model_settings.language,
                    condition_on_previous_text=False,
                    word_timestamps=False,
                    vad_filter=parameters.vad_filter
                )
                align_segments: List[SingleSegment] = []
                for seg in _segments:
                    align_segments.append(
                        {
                            "start": seg.start,
                            "end": seg.end,
                            "text": seg.text,
                        }
                    )
                return align_segments, _info
            except RuntimeError as e:
                logger.error(f"ASR inference error: {str(e)}")
                raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
            except Exception as e:
                logger.error(f"Unknown error diring ASR inference: {str(e)}")
                raise HTTPException(status_code=500, detail=f"Unknown error during ASR inference: {str(e)}")

        def run_alignment(_segments: List[SingleSegment], _audio):
            stream = torch.cuda.Stream()
            with torch.cuda.stream(stream):
                return align(
                    _segments,
                    self.align_model,
                    self.align_metadata,
                    _audio,
                    "cuda",
                )

        def run_diarization(_audio):
            if not self.diarize_model:
                return None
            stream = torch.cuda.Stream()
            with torch.cuda.stream(stream):
                try:
                    diarize_segments, _embeddings = self.diarize_model(
                        _audio,
                        min_speakers=parameters.min_speakers,
                        max_speakers=parameters.max_speakers,
                        num_speakers=parameters.num_speakers,
                        return_embeddings=True
                    )
                    return diarize_segments, _embeddings
                except RuntimeError as e:
                    logger.error(f"Diarization inference error: {str(e)}")
                    raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
                except Exception as e:
                    logger.error(f"Unknown error during diarization: {str(e)}")
                    raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")

        # 1. Run ASR Sequentially (Heaviest operation)
        segments, info = run_asr(audio)

        # 2. Clear VRAM to make room for parallel execution
        gc.collect()
        torch.cuda.empty_cache()

        with ThreadPoolExecutor() as executor:
            align_future = executor.submit(run_alignment, segments, audio)
            diarization_future = executor.submit(run_diarization, audio)

            _result = align_future.result()
            diarization_output, embeddings = diarization_future.result()
        result = []
        if diarization_output is not None and _result:
            result = assign_word_speakers(
                diarization_output,
                _result,
                embeddings
            )
        # Final cleanup
        del diarization_output, segments, audio
        gc.collect()
        torch.cuda.empty_cache()

        diarization = []

        if result:
            diarization = [f'{seg.get("speaker", "UNKNOWN")}: {seg.get("text", "").strip()}' for seg in
                           result.get("segments", [])]

        return {
            "result": result["segments"],
            "full_transcription": " ".join([seg.get("text", "").strip() for seg in result["segments"]]),
            "diarization": diarization,
            "asr_model": model_settings.asr_model,
            "speaker_embeddings": embeddings
        }