|
|
"""Custom handler required to process model in custom way.""" |
|
|
from typing import Dict, Any, List |
|
|
from transformers import pipeline |
|
|
import torch |
|
|
import base64 |
|
|
import json |
|
|
import os |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
"""HF class for custom model processing for an inference endpoint.""" |
|
|
|
|
|
def __init__(self, path: str) -> None: |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
path (str): Path to the model weights allowing to load the model. |
|
|
""" |
|
|
device = 0 if torch.cuda.is_available() else "cpu" |
|
|
self.pipeline = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model="openai/whisper-large-v2", |
|
|
device=device, |
|
|
chunk_length_s=30, |
|
|
) |
|
|
parent_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
with open(os.path.join(parent_dir, "token_mapping.json")) as file: |
|
|
self.token_mapping = json.load(file) |
|
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
"""Custom processing pipeline. |
|
|
|
|
|
Args: |
|
|
request (Dict[str, Any]): _description_ |
|
|
|
|
|
Returns: |
|
|
List[Dict[str, Any]]: _description_ |
|
|
""" |
|
|
inputs = data["inputs"] |
|
|
audio = base64.b64decode(inputs["audio"]) |
|
|
lang = inputs["language"] |
|
|
task = inputs["task"] |
|
|
|
|
|
|
|
|
if lang is None: |
|
|
lang_id = None |
|
|
else: |
|
|
lang_id = self.token_mapping[f"<|{lang}|>"] |
|
|
task_id = self.token_mapping[f"<|{task}|>"] |
|
|
timestamp_id = self.token_mapping["<|notimestamps|>"] |
|
|
forced_ids = [(1, lang_id), (2, task_id), (3, timestamp_id)] |
|
|
|
|
|
|
|
|
output = self.pipeline(audio, return_timestamps="word", generate_kwargs={"forced_decoder_ids": forced_ids}) |
|
|
|
|
|
return output |
|
|
|