indiejoseph commited on
Commit
696dd1f
·
verified ·
1 Parent(s): 63b0fa7

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +389 -0
handler.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import onnxruntime
3
+ import numpy as np
4
+ import base64
5
+ import whisper
6
+ import re
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ from typing import List, Any, Dict
12
+ from models.ctc_model import CTCTransformerModel, PreTrainedModel, PretrainedConfig
13
+ from transformers import Wav2Vec2CTCTokenizer
14
+ import pycantonese
15
+
16
+
17
+ def parse_jyutping(jyutping: str) -> str:
18
+ """Helper to parse Jyutping string using pycantonese."""
19
+
20
+ # Move the tone number to the end if it's not already there
21
+ if jyutping and not jyutping[-1].isdigit():
22
+ match = re.search(r"([1-6])", jyutping)
23
+ if match:
24
+ tone = match.group(1)
25
+ jyutping = jyutping.replace(tone, "") + tone
26
+
27
+ try:
28
+ # Ensure pycantonese is installed and working
29
+ parsed_jyutping = pycantonese.parse_jyutping(jyutping)[0]
30
+ onset = parsed_jyutping.onset if parsed_jyutping.onset else ""
31
+ nucleus = parsed_jyutping.nucleus if parsed_jyutping.nucleus else ""
32
+ coda = parsed_jyutping.coda if parsed_jyutping.coda else ""
33
+ tone_val = str(parsed_jyutping.tone) if parsed_jyutping.tone else ""
34
+ # Construct the phoneme string, e.g., onset + nucleus + coda + tone
35
+ # This depends on the exact format your CTC model expects
36
+ return "".join([onset, nucleus, coda, tone_val]) # Simplified example
37
+ except Exception as e:
38
+ print(f"Failed to parse Jyutping '{jyutping}': {e}. Returning original.")
39
+ return jyutping
40
+
41
+
42
+ class CTCTransformerConfig(PretrainedConfig):
43
+ def __init__(
44
+ self,
45
+ vocab_size=100, # number of unique speech tokens
46
+ num_labels=50, # number of phoneme IDs (+1 for blank)
47
+ eos_token_id=2,
48
+ bos_token_id=1,
49
+ pad_token_id=0,
50
+ blank_id=0, # blank token id for CTC decoding
51
+ hidden_size=384,
52
+ num_hidden_layers=50,
53
+ num_attention_heads=4,
54
+ intermediate_size=2048,
55
+ dropout=0.1,
56
+ max_position_embeddings=1024,
57
+ ctc_loss_reduction="mean",
58
+ ctc_zero_infinity=True,
59
+ **kwargs,
60
+ ):
61
+ super().__init__(**kwargs)
62
+ self.vocab_size = vocab_size
63
+ self.num_labels = num_labels
64
+ self.hidden_size = hidden_size
65
+ self.num_hidden_layers = num_hidden_layers
66
+ self.num_attention_heads = num_attention_heads
67
+ self.intermediate_size = intermediate_size
68
+ self.max_position_embeddings = max_position_embeddings
69
+ self.dropout = dropout
70
+ self.eos_token_id = eos_token_id
71
+ self.bos_token_id = bos_token_id
72
+ self.pad_token_id = pad_token_id
73
+ self.blank_id = blank_id
74
+ self.ctc_loss_reduction = ctc_loss_reduction
75
+ self.ctc_zero_infinity = ctc_zero_infinity
76
+
77
+
78
+ class SinusoidalPositionEncoder(torch.nn.Module):
79
+ """Sinusoidal positional embeddings for sequences"""
80
+
81
+ def __init__(self, d_model=384, dropout_rate=0.1):
82
+ super().__init__()
83
+ self.d_model = d_model
84
+ self.dropout = nn.Dropout(p=dropout_rate)
85
+
86
+ def encode(
87
+ self,
88
+ positions: torch.Tensor = None,
89
+ depth: int = None,
90
+ dtype: torch.dtype = torch.float32,
91
+ ):
92
+ if depth is None:
93
+ depth = self.d_model
94
+
95
+ batch_size = positions.size(0)
96
+ positions = positions.type(dtype)
97
+ device = positions.device
98
+
99
+ # Handle even depth
100
+ depth_float = float(depth)
101
+ log_timescale_increment = torch.log(
102
+ torch.tensor([10000.0], dtype=dtype, device=device)
103
+ ) / (depth_float / 2.0 - 1.0)
104
+
105
+ # Create position encodings
106
+ inv_timescales = torch.exp(
107
+ torch.arange(depth_float // 2, device=device, dtype=dtype)
108
+ * (-log_timescale_increment)
109
+ )
110
+
111
+ # Create correct shapes for broadcasting
112
+ pos_seq = positions.view(-1, 1) # [batch_size*seq_len, 1]
113
+ inv_timescales = inv_timescales.view(1, -1) # [1, depth//2]
114
+
115
+ scaled_time = pos_seq * inv_timescales # [batch_size*seq_len, depth//2]
116
+
117
+ # Apply sin and cos
118
+ sin_encodings = torch.sin(scaled_time)
119
+ cos_encodings = torch.cos(scaled_time)
120
+
121
+ # Interleave sin and cos or concatenate
122
+ pos_encodings = torch.zeros(
123
+ positions.shape[0], positions.shape[1], depth, device=device, dtype=dtype
124
+ )
125
+
126
+ even_indices = torch.arange(0, depth, 2, device=device)
127
+ odd_indices = torch.arange(1, depth, 2, device=device)
128
+
129
+ pos_encodings[:, :, even_indices] = sin_encodings.view(
130
+ batch_size, -1, depth // 2
131
+ )
132
+ pos_encodings[:, :, odd_indices] = cos_encodings.view(
133
+ batch_size, -1, depth // 2
134
+ )
135
+
136
+ return pos_encodings
137
+
138
+ def forward(self, x):
139
+ batch_size, timesteps, input_dim = x.size()
140
+ # Create position indices [1, 2, ..., timesteps]
141
+ positions = (
142
+ torch.arange(1, timesteps + 1, device=x.device)
143
+ .unsqueeze(0)
144
+ .expand(batch_size, -1)
145
+ )
146
+ position_encoding = self.encode(positions, input_dim, x.dtype)
147
+
148
+ # Apply dropout to the sum
149
+ return self.dropout(x + position_encoding)
150
+
151
+
152
+ class CTCTransformerModel(PreTrainedModel):
153
+ config_class = CTCTransformerConfig
154
+
155
+ def __init__(self, config):
156
+ super().__init__(config)
157
+
158
+ self.embed = nn.Embedding(
159
+ config.vocab_size + 1,
160
+ config.hidden_size,
161
+ padding_idx=config.vocab_size,
162
+ )
163
+ encoder_layer = nn.TransformerEncoderLayer(
164
+ d_model=config.hidden_size,
165
+ nhead=config.num_attention_heads,
166
+ dim_feedforward=config.intermediate_size,
167
+ dropout=self.config.dropout,
168
+ activation="gelu",
169
+ batch_first=True,
170
+ )
171
+ self.encoder = nn.TransformerEncoder(
172
+ encoder_layer, num_layers=config.num_hidden_layers
173
+ )
174
+ self.pos_embed = SinusoidalPositionEncoder(
175
+ d_model=config.hidden_size, dropout_rate=config.dropout
176
+ )
177
+ self.norm = nn.LayerNorm(config.hidden_size)
178
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
179
+
180
+ def forward(
181
+ self,
182
+ input_ids,
183
+ attention_mask=None,
184
+ labels=None,
185
+ ):
186
+ # Embed the input tokens
187
+ x = self.embed(input_ids)
188
+
189
+ x = self.norm(x)
190
+
191
+ # Add positional embeddings
192
+ x = self.pos_embed(x)
193
+
194
+ # Create mask for transformer
195
+ if attention_mask is not None:
196
+ # PyTorch transformer expects mask where True indicates positions to be MASKED (padding)
197
+ # Transformers attention_mask uses:
198
+ # - 1 for tokens that are NOT MASKED (should be attended to)
199
+ # - 0 for tokens that ARE MASKED (padding)
200
+ # So, we need to invert the attention_mask to match PyTorch Transformer's expectation
201
+ src_key_padding_mask = attention_mask == 0
202
+ else:
203
+ src_key_padding_mask = None
204
+
205
+ # Pass through encoder with proper masking
206
+ x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
207
+
208
+ x = self.norm(x)
209
+
210
+ # Project to output labels
211
+ logits = self.classifier(x) # [B, T, num_labels]
212
+
213
+ loss = None
214
+ if labels is not None:
215
+ input_lengths = attention_mask.sum(-1)
216
+ # assuming that padded tokens are filled with -100
217
+ # when not being attended to
218
+ labels_mask = labels >= 0
219
+ target_lengths = labels_mask.sum(-1)
220
+ flattened_targets = labels.masked_select(labels_mask)
221
+
222
+ # ctc_loss doesn't support fp16
223
+ log_probs = nn.functional.log_softmax(
224
+ logits, dim=-1, dtype=torch.float32
225
+ ).transpose(0, 1)
226
+
227
+ with torch.backends.cudnn.flags(enabled=False):
228
+ loss = nn.functional.ctc_loss(
229
+ log_probs,
230
+ flattened_targets,
231
+ input_lengths,
232
+ target_lengths,
233
+ blank=0,
234
+ reduction=self.config.ctc_loss_reduction,
235
+ zero_infinity=self.config.ctc_zero_infinity,
236
+ )
237
+
238
+ return {"loss": loss, "logits": logits}
239
+
240
+ @torch.inference_mode()
241
+ def predict(self, input_ids: List[int]):
242
+ blank_id = self.config.blank_id
243
+ # Create attention mask with 1s (not masked) for all positions
244
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(
245
+ input_ids.device
246
+ )
247
+
248
+ with torch.no_grad():
249
+ x = self.embed(input_ids)
250
+ x = self.pos_embed(x) # Add positional embeddings
251
+ # Using the same masking convention as forward method
252
+ encoded = self.encoder(x, src_key_padding_mask=(attention_mask == 0))
253
+ logits = self.classifier(encoded) # [1, T, V]
254
+ log_probs = F.log_softmax(logits, dim=-1) # [1, T, V]
255
+ pred_ids = torch.argmax(log_probs, dim=-1).squeeze(0).tolist()
256
+
257
+ # Greedy decode with collapse
258
+ pred_phoneme_ids = []
259
+ prev = None
260
+
261
+ for idx in pred_ids:
262
+ if idx != blank_id and idx != prev:
263
+ pred_phoneme_ids.append(idx)
264
+ prev = idx
265
+
266
+ return pred_phoneme_ids
267
+
268
+
269
+ def load_speech_tokenizer(speech_tokenizer_path: str):
270
+ """Load speech tokenizer ONNX model."""
271
+ option = onnxruntime.SessionOptions()
272
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
273
+ option.intra_op_num_threads = 1
274
+ session = onnxruntime.InferenceSession(
275
+ speech_tokenizer_path,
276
+ sess_options=option,
277
+ providers=["CPUExecutionProvider"],
278
+ )
279
+ return session
280
+
281
+
282
+ def extract_speech_token(audio, speech_tokenizer_session):
283
+ """
284
+ Extract speech tokens from audio using speech tokenizer.
285
+
286
+ Args:
287
+ audio: audio signal (torch.Tensor or numpy.ndarray), shape (T,) at 16kHz
288
+ speech_tokenizer_session: ONNX speech tokenizer session
289
+
290
+ Returns:
291
+ speech_token: tensor of shape (1, num_tokens)
292
+ speech_token_len: tensor of shape (1,) with token sequence length
293
+ """
294
+ # Ensure audio is on CPU for processing
295
+ if isinstance(audio, torch.Tensor):
296
+ audio = audio.cpu().numpy()
297
+ elif isinstance(audio, np.ndarray):
298
+ pass
299
+ else:
300
+ raise ValueError("Audio must be torch.Tensor or numpy.ndarray")
301
+
302
+ # Convert to torch tensor for mel-spectrogram
303
+ audio_tensor = torch.from_numpy(audio).float().unsqueeze(0)
304
+
305
+ # Extract mel-spectrogram (whisper format)
306
+ feat = whisper.log_mel_spectrogram(audio_tensor, n_mels=128)
307
+
308
+ # Run speech tokenizer
309
+ speech_token = (
310
+ speech_tokenizer_session.run(
311
+ None,
312
+ {
313
+ speech_tokenizer_session.get_inputs()[0]
314
+ .name: feat.detach()
315
+ .cpu()
316
+ .numpy(),
317
+ speech_tokenizer_session.get_inputs()[1].name: np.array(
318
+ [feat.shape[2]], dtype=np.int32
319
+ ),
320
+ },
321
+ )[0]
322
+ .flatten()
323
+ .tolist()
324
+ )
325
+
326
+ speech_token = torch.tensor([speech_token], dtype=torch.int32)
327
+ speech_token_len = torch.tensor([len(speech_token[0])], dtype=torch.int32)
328
+
329
+ return speech_token, speech_token_len
330
+
331
+
332
+ class EndpointHandler:
333
+ def __init__(self, model_dir: str, **kwargs: Any):
334
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
335
+ self.speech_tokenizer_session = load_speech_tokenizer(
336
+ f"{model_dir}/speech_tokenizer_v2.onnx"
337
+ )
338
+ self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_dir)
339
+ self.model = (
340
+ CTCTransformerModel.from_pretrained(
341
+ model_dir,
342
+ torch_dtype=torch.bfloat16,
343
+ low_cpu_mem_usage=True,
344
+ trust_remote_code=True,
345
+ )
346
+ .eval()
347
+ .to(device)
348
+ )
349
+
350
+ def preprocess(self, inputs):
351
+ waveform, original_sampling_rate = torchaudio.load(inputs)
352
+
353
+ if original_sampling_rate != 16000:
354
+ resampler = torchaudio.transforms.Resample(
355
+ orig_freq=original_sampling_rate, new_freq=16000
356
+ )
357
+ audio_array = resampler(waveform).numpy().flatten()
358
+ else:
359
+ audio_array = waveform.numpy().flatten()
360
+ return audio_array
361
+
362
+ def __call__(self, data: Dict[str, Any]) -> List[str]:
363
+ # get inputs, assuming a base64 encoded wav file
364
+ inputs = data.pop("inputs", data)
365
+ # decode base64 file and save to temp file
366
+ audio = inputs["audio"]
367
+ audio_bytes = base64.b64decode(audio)
368
+ temp_wav_path = "/tmp/temp.wav"
369
+ with open(temp_wav_path, "wb") as f:
370
+ f.write(audio_bytes)
371
+
372
+ audio_array = self.preprocess(temp_wav_path)
373
+
374
+ # Extract speech tokens
375
+ speech_token, speech_token_len = extract_speech_token(
376
+ audio_array, self.speech_tokenizer_session
377
+ )
378
+
379
+ with torch.no_grad():
380
+ speech_token = speech_token.to(next(self.model.parameters()).device)
381
+ outputs = self.model.predict(speech_token)
382
+
383
+ transcription = self.tokenizer.decode(outputs, skip_special_tokens=True)
384
+ print(transcription)
385
+ transcription = " ".join(
386
+ [parse_jyutping(jyt) for jyt in transcription.split(" ")]
387
+ )
388
+
389
+ return {"transcription": transcription}