Adriaanvh1 commited on
Commit
a569fc7
·
1 Parent(s): 587344a

custom handler

Browse files
Files changed (2) hide show
  1. handler.py +45 -0
  2. test_handler.py +16 -0
handler.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom handler required to process model in custom way."""
2
+ from typing import Dict, Any, List
3
+ from transformers import pipeline
4
+ import torch
5
+
6
+
7
+ class EndpointHandler:
8
+ """HF class for custom model processing for an inference endpoint."""
9
+
10
+ def __init__(self, path: str) -> None:
11
+ """_summary_
12
+
13
+ Args:
14
+ path (str): Path to the model weights allowing to load the model.
15
+ """
16
+ device = 0 if torch.cuda.is_available() else "cpu"
17
+ self.pipeline = pipeline(
18
+ "automatic-speech-recognition",
19
+ model=path,
20
+ device=device,
21
+ chunk_length_s=30,
22
+ )
23
+
24
+
25
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
26
+ """Custom processing pipeline.
27
+
28
+ Args:
29
+ request (Dict[str, Any]): _description_
30
+
31
+ Returns:
32
+ List[Dict[str, Any]]: _description_
33
+ """
34
+ audio = data.pop("audio") # bytes
35
+ lang = data.pop("language")
36
+ task = data.pop("task") # One of "translate", "transcribe"
37
+
38
+ # Set language and task
39
+ self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language=lang, task=task)
40
+ self.pipeline
41
+
42
+ # Model inference
43
+ output = self.pipeline(audio, return_timestamps="word")
44
+
45
+ return output
test_handler.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test the custom handler."""
2
+ from handler import EndpointHandler
3
+ import os
4
+
5
+ handler = EndpointHandler(path="whisper-tiny")
6
+
7
+ with open("frank.wav", "rb") as file:
8
+ audio = file.read()
9
+
10
+ payload = {"audio": audio, "language": "nl", "task": "transcribe"}
11
+
12
+ output = handler(payload)
13
+
14
+ print(output["text"])
15
+ print("\n")
16
+ print(output["chunks"])