File size: 5,890 Bytes
e4dac16
 
c14b9cc
a7d837f
8bc2122
ada9222
77b5d4d
a7d837f
e4dac16
 
 
77b5d4d
 
c14b9cc
e4dac16
 
 
 
 
 
 
1dec55c
77b5d4d
cb447db
ada9222
2d31b3a
77b5d4d
 
e4dac16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7a1c32
77b5d4d
 
19111b8
7915180
77b5d4d
 
 
 
 
b7ae82a
81de640
c14b9cc
 
 
e4dac16
c14b9cc
e4dac16
 
a1fe107
e4dac16
 
 
 
 
 
85cf5ce
e4dac16
 
 
 
8ab5eff
 
 
 
85cf5ce
 
8ab5eff
85cf5ce
e4dac16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dec55c
77b5d4d
e4dac16
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gc
import logging
import copy
import logging
import base64
import os

import torch
from huggingface_hub import HfApi
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pydantic import ValidationError
from starlette.exceptions import HTTPException
from torchaudio import functional as F
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers.pipelines.audio_utils import ffmpeg_read

from config import model_settings, InferenceConfig
from diarization_utils import SpeakerAligner, preprocess_inputs, diarize
import torch


logger = logging.getLogger(__name__)

HF_TOKEN = os.environ.get("HF_TOKEN")

class EndpointHandler():
    def __init__(self, path=""):

        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

        model_id = model_settings.asr_model
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            model_id, torch_dtype=torch_dtype, use_safetensors=True, cache_dir="cache"
        )
        model.to(device)
        processor = AutoProcessor.from_pretrained(model_id)
        self.processor = processor

        self.asr_pipeline = pipeline(
            "automatic-speech-recognition",
            model=model,
            tokenizer=processor.tokenizer,
            feature_extractor=processor.feature_extractor,
            torch_dtype=torch_dtype,
            device=device,
            generate_kwargs={"max_new_tokens": 400},
            chunk_length_s=5,
            stride_length_s=(1, 1),
        )

        if model_settings.diarization_model:
            # diarization pipeline doesn't raise if there is no token
            HfApi().whoami(model_settings.hf_token)
            self.diarization_pipeline = Pipeline.from_pretrained(
                checkpoint_path=model_settings.diarization_model,
                use_auth_token=model_settings.hf_token,
            )
            self.diarization_pipeline.to(device)
        else:
            self.diarization_pipeline = None

    def __call__(self, inputs):
        file = inputs.pop("inputs")
        file = base64.b64decode(file)
        parameters = inputs.pop("parameters", {})
        try:
            parameters = InferenceConfig(**parameters)
        except ValidationError as e:
            logger.error(f"Error validating parameters: {e}")
            raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}")

        logger.info(f"inference parameters: {parameters}")
        audio_nparray = ffmpeg_read(file, parameters.sampling_rate)
        copy_audio = copy.deepcopy(audio_nparray)
        if parameters.sampling_rate != 16000:
            resampled = F.resample(torch.from_numpy(copy_audio), parameters.sampling_rate, 16000).numpy()
        else:
            resampled = audio_nparray
        audio_tensor = torch.from_numpy(resampled).unsqueeze(0)

        generate_kwargs = {
            "task": parameters.task,
            "language": parameters.language if parameters.language else "sv"
        }
        logger.info(f'params: {generate_kwargs}')
        asr_inputs = {"array": resampled, "sampling_rate": 16000}
        try:
            asr_outputs = self.asr_pipeline(
                asr_inputs,
                generate_kwargs=generate_kwargs,
                return_timestamps=True,
            )
        except RuntimeError as e:
            logger.error(f"ASR inference error: {str(e)}")
            raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
        except Exception as e:
            logger.error(f"Unknown error diring ASR inference: {str(e)}")
            raise HTTPException(status_code=500, detail=f"Unknown error diring ASR inference: {str(e)}")

        if self.diarization_pipeline:
            try:
                with ProgressHook() as progress_hook:
                    aligner = SpeakerAligner()

                    # Align the ASR outputs with diarization segments
                    # inputs, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate)

                    transcript = self.diarization_pipeline(
                        {"waveform": audio_tensor, "sample_rate": 16000},
                        hook=progress_hook,
                        num_speakers=parameters.num_speakers,
                        min_speakers=parameters.min_speakers,
                        max_speakers=parameters.max_speakers,
                    )
                    speaker_transcriptions = aligner.align(asr_outputs["text"], asr_outputs["chunks"], transcript)
            except RuntimeError as e:
                logger.error(f"Diarization inference error: {str(e)}")
                raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
            except Exception as e:
                logger.error(f"Unknown error during diarization: {str(e)}")
                raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")
            try:
                transcript_ = diarize(self.diarization_pipeline, file, parameters, asr_outputs)
            except RuntimeError as e:
                logger.error(f"Diarization_ inference error: {str(e)}")
                raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
            except Exception as e:
                logger.error(f"Unknown_ error during diarization: {str(e)}")
                raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")
        else:
            transcript_ = []

        return {
            "speakers": speaker_transcriptions,
            "speakers_": transcript_,
            "chunks": asr_outputs["chunks"],
            "text": asr_outputs["text"],
        }