"""Custom handler required to process model in custom way.""" from typing import Dict, Any, List from transformers import pipeline import torch import base64 import json import os class EndpointHandler: """HF class for custom model processing for an inference endpoint.""" def __init__(self, path: str) -> None: """_summary_ Args: path (str): Path to the model weights allowing to load the model. """ device = 0 if torch.cuda.is_available() else "cpu" self.pipeline = pipeline( "automatic-speech-recognition", model="openai/whisper-large-v2", # hardcode HF hub link device=device, chunk_length_s=30, ) parent_dir = os.path.dirname(os.path.abspath(__file__)) with open(os.path.join(parent_dir, "token_mapping.json")) as file: self.token_mapping = json.load(file) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """Custom processing pipeline. Args: request (Dict[str, Any]): _description_ Returns: List[Dict[str, Any]]: _description_ """ inputs = data["inputs"] audio = base64.b64decode(inputs["audio"]) # bytes lang = inputs["language"] # ISO code task = inputs["task"] # One of "translate", "transcribe" # Set language and task (order: language, task, timestamp) if lang is None: lang_id = None # line 1576, https://github.com/huggingface/transformers/blob/v4.27.2/src/transformers/models/whisper/modeling_whisper.py else: lang_id = self.token_mapping[f"<|{lang}|>"] task_id = self.token_mapping[f"<|{task}|>"] timestamp_id = self.token_mapping["<|notimestamps|>"] # Required to output timestamps forced_ids = [(1, lang_id), (2, task_id), (3, timestamp_id)] # Model inference output = self.pipeline(audio, return_timestamps="word", generate_kwargs={"forced_decoder_ids": forced_ids}) return output