whisper-tiny-handler / handler.py
Adriaanvh1's picture
whisper large
aa2a138
"""Custom handler required to process model in custom way."""
from typing import Dict, Any, List
from transformers import pipeline
import torch
import base64
import json
import os
class EndpointHandler:
"""HF class for custom model processing for an inference endpoint."""
def __init__(self, path: str) -> None:
"""_summary_
Args:
path (str): Path to the model weights allowing to load the model.
"""
device = 0 if torch.cuda.is_available() else "cpu"
self.pipeline = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v2", # hardcode HF hub link
device=device,
chunk_length_s=30,
)
parent_dir = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(parent_dir, "token_mapping.json")) as file:
self.token_mapping = json.load(file)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Custom processing pipeline.
Args:
request (Dict[str, Any]): _description_
Returns:
List[Dict[str, Any]]: _description_
"""
inputs = data["inputs"]
audio = base64.b64decode(inputs["audio"]) # bytes
lang = inputs["language"] # ISO code
task = inputs["task"] # One of "translate", "transcribe"
# Set language and task (order: language, task, timestamp)
if lang is None:
lang_id = None # line 1576, https://github.com/huggingface/transformers/blob/v4.27.2/src/transformers/models/whisper/modeling_whisper.py
else:
lang_id = self.token_mapping[f"<|{lang}|>"]
task_id = self.token_mapping[f"<|{task}|>"]
timestamp_id = self.token_mapping["<|notimestamps|>"] # Required to output timestamps
forced_ids = [(1, lang_id), (2, task_id), (3, timestamp_id)]
# Model inference
output = self.pipeline(audio, return_timestamps="word", generate_kwargs={"forced_decoder_ids": forced_ids})
return output