alexwengg commited on
Commit
07826d2
·
verified ·
1 Parent(s): b5cf390

Delete inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -304
inference.py DELETED
@@ -1,304 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Inference script for Parakeet-TDT-CTC-110M CoreML model.
4
-
5
- This script demonstrates how to run inference using the converted CoreML models
6
- on Apple Silicon. It supports both TDT (Token-Duration Transducer) decoding for
7
- full transcription and CTC decoding for keyword spotting.
8
-
9
- Usage:
10
- uv run scripts/inference.py --audio audio.wav --mode tdt
11
- uv run scripts/inference.py --audio audio.wav --mode ctc
12
-
13
- Requirements:
14
- - macOS 13+ with Apple Silicon
15
- - Python 3.10+
16
- - coremltools
17
- """
18
-
19
- import argparse
20
- import json
21
- from pathlib import Path
22
-
23
- import coremltools as ct
24
- import numpy as np
25
-
26
-
27
- class ParakeetCoreML:
28
- """CoreML inference wrapper for Parakeet-TDT-CTC-110M."""
29
-
30
- def __init__(self, model_dir: str):
31
- """Load CoreML models from directory.
32
-
33
- Args:
34
- model_dir: Path to directory containing .mlpackage files
35
- """
36
- self.model_dir = Path(model_dir)
37
-
38
- # Load metadata
39
- with open(self.model_dir / "metadata.json") as f:
40
- self.metadata = json.load(f)
41
-
42
- # Load vocabulary
43
- with open(self.model_dir / "vocab.json") as f:
44
- vocab_dict = json.load(f)
45
- self.vocab = {int(k): v for k, v in vocab_dict.items()}
46
-
47
- self.blank_id = len(self.vocab) # Blank token is last
48
-
49
- # Load models
50
- print("Loading CoreML models...")
51
- self.preprocessor = ct.models.MLModel(
52
- str(self.model_dir / "Preprocessor.mlpackage")
53
- )
54
- self.encoder = ct.models.MLModel(
55
- str(self.model_dir / "Encoder.mlpackage")
56
- )
57
- self.ctc_head = ct.models.MLModel(
58
- str(self.model_dir / "CTCHead.mlpackage")
59
- )
60
- self.decoder = ct.models.MLModel(
61
- str(self.model_dir / "Decoder.mlpackage")
62
- )
63
- self.joint = ct.models.MLModel(
64
- str(self.model_dir / "JointDecision.mlpackage")
65
- )
66
- print("Models loaded successfully.")
67
-
68
- def load_audio(self, audio_path: str) -> np.ndarray:
69
- """Load audio file and convert to 16kHz mono.
70
-
71
- Args:
72
- audio_path: Path to audio file (WAV, MP3, etc.)
73
-
74
- Returns:
75
- Audio samples as float32 numpy array
76
- """
77
- try:
78
- import librosa
79
- audio, sr = librosa.load(audio_path, sr=16000, mono=True)
80
- return audio.astype(np.float32)
81
- except ImportError:
82
- # Fallback to scipy for WAV files
83
- from scipy.io import wavfile
84
- sr, audio = wavfile.read(audio_path)
85
-
86
- # Convert to mono if stereo
87
- if len(audio.shape) > 1:
88
- audio = audio.mean(axis=1)
89
-
90
- # Resample if needed
91
- if sr != 16000:
92
- from scipy import signal
93
- num_samples = int(len(audio) * 16000 / sr)
94
- audio = signal.resample(audio, num_samples)
95
-
96
- # Normalize to float32 [-1, 1]
97
- if audio.dtype == np.int16:
98
- audio = audio.astype(np.float32) / 32768.0
99
- elif audio.dtype == np.int32:
100
- audio = audio.astype(np.float32) / 2147483648.0
101
-
102
- return audio.astype(np.float32)
103
-
104
- def preprocess(self, audio: np.ndarray) -> tuple[np.ndarray, int]:
105
- """Convert audio to mel spectrogram.
106
-
107
- Args:
108
- audio: Audio samples as float32 array
109
-
110
- Returns:
111
- Tuple of (mel spectrogram, mel length)
112
- """
113
- audio_signal = audio.reshape(1, -1).astype(np.float32)
114
- audio_length = np.array([len(audio)], dtype=np.int32)
115
-
116
- result = self.preprocessor.predict({
117
- "audio_signal": audio_signal,
118
- "audio_length": audio_length
119
- })
120
-
121
- return result["mel"], int(result["mel_length"][0])
122
-
123
- def encode(self, mel: np.ndarray, mel_length: int) -> tuple[np.ndarray, int]:
124
- """Run encoder on mel spectrogram.
125
-
126
- Args:
127
- mel: Mel spectrogram from preprocessor
128
- mel_length: Length of mel spectrogram
129
-
130
- Returns:
131
- Tuple of (encoder output, encoder length)
132
- """
133
- result = self.encoder.predict({
134
- "mel": mel,
135
- "mel_length": np.array([mel_length], dtype=np.int32)
136
- })
137
-
138
- return result["encoder"], int(result["encoder_length"][0])
139
-
140
- def decode_ctc(self, encoder_output: np.ndarray) -> list[int]:
141
- """CTC greedy decoding.
142
-
143
- Args:
144
- encoder_output: Output from encoder
145
-
146
- Returns:
147
- List of token IDs (with duplicates and blanks removed)
148
- """
149
- result = self.ctc_head.predict({"encoder_output": encoder_output})
150
- log_probs = result["ctc_log_probs"]
151
-
152
- # Greedy decoding: take argmax at each timestep
153
- predictions = np.argmax(log_probs[0], axis=-1)
154
-
155
- # Remove duplicates and blanks
156
- tokens = []
157
- prev_token = self.blank_id
158
- for token in predictions:
159
- if token != self.blank_id and token != prev_token:
160
- tokens.append(int(token))
161
- prev_token = token
162
-
163
- return tokens
164
-
165
- def decode_tdt(self, encoder_output: np.ndarray, encoder_length: int) -> list[int]:
166
- """TDT (Token-Duration Transducer) decoding.
167
-
168
- Args:
169
- encoder_output: Output from encoder
170
- encoder_length: Length of encoder output
171
-
172
- Returns:
173
- List of token IDs
174
- """
175
- hidden_size = self.metadata["decoder_hidden_dim"]
176
- num_layers = self.metadata["decoder_num_layers"]
177
-
178
- # Initialize decoder state
179
- h = np.zeros((num_layers, 1, hidden_size), dtype=np.float32)
180
- c = np.zeros((num_layers, 1, hidden_size), dtype=np.float32)
181
-
182
- # Start with blank token
183
- targets = np.zeros((1, 1), dtype=np.int32)
184
- target_length = np.array([1], dtype=np.int32)
185
-
186
- tokens = []
187
- frame = 0
188
- max_tokens = 1000 # Safety limit
189
-
190
- while frame < encoder_length and len(tokens) < max_tokens:
191
- # Get decoder output
192
- decoder_result = self.decoder.predict({
193
- "targets": targets,
194
- "target_length": target_length,
195
- "h_in": h,
196
- "c_in": c
197
- })
198
-
199
- decoder_output = decoder_result["decoder"]
200
- h = decoder_result["h_out"]
201
- c = decoder_result["c_out"]
202
-
203
- # Get encoder step
204
- encoder_step = encoder_output[0, frame:frame+1, :].T.reshape(1, -1, 1)
205
- decoder_step = decoder_output.T.reshape(1, -1, 1)
206
-
207
- # Joint prediction
208
- joint_result = self.joint.predict({
209
- "encoder_step": encoder_step.astype(np.float32),
210
- "decoder_step": decoder_step.astype(np.float32)
211
- })
212
-
213
- token_id = int(joint_result["token_id"])
214
- duration_bin = int(joint_result["duration_bin"])
215
-
216
- # Duration bins: 0=0, 1=1, 2=2, 3=3, 4=4+
217
- durations = [0, 1, 2, 3, 4]
218
- duration = durations[min(duration_bin, 4)]
219
-
220
- if token_id != self.blank_id:
221
- tokens.append(token_id)
222
- # Update decoder input
223
- targets = np.array([[token_id]], dtype=np.int32)
224
-
225
- # Advance by duration (minimum 1 frame)
226
- frame += max(1, duration)
227
-
228
- return tokens
229
-
230
- def tokens_to_text(self, tokens: list[int]) -> str:
231
- """Convert token IDs to text.
232
-
233
- Args:
234
- tokens: List of token IDs
235
-
236
- Returns:
237
- Decoded text string
238
- """
239
- pieces = [self.vocab.get(t, "") for t in tokens]
240
- # Join and handle SentencePiece encoding
241
- text = "".join(pieces).replace("▁", " ").strip()
242
- return text
243
-
244
- def transcribe(self, audio_path: str, mode: str = "tdt") -> str:
245
- """Transcribe audio file.
246
-
247
- Args:
248
- audio_path: Path to audio file
249
- mode: Decoding mode - "tdt" for full transcription, "ctc" for keyword spotting
250
-
251
- Returns:
252
- Transcribed text
253
- """
254
- # Load and preprocess audio
255
- audio = self.load_audio(audio_path)
256
- mel, mel_length = self.preprocess(audio)
257
-
258
- # Encode
259
- encoder_output, encoder_length = self.encode(mel, mel_length)
260
-
261
- # Decode
262
- if mode == "ctc":
263
- tokens = self.decode_ctc(encoder_output)
264
- else:
265
- tokens = self.decode_tdt(encoder_output, encoder_length)
266
-
267
- # Convert to text
268
- text = self.tokens_to_text(tokens)
269
-
270
- return text
271
-
272
-
273
- def main():
274
- parser = argparse.ArgumentParser(
275
- description="Run inference with Parakeet-TDT-CTC-110M CoreML model"
276
- )
277
- parser.add_argument(
278
- "--audio", type=str, required=True,
279
- help="Path to audio file (WAV, MP3, etc.)"
280
- )
281
- parser.add_argument(
282
- "--model-dir", type=str, default=".",
283
- help="Directory containing CoreML model files"
284
- )
285
- parser.add_argument(
286
- "--mode", type=str, choices=["tdt", "ctc"], default="tdt",
287
- help="Decoding mode: 'tdt' for transcription, 'ctc' for keyword spotting"
288
- )
289
- args = parser.parse_args()
290
-
291
- # Load model
292
- model = ParakeetCoreML(args.model_dir)
293
-
294
- # Transcribe
295
- print(f"\nTranscribing: {args.audio}")
296
- print(f"Mode: {args.mode.upper()}")
297
- print("-" * 40)
298
-
299
- text = model.transcribe(args.audio, mode=args.mode)
300
- print(f"Result: {text}")
301
-
302
-
303
- if __name__ == "__main__":
304
- main()