Safetensors
wav2vec2-bert
indiejoseph commited on
Commit
08663a8
·
verified ·
1 Parent(s): f875841

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +390 -8
handler.py CHANGED
@@ -1,10 +1,160 @@
1
- from typing import Dict, List, Any
2
- from transformers import pipeline
3
- from pipeline import SpeechToJyutpingPipeline
4
- from model import Wav2Vec2BertForCantonese
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from transformers.pipelines import PIPELINE_REGISTRY
6
- from transformers import Wav2Vec2CTCTokenizer, SeamlessM4TFeatureExtractor, pipeline
7
- from model import Wav2Vec2BertForCantonese
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  PIPELINE_REGISTRY.register_pipeline(
10
  "speech-to-jyutping",
@@ -12,6 +162,231 @@ PIPELINE_REGISTRY.register_pipeline(
12
  )
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  class EndpointHandler:
16
  def __init__(self, path="hon9kon9ize/wav2vec2bert-jyutping"):
17
  feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(path)
@@ -33,10 +408,17 @@ class EndpointHandler:
33
  Return:
34
  A :obj:`list` | `dict`: will be serialized and returned
35
  """
36
- # get inputs
37
  inputs = data.pop("inputs", data)
 
 
 
 
 
 
 
38
 
39
  # run normal prediction
40
- prediction = self.pipeline(inputs)
41
 
42
  return prediction
 
1
+ import base64
2
+ import re
3
+ from itertools import groupby
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple, Union, Dict, List, Any
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers.modeling_outputs import ModelOutput
9
+ from transformers import (
10
+ Wav2Vec2BertProcessor,
11
+ Wav2Vec2CTCTokenizer,
12
+ Wav2Vec2BertModel,
13
+ Wav2Vec2CTCTokenizer,
14
+ Wav2Vec2BertPreTrainedModel,
15
+ SeamlessM4TFeatureExtractor,
16
+ pipeline,
17
+ Pipeline,
18
+ )
19
+ from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import (
20
+ _HIDDEN_STATES_START_POSITION,
21
+ )
22
  from transformers.pipelines import PIPELINE_REGISTRY
23
+ import torchaudio
24
+
25
+ ONSETS = {
26
+ "b",
27
+ "d",
28
+ "g",
29
+ "gw",
30
+ "z",
31
+ "p",
32
+ "t",
33
+ "k",
34
+ "kw",
35
+ "c",
36
+ "m",
37
+ "n",
38
+ "ng",
39
+ "f",
40
+ "h",
41
+ "s",
42
+ "l",
43
+ "w",
44
+ "j",
45
+ }
46
+
47
+
48
+ class SpeechToJyutpingPipeline(Pipeline):
49
+ def _sanitize_parameters(self, **kwargs):
50
+ self.tone_tokenizer = Wav2Vec2CTCTokenizer(
51
+ "tone_vocab.json",
52
+ unk_token="[UNK]",
53
+ pad_token="[PAD]",
54
+ word_delimiter_token="|",
55
+ )
56
+ self.processor = Wav2Vec2BertProcessor(
57
+ feature_extractor=self.feature_extractor,
58
+ tokenizer=self.tokenizer,
59
+ )
60
+ self.onset_ids = {
61
+ self.processor.tokenizer.convert_tokens_to_ids(onset) for onset in ONSETS
62
+ }
63
+ preprocess_kwargs = {}
64
+ return preprocess_kwargs, {}, {}
65
+
66
+ def preprocess(self, inputs):
67
+ waveform, original_sampling_rate = torchaudio.load(inputs)
68
+ resampler = torchaudio.transforms.Resample(
69
+ orig_freq=original_sampling_rate, new_freq=16000
70
+ )
71
+ resampled_array = resampler(waveform).numpy().flatten()
72
+
73
+ input_features = self.processor(
74
+ resampled_array, sampling_rate=16_000, return_tensors="pt"
75
+ ).input_features
76
+ return {"input_features": input_features.to(self.device)}
77
+
78
+ def _forward(self, model_inputs):
79
+ outputs = self.model(
80
+ input_features=model_inputs["input_features"],
81
+ )
82
+ jyutping_logits = outputs.jyutping_logits
83
+ tone_logits = outputs.tone_logits
84
+
85
+ return {
86
+ "jyutping_logits": jyutping_logits,
87
+ "tone_logits": tone_logits,
88
+ "duration": model_inputs["input_features"],
89
+ }
90
+
91
+ def postprocess(self, model_outputs):
92
+ tone_logits = model_outputs["tone_logits"]
93
+ predicted_ids = torch.argmax(model_outputs["jyutping_logits"], dim=-1)
94
+ transcription = self.processor.batch_decode(predicted_ids)[0]
95
+
96
+ sample_rate = 16000
97
+ symbols = [w for w in transcription.split(" ") if len(w) > 0]
98
+ duration_sec = model_outputs["duration"] / sample_rate
99
+
100
+ ids_w_index = [(i, _id.item()) for i, _id in enumerate(predicted_ids[0])]
101
+ # remove entries which are just "padding" (i.e. no characers are recognized)
102
+ ids_w_index = [
103
+ i for i in ids_w_index if i[1] != self.processor.tokenizer.pad_token_id
104
+ ]
105
+ # now split the ids into groups of ids where each group represents a word
106
+ split_ids_index = [
107
+ list(group)[0]
108
+ for k, group in groupby(
109
+ ids_w_index,
110
+ lambda x: x[1] == self.processor.tokenizer.word_delimiter_token_id,
111
+ )
112
+ if not k
113
+ ]
114
+
115
+ assert len(split_ids_index) == len(
116
+ symbols
117
+ ) # make sure that there are the same number of id-groups as words. Otherwise something is wrong
118
+
119
+ transcription = ""
120
+ last_onset_index = -1
121
+ tone_probs = []
122
+
123
+ for cur_ids_w_index, cur_word in zip(split_ids_index, symbols):
124
+ symbol_index, symbol_token_id = cur_ids_w_index
125
+ if symbol_token_id in self.onset_ids:
126
+ if last_onset_index > -1:
127
+ tone_prob = torch.zeros(tone_logits.shape[-1]).to(
128
+ tone_logits.device
129
+ )
130
+ for i in range(last_onset_index, symbol_index):
131
+ tone_prob += tone_logits[0, i, :]
132
+ tone_prob[[0, 1, 2]] = 0.0 # set padding, unknown, sep to 0 prob
133
+ tone_probs.append(tone_prob[3:].softmax(dim=-1))
134
+ predicted_tone_id = torch.argmax(tone_prob.softmax(dim=-1)).item()
135
+ transcription += (
136
+ self.tone_tokenizer.decode([predicted_tone_id]) + "_"
137
+ )
138
+ transcription += "_" + cur_word
139
+ last_onset_index = symbol_index
140
+ else:
141
+ transcription += cur_word
142
+ if symbol_index == len(predicted_ids[0]) - 1:
143
+ # last word, add tone
144
+ tone_prob = torch.zeros(tone_logits.shape[-1]).to(tone_logits.device)
145
+ for i in range(last_onset_index, len(predicted_ids[0])):
146
+ tone_prob += tone_logits[0, i, :]
147
+ tone_prob[[0, 1, 2]] = 0.0 # set padding, unknown, sep to 0 prob
148
+ tone_probs.append(tone_prob[3:].softmax(dim=-1))
149
+ predicted_tone_id = torch.argmax(tone_prob.softmax(dim=-1)).item()
150
+ transcription += self.tone_tokenizer.decode([predicted_tone_id]) + "_"
151
+ transcription = re.sub(
152
+ r"\s+", " ", "".join(transcription).replace("_", " ").strip()
153
+ )
154
+ tone_probs = torch.stack(tone_probs).cpu().numpy()
155
+
156
+ return {"transcription": transcription, "tone_probs": tone_probs}
157
+
158
 
159
  PIPELINE_REGISTRY.register_pipeline(
160
  "speech-to-jyutping",
 
162
  )
163
 
164
 
165
+ @dataclass
166
+ class JuytpingOutput(ModelOutput):
167
+ """
168
+ Output type of Wav2Vec2BertForCantonese
169
+ """
170
+
171
+ loss: Optional[torch.FloatTensor] = None
172
+ jyutping_logits: torch.FloatTensor = None
173
+ tone_logits: torch.FloatTensor = None
174
+ jyutping_loss: Optional[torch.FloatTensor] = None
175
+ tone_loss: Optional[torch.FloatTensor] = None
176
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
177
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
178
+
179
+
180
+ class Wav2Vec2BertForCantonese(Wav2Vec2BertPreTrainedModel):
181
+ """
182
+ Wav2Vec2BertForCantonese is a Wav2Vec2BertModel with a language model head on top (a linear layer on top of the hidden-states output) that outputs Jyutping and tone logits.
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ config,
188
+ tone_vocab_size: int = 9,
189
+ ):
190
+ super().__init__(config)
191
+
192
+ self.wav2vec2_bert = Wav2Vec2BertModel(config)
193
+ self.dropout = nn.Dropout(config.final_dropout)
194
+ self.tone_vocab_size = tone_vocab_size
195
+
196
+ if config.vocab_size is None:
197
+ raise ValueError(
198
+ f"You are trying to instantiate {self.__class__} with a configuration that "
199
+ "does not define the vocabulary size of the language model head. Please "
200
+ "instantiate the model as follows: `Wav2Vec2BertForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
201
+ "or define `vocab_size` of your model's configuration."
202
+ )
203
+ output_hidden_size = (
204
+ config.output_hidden_size
205
+ if hasattr(config, "add_adapter") and config.add_adapter
206
+ else config.hidden_size
207
+ )
208
+ self.jyutping_head = nn.Linear(output_hidden_size, config.vocab_size)
209
+ self.tone_head = nn.Linear(output_hidden_size, tone_vocab_size)
210
+
211
+ # Initialize weights and apply final processing
212
+ self.post_init()
213
+
214
+ def forward(
215
+ self,
216
+ input_features: torch.Tensor,
217
+ attention_mask: Optional[torch.Tensor] = None,
218
+ output_attentions: Optional[bool] = None,
219
+ output_hidden_states: Optional[bool] = None,
220
+ return_dict: Optional[bool] = None,
221
+ jyutping_labels: Optional[torch.Tensor] = None,
222
+ tone_labels: Optional[torch.Tensor] = None,
223
+ ) -> Union[Tuple, JuytpingOutput]:
224
+ if (
225
+ jyutping_labels is not None
226
+ and jyutping_labels.max() >= self.config.vocab_size
227
+ ):
228
+ raise ValueError(
229
+ f"Label values must be <= vocab_size: {self.config.vocab_size}"
230
+ )
231
+
232
+ if tone_labels is not None and tone_labels.max() >= self.tone_vocab_size:
233
+ raise ValueError(
234
+ f"Label values must be <= tone_vocab_size: {self.tone_vocab_size}"
235
+ )
236
+
237
+ return_dict = (
238
+ return_dict if return_dict is not None else self.config.use_return_dict
239
+ )
240
+
241
+ outputs = self.wav2vec2_bert(
242
+ input_features,
243
+ attention_mask=attention_mask,
244
+ output_attentions=output_attentions,
245
+ output_hidden_states=output_hidden_states,
246
+ return_dict=return_dict,
247
+ )
248
+
249
+ hidden_states = outputs[0]
250
+ hidden_states = self.dropout(hidden_states)
251
+
252
+ jyutping_logits = self.jyutping_head(hidden_states)
253
+ tone_logits = self.tone_head(hidden_states)
254
+
255
+ loss = None
256
+ jyutping_loss = None
257
+ tone_loss = None
258
+
259
+ if jyutping_labels is not None and tone_labels is not None:
260
+ # retrieve loss input_lengths from attention_mask
261
+ attention_mask = (
262
+ attention_mask
263
+ if attention_mask is not None
264
+ else torch.ones(
265
+ input_features.shape[:2],
266
+ device=input_features.device,
267
+ dtype=torch.long,
268
+ )
269
+ )
270
+ input_lengths = self._get_feat_extract_output_lengths(
271
+ attention_mask.sum([-1])
272
+ ).to(torch.long)
273
+
274
+ # assuming that padded tokens are filled with -100
275
+ # when not being attended to
276
+ jyutping_labels_mask = jyutping_labels >= 0
277
+ jyutping_target_lengths = jyutping_labels_mask.sum(-1)
278
+ jyutping_flattened_targets = jyutping_labels.masked_select(
279
+ jyutping_labels_mask
280
+ )
281
+
282
+ # ctc_loss doesn't support fp16
283
+ jyutping_log_probs = nn.functional.log_softmax(
284
+ jyutping_logits, dim=-1, dtype=torch.float32
285
+ ).transpose(0, 1)
286
+
287
+ with torch.backends.cudnn.flags(enabled=False):
288
+ jyutping_loss = nn.functional.ctc_loss(
289
+ jyutping_log_probs,
290
+ jyutping_flattened_targets,
291
+ input_lengths,
292
+ jyutping_target_lengths,
293
+ blank=self.config.pad_token_id,
294
+ reduction=self.config.ctc_loss_reduction,
295
+ zero_infinity=self.config.ctc_zero_infinity,
296
+ )
297
+
298
+ tone_labels_mask = tone_labels >= 0
299
+ tone_target_lengths = tone_labels_mask.sum(-1)
300
+ tone_flattened_targets = tone_labels.masked_select(tone_labels_mask)
301
+
302
+ tone_log_probs = nn.functional.log_softmax(
303
+ tone_logits, dim=-1, dtype=torch.float32
304
+ ).transpose(0, 1)
305
+
306
+ with torch.backends.cudnn.flags(enabled=False):
307
+ tone_loss = nn.functional.ctc_loss(
308
+ tone_log_probs,
309
+ tone_flattened_targets,
310
+ input_lengths,
311
+ tone_target_lengths,
312
+ blank=self.config.pad_token_id,
313
+ reduction=self.config.ctc_loss_reduction,
314
+ zero_infinity=self.config.ctc_zero_infinity,
315
+ )
316
+
317
+ loss = jyutping_loss + tone_loss
318
+
319
+ if not return_dict:
320
+ output = (jyutping_logits, tone_logits) + outputs[
321
+ _HIDDEN_STATES_START_POSITION:
322
+ ]
323
+ return ((loss,) + output) if loss is not None else output
324
+
325
+ return JuytpingOutput(
326
+ loss=loss,
327
+ jyutping_logits=jyutping_logits,
328
+ tone_logits=tone_logits,
329
+ jyutping_loss=jyutping_loss,
330
+ tone_loss=tone_loss,
331
+ hidden_states=outputs.hidden_states,
332
+ attentions=outputs.attentions,
333
+ )
334
+
335
+ def inference(
336
+ self,
337
+ processor: Wav2Vec2BertProcessor,
338
+ tone_tokenizer: Wav2Vec2CTCTokenizer,
339
+ input_features: torch.Tensor,
340
+ attention_mask: Optional[torch.Tensor] = None,
341
+ ):
342
+ outputs = self.forward(
343
+ input_features=input_features,
344
+ attention_mask=attention_mask,
345
+ output_attentions=False,
346
+ output_hidden_states=False,
347
+ return_dict=True,
348
+ )
349
+ jyutping_logits = outputs.jyutping_logits
350
+ tone_logits = outputs.tone_logits
351
+ jyutping_pred_ids = torch.argmax(jyutping_logits, dim=-1)
352
+ tone_pred_ids = torch.argmax(tone_logits, dim=-1)
353
+ jyutping_pred = processor.batch_decode(jyutping_pred_ids)[0]
354
+ tone_pred = tone_tokenizer.batch_decode(tone_pred_ids)[0]
355
+ jyutping_list = jyutping_pred.split(" ")
356
+ tone_list = tone_pred.split(" ")
357
+ jyutping_output = []
358
+
359
+ for jypt in jyutping_list:
360
+ is_initial = jypt in ONSETS
361
+
362
+ if is_initial:
363
+ jypt = "_" + jypt
364
+ else:
365
+ jypt = jypt + "_"
366
+
367
+ jyutping_output.append(jypt)
368
+
369
+ jyutping_output = re.sub(
370
+ r"\s+", " ", "".join(jyutping_output).replace("_", " ").strip()
371
+ ).split(" ")
372
+
373
+ if len(tone_list) > len(jyutping_output):
374
+ tone_list = tone_list[: len(jyutping_output)]
375
+ elif len(tone_list) < len(jyutping_output):
376
+ # repeat the last tone if the length of tone list is shorter than the length of jyutping list
377
+ tone_list = tone_list + [tone_list[-1]] * (
378
+ len(jyutping_output) - len(tone_list)
379
+ )
380
+
381
+ return (
382
+ " ".join(
383
+ [f"{jypt}{tone}" for jypt, tone in zip(jyutping_output, tone_list)]
384
+ ),
385
+ jyutping_logits,
386
+ tone_logits,
387
+ )
388
+
389
+
390
  class EndpointHandler:
391
  def __init__(self, path="hon9kon9ize/wav2vec2bert-jyutping"):
392
  feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(path)
 
408
  Return:
409
  A :obj:`list` | `dict`: will be serialized and returned
410
  """
411
+ # get inputs, assuming a base64 encoded wav file
412
  inputs = data.pop("inputs", data)
413
+ # decode base64 file and save to temp file
414
+ audio = inputs["audio"]
415
+ audio_bytes = base64.b64decode(audio)
416
+ temp_wav_path = "/tmp/temp.wav"
417
+
418
+ with open(temp_wav_path, "wb") as f:
419
+ f.write(audio_bytes)
420
 
421
  # run normal prediction
422
+ prediction = self.pipeline(temp_wav_path)
423
 
424
  return prediction