File size: 2,077 Bytes
a569fc7
 
 
 
7e06129
86eacb5
487e4ae
a569fc7
 
 
 
 
 
 
 
 
 
 
 
 
 
aa2a138
a569fc7
 
 
487e4ae
 
86eacb5
a569fc7
 
 
 
 
 
 
 
 
 
 
a295a8d
7e06129
86eacb5
a295a8d
a569fc7
86eacb5
 
 
 
 
 
 
 
 
a569fc7
86eacb5
a569fc7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""Custom handler required to process model in custom way."""
from typing import Dict, Any, List
from transformers import pipeline
import torch
import base64
import json
import os


class EndpointHandler:
    """HF class for custom model processing for an inference endpoint."""

    def __init__(self, path: str) -> None:
        """_summary_

        Args:
            path (str): Path to the model weights allowing to load the model.
        """
        device = 0 if torch.cuda.is_available() else "cpu"
        self.pipeline = pipeline(
            "automatic-speech-recognition",
            model="openai/whisper-large-v2",  # hardcode HF hub link
            device=device,
            chunk_length_s=30,
            )
        parent_dir = os.path.dirname(os.path.abspath(__file__))
        with open(os.path.join(parent_dir, "token_mapping.json")) as file:
            self.token_mapping = json.load(file)


    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Custom processing pipeline.

        Args:
            request (Dict[str, Any]): _description_

        Returns:
            List[Dict[str, Any]]: _description_
        """
        inputs = data["inputs"]
        audio = base64.b64decode(inputs["audio"])  # bytes
        lang = inputs["language"]  # ISO code
        task = inputs["task"]  # One of "translate", "transcribe"

        # Set language and task (order: language, task, timestamp)
        if lang is None:
            lang_id = None  # line 1576, https://github.com/huggingface/transformers/blob/v4.27.2/src/transformers/models/whisper/modeling_whisper.py
        else:
            lang_id = self.token_mapping[f"<|{lang}|>"]
        task_id = self.token_mapping[f"<|{task}|>"]
        timestamp_id = self.token_mapping["<|notimestamps|>"]  # Required to output timestamps
        forced_ids = [(1, lang_id), (2, task_id), (3, timestamp_id)]
        
        # Model inference
        output = self.pipeline(audio, return_timestamps="word", generate_kwargs={"forced_decoder_ids": forced_ids})

        return output