File size: 2,077 Bytes
a569fc7 7e06129 86eacb5 487e4ae a569fc7 aa2a138 a569fc7 487e4ae 86eacb5 a569fc7 a295a8d 7e06129 86eacb5 a295a8d a569fc7 86eacb5 a569fc7 86eacb5 a569fc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
"""Custom handler required to process model in custom way."""
from typing import Dict, Any, List
from transformers import pipeline
import torch
import base64
import json
import os
class EndpointHandler:
"""HF class for custom model processing for an inference endpoint."""
def __init__(self, path: str) -> None:
"""_summary_
Args:
path (str): Path to the model weights allowing to load the model.
"""
device = 0 if torch.cuda.is_available() else "cpu"
self.pipeline = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v2", # hardcode HF hub link
device=device,
chunk_length_s=30,
)
parent_dir = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(parent_dir, "token_mapping.json")) as file:
self.token_mapping = json.load(file)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Custom processing pipeline.
Args:
request (Dict[str, Any]): _description_
Returns:
List[Dict[str, Any]]: _description_
"""
inputs = data["inputs"]
audio = base64.b64decode(inputs["audio"]) # bytes
lang = inputs["language"] # ISO code
task = inputs["task"] # One of "translate", "transcribe"
# Set language and task (order: language, task, timestamp)
if lang is None:
lang_id = None # line 1576, https://github.com/huggingface/transformers/blob/v4.27.2/src/transformers/models/whisper/modeling_whisper.py
else:
lang_id = self.token_mapping[f"<|{lang}|>"]
task_id = self.token_mapping[f"<|{task}|>"]
timestamp_id = self.token_mapping["<|notimestamps|>"] # Required to output timestamps
forced_ids = [(1, lang_id), (2, task_id), (3, timestamp_id)]
# Model inference
output = self.pipeline(audio, return_timestamps="word", generate_kwargs={"forced_decoder_ids": forced_ids})
return output
|