File size: 5,890 Bytes
e4dac16 c14b9cc a7d837f 8bc2122 ada9222 77b5d4d a7d837f e4dac16 77b5d4d c14b9cc e4dac16 1dec55c 77b5d4d cb447db ada9222 2d31b3a 77b5d4d e4dac16 b7a1c32 77b5d4d 19111b8 7915180 77b5d4d b7ae82a 81de640 c14b9cc e4dac16 c14b9cc e4dac16 a1fe107 e4dac16 85cf5ce e4dac16 8ab5eff 85cf5ce 8ab5eff 85cf5ce e4dac16 1dec55c 77b5d4d e4dac16 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import gc
import logging
import copy
import logging
import base64
import os
import torch
from huggingface_hub import HfApi
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pydantic import ValidationError
from starlette.exceptions import HTTPException
from torchaudio import functional as F
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers.pipelines.audio_utils import ffmpeg_read
from config import model_settings, InferenceConfig
from diarization_utils import SpeakerAligner, preprocess_inputs, diarize
import torch
logger = logging.getLogger(__name__)
HF_TOKEN = os.environ.get("HF_TOKEN")
class EndpointHandler():
def __init__(self, path=""):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = model_settings.asr_model
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, use_safetensors=True, cache_dir="cache"
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
self.processor = processor
self.asr_pipeline = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
generate_kwargs={"max_new_tokens": 400},
chunk_length_s=5,
stride_length_s=(1, 1),
)
if model_settings.diarization_model:
# diarization pipeline doesn't raise if there is no token
HfApi().whoami(model_settings.hf_token)
self.diarization_pipeline = Pipeline.from_pretrained(
checkpoint_path=model_settings.diarization_model,
use_auth_token=model_settings.hf_token,
)
self.diarization_pipeline.to(device)
else:
self.diarization_pipeline = None
def __call__(self, inputs):
file = inputs.pop("inputs")
file = base64.b64decode(file)
parameters = inputs.pop("parameters", {})
try:
parameters = InferenceConfig(**parameters)
except ValidationError as e:
logger.error(f"Error validating parameters: {e}")
raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}")
logger.info(f"inference parameters: {parameters}")
audio_nparray = ffmpeg_read(file, parameters.sampling_rate)
copy_audio = copy.deepcopy(audio_nparray)
if parameters.sampling_rate != 16000:
resampled = F.resample(torch.from_numpy(copy_audio), parameters.sampling_rate, 16000).numpy()
else:
resampled = audio_nparray
audio_tensor = torch.from_numpy(resampled).unsqueeze(0)
generate_kwargs = {
"task": parameters.task,
"language": parameters.language if parameters.language else "sv"
}
logger.info(f'params: {generate_kwargs}')
asr_inputs = {"array": resampled, "sampling_rate": 16000}
try:
asr_outputs = self.asr_pipeline(
asr_inputs,
generate_kwargs=generate_kwargs,
return_timestamps=True,
)
except RuntimeError as e:
logger.error(f"ASR inference error: {str(e)}")
raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
except Exception as e:
logger.error(f"Unknown error diring ASR inference: {str(e)}")
raise HTTPException(status_code=500, detail=f"Unknown error diring ASR inference: {str(e)}")
if self.diarization_pipeline:
try:
with ProgressHook() as progress_hook:
aligner = SpeakerAligner()
# Align the ASR outputs with diarization segments
# inputs, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate)
transcript = self.diarization_pipeline(
{"waveform": audio_tensor, "sample_rate": 16000},
hook=progress_hook,
num_speakers=parameters.num_speakers,
min_speakers=parameters.min_speakers,
max_speakers=parameters.max_speakers,
)
speaker_transcriptions = aligner.align(asr_outputs["text"], asr_outputs["chunks"], transcript)
except RuntimeError as e:
logger.error(f"Diarization inference error: {str(e)}")
raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
except Exception as e:
logger.error(f"Unknown error during diarization: {str(e)}")
raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")
try:
transcript_ = diarize(self.diarization_pipeline, file, parameters, asr_outputs)
except RuntimeError as e:
logger.error(f"Diarization_ inference error: {str(e)}")
raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
except Exception as e:
logger.error(f"Unknown_ error during diarization: {str(e)}")
raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")
else:
transcript_ = []
return {
"speakers": speaker_transcriptions,
"speakers_": transcript_,
"chunks": asr_outputs["chunks"],
"text": asr_outputs["text"],
} |