Adriaanvh1 commited on
Commit
86eacb5
·
1 Parent(s): 2933ecf

Updated handler for language detection

Browse files
Files changed (2) hide show
  1. handler.py +14 -7
  2. token_mapping.json +108 -0
handler.py CHANGED
@@ -2,8 +2,8 @@
2
  from typing import Dict, Any, List
3
  from transformers import pipeline
4
  import torch
5
- from io import BytesIO
6
  import base64
 
7
 
8
 
9
  class EndpointHandler:
@@ -22,6 +22,8 @@ class EndpointHandler:
22
  device=device,
23
  chunk_length_s=30,
24
  )
 
 
25
 
26
 
27
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
@@ -35,14 +37,19 @@ class EndpointHandler:
35
  """
36
  inputs = data["inputs"]
37
  audio = base64.b64decode(inputs["audio"]) # bytes
38
- lang = inputs["language"]
39
  task = inputs["task"] # One of "translate", "transcribe"
40
 
41
- # Set language and task
42
- self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language=lang, task=task)
43
- self.pipeline
44
-
 
 
 
 
 
45
  # Model inference
46
- output = self.pipeline(audio, return_timestamps="word")
47
 
48
  return output
 
2
  from typing import Dict, Any, List
3
  from transformers import pipeline
4
  import torch
 
5
  import base64
6
+ import json
7
 
8
 
9
  class EndpointHandler:
 
22
  device=device,
23
  chunk_length_s=30,
24
  )
25
+ with open("token_mapping.json") as file:
26
+ self.token_mapping = json.load(file)
27
 
28
 
29
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
37
  """
38
  inputs = data["inputs"]
39
  audio = base64.b64decode(inputs["audio"]) # bytes
40
+ lang = inputs["language"] # ISO code
41
  task = inputs["task"] # One of "translate", "transcribe"
42
 
43
+ # Set language and task (order: language, task, timestamp)
44
+ if lang is None:
45
+ lang_id = None # line 1576, https://github.com/huggingface/transformers/blob/v4.27.2/src/transformers/models/whisper/modeling_whisper.py
46
+ else:
47
+ lang_id = self.token_mapping[f"<|{lang}|>"]
48
+ task_id = self.token_mapping[f"<|{task}|>"]
49
+ timestamp_id = self.token_mapping["<|notimestamps|>"] # Required to output timestamps
50
+ forced_ids = [(1, lang_id), (2, task_id), (3, timestamp_id)]
51
+
52
  # Model inference
53
+ output = self.pipeline(audio, return_timestamps="word", generate_kwargs={"forced_decoder_ids": forced_ids})
54
 
55
  return output
token_mapping.json ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|af|>": 50327,
3
+ "<|am|>": 50334,
4
+ "<|ar|>": 50272,
5
+ "<|as|>": 50350,
6
+ "<|az|>": 50304,
7
+ "<|ba|>": 50355,
8
+ "<|be|>": 50330,
9
+ "<|bg|>": 50292,
10
+ "<|bn|>": 50302,
11
+ "<|bo|>": 50347,
12
+ "<|br|>": 50309,
13
+ "<|bs|>": 50315,
14
+ "<|ca|>": 50270,
15
+ "<|cs|>": 50283,
16
+ "<|cy|>": 50297,
17
+ "<|da|>": 50285,
18
+ "<|de|>": 50261,
19
+ "<|el|>": 50281,
20
+ "<|en|>": 50259,
21
+ "<|es|>": 50262,
22
+ "<|et|>": 50307,
23
+ "<|eu|>": 50310,
24
+ "<|fa|>": 50300,
25
+ "<|fi|>": 50277,
26
+ "<|fo|>": 50338,
27
+ "<|fr|>": 50265,
28
+ "<|gl|>": 50319,
29
+ "<|gu|>": 50333,
30
+ "<|haw|>": 50352,
31
+ "<|ha|>": 50354,
32
+ "<|he|>": 50279,
33
+ "<|hi|>": 50276,
34
+ "<|hr|>": 50291,
35
+ "<|ht|>": 50339,
36
+ "<|hu|>": 50286,
37
+ "<|hy|>": 50312,
38
+ "<|id|>": 50275,
39
+ "<|is|>": 50311,
40
+ "<|it|>": 50274,
41
+ "<|ja|>": 50266,
42
+ "<|jw|>": 50356,
43
+ "<|ka|>": 50329,
44
+ "<|kk|>": 50316,
45
+ "<|km|>": 50323,
46
+ "<|kn|>": 50306,
47
+ "<|ko|>": 50264,
48
+ "<|la|>": 50294,
49
+ "<|lb|>": 50345,
50
+ "<|ln|>": 50353,
51
+ "<|lo|>": 50336,
52
+ "<|lt|>": 50293,
53
+ "<|lv|>": 50301,
54
+ "<|mg|>": 50349,
55
+ "<|mi|>": 50295,
56
+ "<|mk|>": 50308,
57
+ "<|ml|>": 50296,
58
+ "<|mn|>": 50314,
59
+ "<|mr|>": 50320,
60
+ "<|ms|>": 50282,
61
+ "<|mt|>": 50343,
62
+ "<|my|>": 50346,
63
+ "<|ne|>": 50313,
64
+ "<|nl|>": 50271,
65
+ "<|nn|>": 50342,
66
+ "<|nocaptions|>": 50362,
67
+ "<|notimestamps|>": 50363,
68
+ "<|no|>": 50288,
69
+ "<|oc|>": 50328,
70
+ "<|pa|>": 50321,
71
+ "<|pl|>": 50269,
72
+ "<|ps|>": 50340,
73
+ "<|pt|>": 50267,
74
+ "<|ro|>": 50284,
75
+ "<|ru|>": 50263,
76
+ "<|sa|>": 50344,
77
+ "<|sd|>": 50332,
78
+ "<|si|>": 50322,
79
+ "<|sk|>": 50298,
80
+ "<|sl|>": 50305,
81
+ "<|sn|>": 50324,
82
+ "<|so|>": 50326,
83
+ "<|sq|>": 50317,
84
+ "<|sr|>": 50303,
85
+ "<|startoflm|>": 50360,
86
+ "<|startofprev|>": 50361,
87
+ "<|startoftranscript|>": 50258,
88
+ "<|su|>": 50357,
89
+ "<|sv|>": 50273,
90
+ "<|sw|>": 50318,
91
+ "<|ta|>": 50287,
92
+ "<|te|>": 50299,
93
+ "<|tg|>": 50331,
94
+ "<|th|>": 50289,
95
+ "<|tk|>": 50341,
96
+ "<|tl|>": 50348,
97
+ "<|transcribe|>": 50359,
98
+ "<|translate|>": 50358,
99
+ "<|tr|>": 50268,
100
+ "<|tt|>": 50351,
101
+ "<|uk|>": 50280,
102
+ "<|ur|>": 50290,
103
+ "<|uz|>": 50337,
104
+ "<|vi|>": 50278,
105
+ "<|yi|>": 50335,
106
+ "<|yo|>": 50325,
107
+ "<|zh|>": 50260
108
+ }