| import gc |
| import logging |
| import copy |
| import logging |
| import base64 |
| import os |
|
|
| import torch |
| from huggingface_hub import HfApi |
| from pyannote.audio import Pipeline |
| from pyannote.audio.pipelines.utils.hook import ProgressHook |
| from pydantic import ValidationError |
| from starlette.exceptions import HTTPException |
| from torchaudio import functional as F |
| from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor |
| from transformers.pipelines.audio_utils import ffmpeg_read |
|
|
| from config import model_settings, InferenceConfig |
| from diarization_utils import SpeakerAligner, preprocess_inputs, diarize |
| import torch |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
| HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
| model_id = model_settings.asr_model |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( |
| model_id, torch_dtype=torch_dtype, use_safetensors=True, cache_dir="cache" |
| ) |
| model.to(device) |
| processor = AutoProcessor.from_pretrained(model_id) |
| self.processor = processor |
|
|
| self.asr_pipeline = pipeline( |
| "automatic-speech-recognition", |
| model=model, |
| tokenizer=processor.tokenizer, |
| feature_extractor=processor.feature_extractor, |
| torch_dtype=torch_dtype, |
| device=device, |
| generate_kwargs={"max_new_tokens": 400}, |
| chunk_length_s=5, |
| stride_length_s=(1, 1), |
| ) |
|
|
| if model_settings.diarization_model: |
| |
| HfApi().whoami(model_settings.hf_token) |
| self.diarization_pipeline = Pipeline.from_pretrained( |
| checkpoint_path=model_settings.diarization_model, |
| use_auth_token=model_settings.hf_token, |
| ) |
| self.diarization_pipeline.to(device) |
| else: |
| self.diarization_pipeline = None |
|
|
| def __call__(self, inputs): |
| file = inputs.pop("inputs") |
| file = base64.b64decode(file) |
| parameters = inputs.pop("parameters", {}) |
| try: |
| parameters = InferenceConfig(**parameters) |
| except ValidationError as e: |
| logger.error(f"Error validating parameters: {e}") |
| raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}") |
|
|
| logger.info(f"inference parameters: {parameters}") |
| audio_nparray = ffmpeg_read(file, parameters.sampling_rate) |
| copy_audio = copy.deepcopy(audio_nparray) |
| if parameters.sampling_rate != 16000: |
| resampled = F.resample(torch.from_numpy(copy_audio), parameters.sampling_rate, 16000).numpy() |
| else: |
| resampled = audio_nparray |
| audio_tensor = torch.from_numpy(resampled).unsqueeze(0) |
|
|
| generate_kwargs = { |
| "task": parameters.task, |
| "language": parameters.language if parameters.language else "sv" |
| } |
| logger.info(f'params: {generate_kwargs}') |
| asr_inputs = {"array": resampled, "sampling_rate": 16000} |
| try: |
| asr_outputs = self.asr_pipeline( |
| asr_inputs, |
| generate_kwargs=generate_kwargs, |
| return_timestamps=True, |
| ) |
| except RuntimeError as e: |
| logger.error(f"ASR inference error: {str(e)}") |
| raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}") |
| except Exception as e: |
| logger.error(f"Unknown error diring ASR inference: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Unknown error diring ASR inference: {str(e)}") |
|
|
| if self.diarization_pipeline: |
| try: |
| with ProgressHook() as progress_hook: |
| aligner = SpeakerAligner() |
|
|
| |
| |
|
|
| transcript = self.diarization_pipeline( |
| {"waveform": audio_tensor, "sample_rate": 16000}, |
| hook=progress_hook, |
| num_speakers=parameters.num_speakers, |
| min_speakers=parameters.min_speakers, |
| max_speakers=parameters.max_speakers, |
| ) |
| speaker_transcriptions = aligner.align(asr_outputs["text"], asr_outputs["chunks"], transcript) |
| except RuntimeError as e: |
| logger.error(f"Diarization inference error: {str(e)}") |
| raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}") |
| except Exception as e: |
| logger.error(f"Unknown error during diarization: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}") |
| try: |
| transcript_ = diarize(self.diarization_pipeline, file, parameters, asr_outputs) |
| except RuntimeError as e: |
| logger.error(f"Diarization_ inference error: {str(e)}") |
| raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}") |
| except Exception as e: |
| logger.error(f"Unknown_ error during diarization: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}") |
| else: |
| transcript_ = [] |
|
|
| return { |
| "speakers": speaker_transcriptions, |
| "speakers_": transcript_, |
| "chunks": asr_outputs["chunks"], |
| "text": asr_outputs["text"], |
| } |