Alex-Wengg commited on
Commit
46daad3
·
1 Parent(s): 20c78b6

remove TDT decoder conversion script (not CTC)

Browse files
convert/parakeet-tdt-ctc-110m/convert_tdt_decoder.py DELETED
@@ -1,323 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Convert Parakeet TDT-CTC 110M decoder components to CoreML.
4
-
5
- This script exports the TDT decoder (prediction network) and joint network
6
- with the SAME format as the working 0.6B model:
7
- - JointDecision outputs token_id, token_prob, duration (argmax done inside)
8
- - Uses shape [1, dim, 1] for encoder/decoder steps
9
- - Matches the interface expected by TdtDecoderV3
10
- """
11
-
12
- import argparse
13
- import os
14
- import torch
15
- import torch.nn.functional as F
16
- import coremltools as ct
17
- import numpy as np
18
- from pathlib import Path
19
-
20
- # NeMo imports
21
- import nemo.collections.asr as nemo_asr
22
-
23
-
24
- def get_model_config(model):
25
- """Extract model configuration."""
26
- encoder_dim = None
27
- pred_hidden = 640 # Default for parakeet models
28
- num_layers = 1
29
- vocab_size = 1024
30
- num_durations = 5
31
-
32
- # Get encoder dimension
33
- if hasattr(model, 'encoder'):
34
- encoder = model.encoder
35
- if hasattr(encoder, 'd_model'):
36
- encoder_dim = encoder.d_model
37
- elif hasattr(encoder, '_feat_out'):
38
- encoder_dim = encoder._feat_out
39
-
40
- # Get decoder config
41
- if hasattr(model, 'decoder'):
42
- decoder = model.decoder
43
- if hasattr(decoder, 'pred_hidden'):
44
- pred_hidden = decoder.pred_hidden
45
- if hasattr(decoder, 'pred_rnn_layers'):
46
- num_layers = decoder.pred_rnn_layers
47
-
48
- # Get joint config
49
- if hasattr(model, 'joint'):
50
- joint = model.joint
51
- if hasattr(joint, 'num_extra_outputs'):
52
- num_durations = joint.num_extra_outputs
53
- if hasattr(joint, 'num_classes'):
54
- vocab_size = joint.num_classes - num_durations
55
-
56
- return {
57
- 'encoder_dim': encoder_dim,
58
- 'pred_hidden': pred_hidden,
59
- 'num_layers': num_layers,
60
- 'vocab_size': vocab_size,
61
- 'num_durations': num_durations,
62
- }
63
-
64
-
65
- class DecoderWrapper(torch.nn.Module):
66
- """
67
- Wrapper for the RNNT/TDT decoder (prediction network).
68
-
69
- Matches 0.6B format:
70
- - Input: targets[1,1], target_lengths[1], h_in[num_layers,1,pred_hidden], c_in[...]
71
- - Output: decoder_output[1,pred_hidden,2], h_out[...], c_out[...]
72
- """
73
-
74
- def __init__(self, decoder, pred_hidden):
75
- super().__init__()
76
- self.decoder = decoder
77
- self.pred_hidden = pred_hidden
78
-
79
- def forward(self, targets, target_lengths, h_in, c_in):
80
- """
81
- Args:
82
- targets: [1, 1] - previous token ID
83
- target_lengths: [1] - always 1
84
- h_in: [num_layers, 1, pred_hidden]
85
- c_in: [num_layers, 1, pred_hidden]
86
- Returns:
87
- decoder_output: [1, pred_hidden, 2] - prediction network output (transposed)
88
- h_out: [num_layers, 1, pred_hidden]
89
- c_out: [num_layers, 1, pred_hidden]
90
- """
91
- state = (h_in, c_in)
92
- # pred_output shape: [batch, time, pred_hidden] = [1, 1, pred_hidden]
93
- pred_output, new_state = self.decoder.predict(targets, state=state, add_sos=False)
94
- h_out, c_out = new_state
95
-
96
- # Transpose to [batch, pred_hidden, time] and concat two time steps
97
- # (0.6B outputs [1, 640, 2] - we match this by duplicating)
98
- pred_transposed = pred_output.transpose(1, 2) # [1, pred_hidden, 1]
99
- decoder_output = torch.cat([pred_transposed, pred_transposed], dim=2) # [1, pred_hidden, 2]
100
-
101
- return decoder_output, h_out, c_out
102
-
103
-
104
- class JointWrapper(torch.nn.Module):
105
- """
106
- Wrapper for the TDT joint network with internal argmax.
107
-
108
- Matches 0.6B format:
109
- - Input: encoder_step[1,encoder_dim,1], decoder_step[1,pred_hidden,1]
110
- - Output: token_id[1,1,1], token_prob[1,1,1], duration[1,1,1]
111
- """
112
-
113
- def __init__(self, joint, vocab_size, num_durations=5):
114
- super().__init__()
115
- self.joint = joint
116
- self.vocab_size = vocab_size
117
- self.num_durations = num_durations
118
-
119
- def forward(self, encoder_step, decoder_step):
120
- """
121
- Args:
122
- encoder_step: [1, encoder_dim, 1]
123
- decoder_step: [1, pred_hidden, 1]
124
- Returns:
125
- token_id: [1, 1, 1] - argmax token ID
126
- token_prob: [1, 1, 1] - probability of selected token
127
- duration: [1, 1, 1] - argmax duration bin
128
- """
129
- # Transpose to [batch, 1, dim] for joint network
130
- enc = encoder_step.transpose(1, 2) # [1, 1, encoder_dim]
131
- dec = decoder_step.transpose(1, 2) # [1, 1, pred_hidden]
132
-
133
- # Run joint network
134
- # Joint output: [1, 1, 1, vocab_size + 1 (blank) + num_durations]
135
- joint_out = self.joint.joint(enc, dec)
136
-
137
- # Debug: print shape on first call
138
- if not hasattr(self, '_debug_printed'):
139
- self._debug_printed = True
140
- print(f" Joint output shape: {joint_out.shape}")
141
- print(f" Expected: vocab={self.vocab_size} + blank=1 + durations={self.num_durations} = {self.vocab_size + 1 + self.num_durations}")
142
-
143
- # Split: token logits include vocab + blank, durations are separate
144
- # vocab_size = 1024 tokens (0-1023), blank = index 1024, durations = indices 1025+
145
- num_tokens = self.vocab_size + 1 # Include blank at vocab_size
146
- logits = joint_out[..., :num_tokens] # [1, 1, 1, vocab_size + 1]
147
- duration_logits = joint_out[..., num_tokens:] # [1, 1, 1, num_durations]
148
-
149
- # Apply softmax and get probabilities
150
- probs = F.softmax(logits, dim=-1)
151
-
152
- # Argmax for token
153
- token_id = torch.argmax(logits, dim=-1, keepdim=True) # [1, 1, 1, 1]
154
- token_id = token_id.squeeze(-1) # [1, 1, 1]
155
-
156
- # Get probability of selected token
157
- token_prob = torch.gather(probs, -1, token_id.unsqueeze(-1)) # [1, 1, 1, 1]
158
- token_prob = token_prob.squeeze(-1) # [1, 1, 1]
159
-
160
- # Argmax for duration
161
- duration = torch.argmax(duration_logits, dim=-1, keepdim=False) # [1, 1, 1]
162
-
163
- return token_id.int(), token_prob, duration.int()
164
-
165
-
166
- def convert_decoder(model, config, output_dir: Path):
167
- """Convert decoder to CoreML."""
168
- print(f"Converting Decoder...")
169
- print(f" pred_hidden={config['pred_hidden']}, num_layers={config['num_layers']}")
170
-
171
- wrapper = DecoderWrapper(model.decoder, config['pred_hidden'])
172
- wrapper.eval()
173
-
174
- # Create example inputs
175
- targets = torch.zeros(1, 1, dtype=torch.long)
176
- target_lengths = torch.ones(1, dtype=torch.long)
177
- h_in = torch.zeros(config['num_layers'], 1, config['pred_hidden'])
178
- c_in = torch.zeros(config['num_layers'], 1, config['pred_hidden'])
179
-
180
- # Trace the model
181
- with torch.no_grad():
182
- traced = torch.jit.trace(wrapper, (targets, target_lengths, h_in, c_in))
183
-
184
- # Convert to CoreML
185
- mlmodel = ct.convert(
186
- traced,
187
- inputs=[
188
- ct.TensorType(name="targets", shape=(1, 1), dtype=np.int32),
189
- ct.TensorType(name="target_lengths", shape=(1,), dtype=np.int32),
190
- ct.TensorType(name="h_in", shape=(config['num_layers'], 1, config['pred_hidden']), dtype=np.float32),
191
- ct.TensorType(name="c_in", shape=(config['num_layers'], 1, config['pred_hidden']), dtype=np.float32),
192
- ],
193
- outputs=[
194
- ct.TensorType(name="decoder_output"),
195
- ct.TensorType(name="h_out"),
196
- ct.TensorType(name="c_out"),
197
- ],
198
- minimum_deployment_target=ct.target.iOS17,
199
- compute_precision=ct.precision.FLOAT16,
200
- )
201
-
202
- # Add metadata
203
- mlmodel.author = "Fluid Inference"
204
- mlmodel.short_description = "Hybrid TDT Decoder (110M)"
205
-
206
- # Save
207
- output_path = output_dir / "Decoder.mlpackage"
208
- mlmodel.save(str(output_path))
209
- print(f" Saved to {output_path}")
210
-
211
- return mlmodel
212
-
213
-
214
- def convert_joint(model, config, output_dir: Path):
215
- """Convert joint network to CoreML."""
216
- print(f"Converting JointDecision...")
217
- print(f" encoder_dim={config['encoder_dim']}, pred_hidden={config['pred_hidden']}")
218
- print(f" vocab_size={config['vocab_size']}, num_durations={config['num_durations']}")
219
-
220
- wrapper = JointWrapper(
221
- model.joint,
222
- vocab_size=config['vocab_size'],
223
- num_durations=config['num_durations']
224
- )
225
- wrapper.eval()
226
-
227
- # Create example inputs - shape [1, dim, 1]
228
- encoder_step = torch.randn(1, config['encoder_dim'], 1)
229
- decoder_step = torch.randn(1, config['pred_hidden'], 1)
230
-
231
- # Trace the model
232
- with torch.no_grad():
233
- traced = torch.jit.trace(wrapper, (encoder_step, decoder_step))
234
-
235
- # Convert to CoreML
236
- mlmodel = ct.convert(
237
- traced,
238
- inputs=[
239
- ct.TensorType(name="encoder_step", shape=(1, config['encoder_dim'], 1), dtype=np.float32),
240
- ct.TensorType(name="decoder_step", shape=(1, config['pred_hidden'], 1), dtype=np.float32),
241
- ],
242
- outputs=[
243
- ct.TensorType(name="token_id"),
244
- ct.TensorType(name="token_prob"),
245
- ct.TensorType(name="duration"),
246
- ],
247
- minimum_deployment_target=ct.target.iOS17,
248
- compute_precision=ct.precision.FLOAT16,
249
- )
250
-
251
- # Add metadata
252
- mlmodel.author = "Fluid Inference"
253
- mlmodel.short_description = "Hybrid Joint Decision (110M)"
254
-
255
- # Save
256
- output_path = output_dir / "JointDecision.mlpackage"
257
- mlmodel.save(str(output_path))
258
- print(f" Saved to {output_path}")
259
-
260
- return mlmodel
261
-
262
-
263
- def main():
264
- parser = argparse.ArgumentParser(description="Convert TDT decoder to CoreML (0.6B format)")
265
- parser.add_argument(
266
- "--model-name",
267
- default="nvidia/parakeet-tdt_ctc-110m",
268
- help="NeMo model name or path"
269
- )
270
- parser.add_argument(
271
- "--output-dir",
272
- type=Path,
273
- default=Path("./output"),
274
- help="Output directory for CoreML models"
275
- )
276
- args = parser.parse_args()
277
-
278
- # Create output directory
279
- args.output_dir.mkdir(parents=True, exist_ok=True)
280
-
281
- # Load model
282
- print(f"Loading model: {args.model_name}")
283
- model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(args.model_name)
284
- model.eval()
285
-
286
- # Get model configuration
287
- config = get_model_config(model)
288
-
289
- # Auto-detect encoder dim if not found
290
- if config['encoder_dim'] is None:
291
- print("Auto-detecting encoder dimension...")
292
- dummy_audio = torch.randn(1, 16000)
293
- dummy_length = torch.tensor([16000])
294
- with torch.no_grad():
295
- enc_out, enc_len = model.encoder(
296
- audio_signal=dummy_audio,
297
- length=dummy_length
298
- )
299
- config['encoder_dim'] = enc_out.shape[-1]
300
-
301
- print(f"\nModel config:")
302
- for k, v in config.items():
303
- print(f" {k}: {v}")
304
-
305
- # Convert components
306
- print()
307
- convert_decoder(model, config, args.output_dir)
308
- convert_joint(model, config, args.output_dir)
309
-
310
- print("\nConversion complete!")
311
- print(f"Models saved to: {args.output_dir}")
312
- print("\nNext steps:")
313
- print("1. Compile to .mlmodelc:")
314
- print(f" cd {args.output_dir}")
315
- print(" xcrun coremlcompiler compile Decoder.mlpackage .")
316
- print(" xcrun coremlcompiler compile JointDecision.mlpackage .")
317
- print("2. Copy to model cache:")
318
- print(" cp -r Decoder.mlmodelc JointDecision.mlmodelc ~/Library/Application\\ Support/FluidAudio/Models/parakeet-ctc-110m-coreml/")
319
- print("3. Test with: swift run fluidaudio hybrid-earnings-benchmark --max-files 1")
320
-
321
-
322
- if __name__ == "__main__":
323
- main()