Commit
·
86eacb5
1
Parent(s):
2933ecf
Updated handler for language detection
Browse files- handler.py +14 -7
- token_mapping.json +108 -0
handler.py
CHANGED
|
@@ -2,8 +2,8 @@
|
|
| 2 |
from typing import Dict, Any, List
|
| 3 |
from transformers import pipeline
|
| 4 |
import torch
|
| 5 |
-
from io import BytesIO
|
| 6 |
import base64
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class EndpointHandler:
|
|
@@ -22,6 +22,8 @@ class EndpointHandler:
|
|
| 22 |
device=device,
|
| 23 |
chunk_length_s=30,
|
| 24 |
)
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
@@ -35,14 +37,19 @@ class EndpointHandler:
|
|
| 35 |
"""
|
| 36 |
inputs = data["inputs"]
|
| 37 |
audio = base64.b64decode(inputs["audio"]) # bytes
|
| 38 |
-
lang = inputs["language"]
|
| 39 |
task = inputs["task"] # One of "translate", "transcribe"
|
| 40 |
|
| 41 |
-
# Set language and task
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
# Model inference
|
| 46 |
-
output = self.pipeline(audio, return_timestamps="word")
|
| 47 |
|
| 48 |
return output
|
|
|
|
| 2 |
from typing import Dict, Any, List
|
| 3 |
from transformers import pipeline
|
| 4 |
import torch
|
|
|
|
| 5 |
import base64
|
| 6 |
+
import json
|
| 7 |
|
| 8 |
|
| 9 |
class EndpointHandler:
|
|
|
|
| 22 |
device=device,
|
| 23 |
chunk_length_s=30,
|
| 24 |
)
|
| 25 |
+
with open("token_mapping.json") as file:
|
| 26 |
+
self.token_mapping = json.load(file)
|
| 27 |
|
| 28 |
|
| 29 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
|
|
| 37 |
"""
|
| 38 |
inputs = data["inputs"]
|
| 39 |
audio = base64.b64decode(inputs["audio"]) # bytes
|
| 40 |
+
lang = inputs["language"] # ISO code
|
| 41 |
task = inputs["task"] # One of "translate", "transcribe"
|
| 42 |
|
| 43 |
+
# Set language and task (order: language, task, timestamp)
|
| 44 |
+
if lang is None:
|
| 45 |
+
lang_id = None # line 1576, https://github.com/huggingface/transformers/blob/v4.27.2/src/transformers/models/whisper/modeling_whisper.py
|
| 46 |
+
else:
|
| 47 |
+
lang_id = self.token_mapping[f"<|{lang}|>"]
|
| 48 |
+
task_id = self.token_mapping[f"<|{task}|>"]
|
| 49 |
+
timestamp_id = self.token_mapping["<|notimestamps|>"] # Required to output timestamps
|
| 50 |
+
forced_ids = [(1, lang_id), (2, task_id), (3, timestamp_id)]
|
| 51 |
+
|
| 52 |
# Model inference
|
| 53 |
+
output = self.pipeline(audio, return_timestamps="word", generate_kwargs={"forced_decoder_ids": forced_ids})
|
| 54 |
|
| 55 |
return output
|
token_mapping.json
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<|af|>": 50327,
|
| 3 |
+
"<|am|>": 50334,
|
| 4 |
+
"<|ar|>": 50272,
|
| 5 |
+
"<|as|>": 50350,
|
| 6 |
+
"<|az|>": 50304,
|
| 7 |
+
"<|ba|>": 50355,
|
| 8 |
+
"<|be|>": 50330,
|
| 9 |
+
"<|bg|>": 50292,
|
| 10 |
+
"<|bn|>": 50302,
|
| 11 |
+
"<|bo|>": 50347,
|
| 12 |
+
"<|br|>": 50309,
|
| 13 |
+
"<|bs|>": 50315,
|
| 14 |
+
"<|ca|>": 50270,
|
| 15 |
+
"<|cs|>": 50283,
|
| 16 |
+
"<|cy|>": 50297,
|
| 17 |
+
"<|da|>": 50285,
|
| 18 |
+
"<|de|>": 50261,
|
| 19 |
+
"<|el|>": 50281,
|
| 20 |
+
"<|en|>": 50259,
|
| 21 |
+
"<|es|>": 50262,
|
| 22 |
+
"<|et|>": 50307,
|
| 23 |
+
"<|eu|>": 50310,
|
| 24 |
+
"<|fa|>": 50300,
|
| 25 |
+
"<|fi|>": 50277,
|
| 26 |
+
"<|fo|>": 50338,
|
| 27 |
+
"<|fr|>": 50265,
|
| 28 |
+
"<|gl|>": 50319,
|
| 29 |
+
"<|gu|>": 50333,
|
| 30 |
+
"<|haw|>": 50352,
|
| 31 |
+
"<|ha|>": 50354,
|
| 32 |
+
"<|he|>": 50279,
|
| 33 |
+
"<|hi|>": 50276,
|
| 34 |
+
"<|hr|>": 50291,
|
| 35 |
+
"<|ht|>": 50339,
|
| 36 |
+
"<|hu|>": 50286,
|
| 37 |
+
"<|hy|>": 50312,
|
| 38 |
+
"<|id|>": 50275,
|
| 39 |
+
"<|is|>": 50311,
|
| 40 |
+
"<|it|>": 50274,
|
| 41 |
+
"<|ja|>": 50266,
|
| 42 |
+
"<|jw|>": 50356,
|
| 43 |
+
"<|ka|>": 50329,
|
| 44 |
+
"<|kk|>": 50316,
|
| 45 |
+
"<|km|>": 50323,
|
| 46 |
+
"<|kn|>": 50306,
|
| 47 |
+
"<|ko|>": 50264,
|
| 48 |
+
"<|la|>": 50294,
|
| 49 |
+
"<|lb|>": 50345,
|
| 50 |
+
"<|ln|>": 50353,
|
| 51 |
+
"<|lo|>": 50336,
|
| 52 |
+
"<|lt|>": 50293,
|
| 53 |
+
"<|lv|>": 50301,
|
| 54 |
+
"<|mg|>": 50349,
|
| 55 |
+
"<|mi|>": 50295,
|
| 56 |
+
"<|mk|>": 50308,
|
| 57 |
+
"<|ml|>": 50296,
|
| 58 |
+
"<|mn|>": 50314,
|
| 59 |
+
"<|mr|>": 50320,
|
| 60 |
+
"<|ms|>": 50282,
|
| 61 |
+
"<|mt|>": 50343,
|
| 62 |
+
"<|my|>": 50346,
|
| 63 |
+
"<|ne|>": 50313,
|
| 64 |
+
"<|nl|>": 50271,
|
| 65 |
+
"<|nn|>": 50342,
|
| 66 |
+
"<|nocaptions|>": 50362,
|
| 67 |
+
"<|notimestamps|>": 50363,
|
| 68 |
+
"<|no|>": 50288,
|
| 69 |
+
"<|oc|>": 50328,
|
| 70 |
+
"<|pa|>": 50321,
|
| 71 |
+
"<|pl|>": 50269,
|
| 72 |
+
"<|ps|>": 50340,
|
| 73 |
+
"<|pt|>": 50267,
|
| 74 |
+
"<|ro|>": 50284,
|
| 75 |
+
"<|ru|>": 50263,
|
| 76 |
+
"<|sa|>": 50344,
|
| 77 |
+
"<|sd|>": 50332,
|
| 78 |
+
"<|si|>": 50322,
|
| 79 |
+
"<|sk|>": 50298,
|
| 80 |
+
"<|sl|>": 50305,
|
| 81 |
+
"<|sn|>": 50324,
|
| 82 |
+
"<|so|>": 50326,
|
| 83 |
+
"<|sq|>": 50317,
|
| 84 |
+
"<|sr|>": 50303,
|
| 85 |
+
"<|startoflm|>": 50360,
|
| 86 |
+
"<|startofprev|>": 50361,
|
| 87 |
+
"<|startoftranscript|>": 50258,
|
| 88 |
+
"<|su|>": 50357,
|
| 89 |
+
"<|sv|>": 50273,
|
| 90 |
+
"<|sw|>": 50318,
|
| 91 |
+
"<|ta|>": 50287,
|
| 92 |
+
"<|te|>": 50299,
|
| 93 |
+
"<|tg|>": 50331,
|
| 94 |
+
"<|th|>": 50289,
|
| 95 |
+
"<|tk|>": 50341,
|
| 96 |
+
"<|tl|>": 50348,
|
| 97 |
+
"<|transcribe|>": 50359,
|
| 98 |
+
"<|translate|>": 50358,
|
| 99 |
+
"<|tr|>": 50268,
|
| 100 |
+
"<|tt|>": 50351,
|
| 101 |
+
"<|uk|>": 50280,
|
| 102 |
+
"<|ur|>": 50290,
|
| 103 |
+
"<|uz|>": 50337,
|
| 104 |
+
"<|vi|>": 50278,
|
| 105 |
+
"<|yi|>": 50335,
|
| 106 |
+
"<|yo|>": 50325,
|
| 107 |
+
"<|zh|>": 50260
|
| 108 |
+
}
|