import gc import logging import copy import logging import base64 import os import torch from huggingface_hub import HfApi from pyannote.audio import Pipeline from pyannote.audio.pipelines.utils.hook import ProgressHook from pydantic import ValidationError from starlette.exceptions import HTTPException from torchaudio import functional as F from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor from transformers.pipelines.audio_utils import ffmpeg_read from config import model_settings, InferenceConfig from diarization_utils import SpeakerAligner, preprocess_inputs, diarize import torch logger = logging.getLogger(__name__) HF_TOKEN = os.environ.get("HF_TOKEN") class EndpointHandler(): def __init__(self, path=""): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id = model_settings.asr_model model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, use_safetensors=True, cache_dir="cache" ) model.to(device) processor = AutoProcessor.from_pretrained(model_id) self.processor = processor self.asr_pipeline = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, generate_kwargs={"max_new_tokens": 400}, chunk_length_s=5, stride_length_s=(1, 1), ) if model_settings.diarization_model: # diarization pipeline doesn't raise if there is no token HfApi().whoami(model_settings.hf_token) self.diarization_pipeline = Pipeline.from_pretrained( checkpoint_path=model_settings.diarization_model, use_auth_token=model_settings.hf_token, ) self.diarization_pipeline.to(device) else: self.diarization_pipeline = None def __call__(self, inputs): file = inputs.pop("inputs") file = base64.b64decode(file) parameters = inputs.pop("parameters", {}) try: parameters = InferenceConfig(**parameters) except ValidationError as e: logger.error(f"Error validating parameters: {e}") raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}") logger.info(f"inference parameters: {parameters}") audio_nparray = ffmpeg_read(file, parameters.sampling_rate) copy_audio = copy.deepcopy(audio_nparray) if parameters.sampling_rate != 16000: resampled = F.resample(torch.from_numpy(copy_audio), parameters.sampling_rate, 16000).numpy() else: resampled = audio_nparray audio_tensor = torch.from_numpy(resampled).unsqueeze(0) generate_kwargs = { "task": parameters.task, "language": parameters.language if parameters.language else "sv" } logger.info(f'params: {generate_kwargs}') asr_inputs = {"array": resampled, "sampling_rate": 16000} try: asr_outputs = self.asr_pipeline( asr_inputs, generate_kwargs=generate_kwargs, return_timestamps=True, ) except RuntimeError as e: logger.error(f"ASR inference error: {str(e)}") raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}") except Exception as e: logger.error(f"Unknown error diring ASR inference: {str(e)}") raise HTTPException(status_code=500, detail=f"Unknown error diring ASR inference: {str(e)}") if self.diarization_pipeline: try: with ProgressHook() as progress_hook: aligner = SpeakerAligner() # Align the ASR outputs with diarization segments # inputs, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate) transcript = self.diarization_pipeline( {"waveform": audio_tensor, "sample_rate": 16000}, hook=progress_hook, num_speakers=parameters.num_speakers, min_speakers=parameters.min_speakers, max_speakers=parameters.max_speakers, ) speaker_transcriptions = aligner.align(asr_outputs["text"], asr_outputs["chunks"], transcript) except RuntimeError as e: logger.error(f"Diarization inference error: {str(e)}") raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}") except Exception as e: logger.error(f"Unknown error during diarization: {str(e)}") raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}") try: transcript_ = diarize(self.diarization_pipeline, file, parameters, asr_outputs) except RuntimeError as e: logger.error(f"Diarization_ inference error: {str(e)}") raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}") except Exception as e: logger.error(f"Unknown_ error during diarization: {str(e)}") raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}") else: transcript_ = [] return { "speakers": speaker_transcriptions, "speakers_": transcript_, "chunks": asr_outputs["chunks"], "text": asr_outputs["text"], }