alexwengg commited on
Commit
2d00bb3
·
verified ·
1 Parent(s): c1b4251

Upload 8 files

Browse files
Files changed (1) hide show
  1. export_gradient_descent.py +342 -0
export_gradient_descent.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Export Sortformer models with Gradient Descent configuration.
3
+
4
+ This creates models compatible with the Swift SortformerDiarizer interface.
5
+ Outputs both .mlpackage and .mlmodelc (compiled) versions.
6
+
7
+ Gradient Descent Config:
8
+ - chunk_len: 6
9
+ - chunk_right_context: 7 (higher quality, more context)
10
+ - chunk_left_context: 1
11
+ - fifo_len: 40
12
+ - spkcache_len: 188
13
+ - spkcache_update_period: 31
14
+ """
15
+ import os
16
+ import subprocess
17
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import numpy as np
22
+ import coremltools as ct
23
+ from nemo.collections.asr.models import SortformerEncLabelModel
24
+ from coreml_wrappers import PreEncoderWrapper, PreprocessorWrapper
25
+
26
+ # Gradient Descent configuration (matching Swift --gradient-descent)
27
+ GRADIENT_DESCENT_CONFIG = {
28
+ 'chunk_len': 6,
29
+ 'chunk_right_context': 7, # Higher quality
30
+ 'chunk_left_context': 1,
31
+ 'fifo_len': 40,
32
+ 'spkcache_len': 188,
33
+ 'spkcache_update_period': 31,
34
+ }
35
+
36
+ print("=" * 70)
37
+ print("Exporting Sortformer Models - Gradient Descent Config")
38
+ print("=" * 70)
39
+ print(f"Config: {GRADIENT_DESCENT_CONFIG}")
40
+
41
+ # Load model
42
+ print("\nLoading NeMo model...")
43
+ model = SortformerEncLabelModel.from_pretrained(
44
+ "nvidia/diar_streaming_sortformer_4spk-v2.1", map_location="cpu"
45
+ )
46
+ model.eval()
47
+
48
+ # Apply Gradient Descent config
49
+ modules = model.sortformer_modules
50
+ modules.chunk_len = GRADIENT_DESCENT_CONFIG['chunk_len']
51
+ modules.chunk_right_context = GRADIENT_DESCENT_CONFIG['chunk_right_context']
52
+ modules.chunk_left_context = GRADIENT_DESCENT_CONFIG['chunk_left_context']
53
+ modules.fifo_len = GRADIENT_DESCENT_CONFIG['fifo_len']
54
+ modules.spkcache_len = GRADIENT_DESCENT_CONFIG['spkcache_len']
55
+ modules.spkcache_update_period = GRADIENT_DESCENT_CONFIG['spkcache_update_period']
56
+
57
+ # Calculate dimensions
58
+ chunk_len = modules.chunk_len
59
+ input_chunk_time = (chunk_len + modules.chunk_left_context + modules.chunk_right_context) * modules.subsampling_factor
60
+ fc_d_model = modules.fc_d_model # 512
61
+ spkcache_len = modules.spkcache_len
62
+ fifo_len = modules.fifo_len
63
+
64
+ feat_dim = 128
65
+ pre_encode_out_len = input_chunk_time // modules.subsampling_factor
66
+ total_concat_len = spkcache_len + fifo_len + pre_encode_out_len
67
+
68
+ print(f"\nDimensions:")
69
+ print(f" Input chunk frames: {input_chunk_time} (= ({chunk_len}+{modules.chunk_left_context}+{modules.chunk_right_context})*{modules.subsampling_factor})")
70
+ print(f" Pre-encode output: {pre_encode_out_len}")
71
+ print(f" Total concat len: {total_concat_len}")
72
+ print(f" FC d_model: {fc_d_model}")
73
+ print(f" FIFO len: {fifo_len}")
74
+ print(f" Spkcache len: {spkcache_len}")
75
+
76
+ # Calculate audio samples needed for preprocessor
77
+ # For gradient descent: (6+1+7)*8 = 112 mel frames
78
+ # Audio samples = (112-1)*160 + 400 = 18160, but NeMo adds padding
79
+ # Empirically: 112 frames needs specific sample count
80
+ mel_stride = 160
81
+ mel_window = 400
82
+ # For 112 mel frames with NeMo padding
83
+ preprocessor_audio_samples = (input_chunk_time - 1) * mel_stride + mel_window
84
+ print(f" Preprocessor audio samples: {preprocessor_audio_samples}")
85
+
86
+ # Create output directory
87
+ output_dir = "coreml_models_gradient_descent"
88
+ os.makedirs(output_dir, exist_ok=True)
89
+
90
+ # =========================================================
91
+ # 0. Export Preprocessor (audio -> mel features)
92
+ # =========================================================
93
+ print("\n[0/3] Exporting Preprocessor...")
94
+
95
+ preprocessor_wrapper = PreprocessorWrapper(model.preprocessor)
96
+ preprocessor_wrapper.eval()
97
+
98
+ # Trace with correct audio sample count
99
+ audio_input = torch.randn(1, preprocessor_audio_samples)
100
+ audio_length = torch.tensor([preprocessor_audio_samples], dtype=torch.long)
101
+
102
+ traced_preprocessor = torch.jit.trace(preprocessor_wrapper, (audio_input, audio_length))
103
+
104
+ preprocessor_ml = ct.convert(
105
+ traced_preprocessor,
106
+ inputs=[
107
+ ct.TensorType(name="audio_signal", shape=audio_input.shape, dtype=np.float32),
108
+ ct.TensorType(name="length", shape=audio_length.shape, dtype=np.int32),
109
+ ],
110
+ outputs=[
111
+ ct.TensorType(name="features", dtype=np.float32),
112
+ ct.TensorType(name="feature_lengths", dtype=np.int32),
113
+ ],
114
+ minimum_deployment_target=ct.target.iOS16,
115
+ compute_precision=ct.precision.FLOAT32,
116
+ compute_units=ct.ComputeUnit.CPU_ONLY # CPU for FP32 precision
117
+ )
118
+
119
+ preprocessor_path = os.path.join(output_dir, "Pipeline_Preprocessor.mlpackage")
120
+ preprocessor_ml.save(preprocessor_path)
121
+ print(f" Saved {preprocessor_path}")
122
+
123
+ # =========================================================
124
+ # 1. Export PreEncoder
125
+ # =========================================================
126
+ print("\n[1/3] Exporting PreEncoder...")
127
+
128
+ input_chunk = torch.randn(1, input_chunk_time, feat_dim)
129
+ input_chunk_len = torch.tensor([input_chunk_time], dtype=torch.long)
130
+ input_spkcache = torch.randn(1, spkcache_len, fc_d_model)
131
+ input_spkcache_len = torch.tensor([spkcache_len], dtype=torch.long)
132
+ input_fifo = torch.randn(1, fifo_len, fc_d_model)
133
+ input_fifo_len = torch.tensor([fifo_len], dtype=torch.long)
134
+
135
+ pre_encoder = PreEncoderWrapper(model)
136
+ pre_encoder.eval()
137
+
138
+ traced_pre_encoder = torch.jit.trace(pre_encoder, (
139
+ input_chunk, input_chunk_len,
140
+ input_spkcache, input_spkcache_len,
141
+ input_fifo, input_fifo_len
142
+ ))
143
+
144
+ # Use names that match Swift expectations
145
+ pre_encoder_ml = ct.convert(
146
+ traced_pre_encoder,
147
+ inputs=[
148
+ ct.TensorType(name="chunk", shape=input_chunk.shape, dtype=np.float32),
149
+ ct.TensorType(name="chunk_lengths", shape=input_chunk_len.shape, dtype=np.int32),
150
+ ct.TensorType(name="spkcache", shape=input_spkcache.shape, dtype=np.float32),
151
+ ct.TensorType(name="spkcache_lengths", shape=input_spkcache_len.shape, dtype=np.int32),
152
+ ct.TensorType(name="fifo", shape=input_fifo.shape, dtype=np.float32),
153
+ ct.TensorType(name="fifo_lengths", shape=input_fifo_len.shape, dtype=np.int32),
154
+ ],
155
+ outputs=[
156
+ ct.TensorType(name="pre_encoder_embs", dtype=np.float32),
157
+ ct.TensorType(name="pre_encoder_lengths", dtype=np.int32),
158
+ ct.TensorType(name="chunk_embs_in", dtype=np.float32),
159
+ ct.TensorType(name="chunk_lens_in", dtype=np.int32),
160
+ ],
161
+ minimum_deployment_target=ct.target.iOS16,
162
+ compute_precision=ct.precision.FLOAT32,
163
+ compute_units=ct.ComputeUnit.ALL
164
+ )
165
+
166
+ pre_encoder_path = os.path.join(output_dir, "Pipeline_PreEncoder.mlpackage")
167
+ pre_encoder_ml.save(pre_encoder_path)
168
+ print(f" Saved {pre_encoder_path}")
169
+
170
+ # =========================================================
171
+ # 2. Export Fixed Head (with identity ops to preserve embeddings)
172
+ # =========================================================
173
+ print("\n[2/3] Exporting Fixed Head...")
174
+
175
+
176
+ class FixedSortformerHead(nn.Module):
177
+ """Head wrapper that forces chunk_pre_encoder_embs to be computed."""
178
+
179
+ def __init__(self, model):
180
+ super().__init__()
181
+ self.model = model
182
+ self.identity_scale = nn.Parameter(torch.ones(1), requires_grad=False)
183
+
184
+ def forward(self, pre_encoder_embs, pre_encoder_lengths, chunk_embs_in, chunk_lens_in):
185
+ # Frontend encoder
186
+ spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths = self.model.frontend_encoder(
187
+ processed_signal=pre_encoder_embs,
188
+ processed_signal_length=pre_encoder_lengths,
189
+ bypass_pre_encode=True,
190
+ )
191
+
192
+ # Forward inference
193
+ speaker_preds = self.model.forward_infer(
194
+ spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths
195
+ )
196
+
197
+ # Force the embedding to be computed (prevents optimization)
198
+ chunk_pre_encoder_embs = chunk_embs_in * self.identity_scale
199
+ chunk_pre_encoder_lengths = chunk_lens_in + 0
200
+
201
+ return speaker_preds, chunk_pre_encoder_embs, chunk_pre_encoder_lengths
202
+
203
+
204
+ head = FixedSortformerHead(model)
205
+ head.eval()
206
+
207
+ # Input shapes for head - must match PreEncoder output
208
+ pre_encoder_embs = torch.randn(1, total_concat_len, fc_d_model)
209
+ pre_encoder_lengths = torch.tensor([total_concat_len], dtype=torch.long)
210
+ chunk_embs_in = torch.randn(1, pre_encode_out_len, fc_d_model)
211
+ chunk_lens_in = torch.tensor([pre_encode_out_len], dtype=torch.long)
212
+
213
+ traced_head = torch.jit.trace(head, (
214
+ pre_encoder_embs, pre_encoder_lengths,
215
+ chunk_embs_in, chunk_lens_in
216
+ ))
217
+
218
+ head_ml = ct.convert(
219
+ traced_head,
220
+ inputs=[
221
+ ct.TensorType(name="pre_encoder_embs", shape=pre_encoder_embs.shape, dtype=np.float32),
222
+ ct.TensorType(name="pre_encoder_lengths", shape=pre_encoder_lengths.shape, dtype=np.int32),
223
+ ct.TensorType(name="chunk_embs_in", shape=chunk_embs_in.shape, dtype=np.float32),
224
+ ct.TensorType(name="chunk_lens_in", shape=chunk_lens_in.shape, dtype=np.int32),
225
+ ],
226
+ outputs=[
227
+ ct.TensorType(name="speaker_preds", dtype=np.float32),
228
+ ct.TensorType(name="chunk_pre_encoder_embs", dtype=np.float32),
229
+ ct.TensorType(name="chunk_pre_encoder_lengths", dtype=np.int32),
230
+ ],
231
+ minimum_deployment_target=ct.target.iOS16,
232
+ compute_precision=ct.precision.FLOAT16,
233
+ compute_units=ct.ComputeUnit.ALL
234
+ )
235
+
236
+ head_path = os.path.join(output_dir, "Pipeline_Head_Fixed.mlpackage")
237
+ head_ml.save(head_path)
238
+ print(f" Saved {head_path}")
239
+
240
+ # =========================================================
241
+ # 3. Compile to .mlmodelc
242
+ # =========================================================
243
+ print("\n[3/3] Compiling to .mlmodelc...")
244
+
245
+ def compile_model(mlpackage_path):
246
+ """Compile .mlpackage to .mlmodelc using xcrun coremlcompiler."""
247
+ output_dir_path = os.path.dirname(mlpackage_path)
248
+ model_name = os.path.basename(mlpackage_path).replace('.mlpackage', '')
249
+
250
+ try:
251
+ result = subprocess.run(
252
+ ['xcrun', 'coremlcompiler', 'compile', mlpackage_path, output_dir_path],
253
+ capture_output=True,
254
+ text=True,
255
+ check=True
256
+ )
257
+ mlmodelc_path = os.path.join(output_dir_path, f"{model_name}.mlmodelc")
258
+ if os.path.exists(mlmodelc_path):
259
+ print(f" Compiled {mlmodelc_path}")
260
+ return True
261
+ else:
262
+ print(f" Warning: {mlmodelc_path} not found after compilation")
263
+ return False
264
+ except subprocess.CalledProcessError as e:
265
+ print(f" Error compiling {mlpackage_path}: {e.stderr}")
266
+ return False
267
+ except FileNotFoundError:
268
+ print(" Error: xcrun not found. Make sure Xcode Command Line Tools are installed.")
269
+ return False
270
+
271
+ compile_model(preprocessor_path)
272
+ compile_model(pre_encoder_path)
273
+ compile_model(head_path)
274
+
275
+ # =========================================================
276
+ # Verification
277
+ # =========================================================
278
+ print("\n" + "=" * 70)
279
+ print("Verification")
280
+ print("=" * 70)
281
+
282
+ # Test PreEncoder
283
+ test_chunk = np.random.randn(1, input_chunk_time, feat_dim).astype(np.float32)
284
+ test_chunk_len = np.array([input_chunk_time], dtype=np.int32)
285
+ test_spkcache = np.zeros((1, spkcache_len, fc_d_model), dtype=np.float32)
286
+ test_spkcache_len = np.array([0], dtype=np.int32)
287
+ test_fifo = np.zeros((1, fifo_len, fc_d_model), dtype=np.float32)
288
+ test_fifo_len = np.array([0], dtype=np.int32)
289
+
290
+ pre_out = pre_encoder_ml.predict({
291
+ 'chunk': test_chunk,
292
+ 'chunk_lengths': test_chunk_len,
293
+ 'spkcache': test_spkcache,
294
+ 'spkcache_lengths': test_spkcache_len,
295
+ 'fifo': test_fifo,
296
+ 'fifo_lengths': test_fifo_len
297
+ })
298
+
299
+ print(f"PreEncoder output shapes:")
300
+ print(f" pre_encoder_embs: {pre_out['pre_encoder_embs'].shape}")
301
+ print(f" chunk_embs_in: {pre_out['chunk_embs_in'].shape}")
302
+
303
+ # Test Head
304
+ head_out = head_ml.predict({
305
+ 'pre_encoder_embs': pre_out['pre_encoder_embs'],
306
+ 'pre_encoder_lengths': pre_out['pre_encoder_lengths'],
307
+ 'chunk_embs_in': pre_out['chunk_embs_in'],
308
+ 'chunk_lens_in': pre_out['chunk_lens_in']
309
+ })
310
+
311
+ print(f"\nHead output shapes:")
312
+ print(f" speaker_preds: {head_out['speaker_preds'].shape}")
313
+ print(f" chunk_pre_encoder_embs: {head_out['chunk_pre_encoder_embs'].shape}")
314
+
315
+ # Verify embedding preservation
316
+ if np.isclose(pre_out['chunk_embs_in'][0,0,0], head_out['chunk_pre_encoder_embs'][0,0,0], atol=0.01):
317
+ print("\n✓ Embedding [0,0,0] preserved correctly!")
318
+ else:
319
+ print(f"\n✗ WARNING: Embedding [0,0,0] corrupted!")
320
+
321
+ print("\n" + "=" * 70)
322
+ print("Export Complete!")
323
+ print("=" * 70)
324
+ print(f"\nModels saved to: {output_dir}/")
325
+ print(f" .mlpackage files:")
326
+ print(f" - Pipeline_Preprocessor.mlpackage")
327
+ print(f" - Pipeline_PreEncoder.mlpackage")
328
+ print(f" - Pipeline_Head_Fixed.mlpackage")
329
+ print(f" .mlmodelc files (compiled):")
330
+ print(f" - Pipeline_Preprocessor.mlmodelc")
331
+ print(f" - Pipeline_PreEncoder.mlmodelc")
332
+ print(f" - Pipeline_Head_Fixed.mlmodelc")
333
+ print(f"\nConfiguration (Gradient Descent):")
334
+ for k, v in GRADIENT_DESCENT_CONFIG.items():
335
+ print(f" {k}: {v}")
336
+ print(f"\nInput shapes:")
337
+ print(f" Preprocessor audio: [1, {preprocessor_audio_samples}]")
338
+ print(f" PreEncoder chunk: [1, {input_chunk_time}, {feat_dim}]")
339
+ print(f" PreEncoder spkcache: [1, {spkcache_len}, {fc_d_model}]")
340
+ print(f" PreEncoder fifo: [1, {fifo_len}, {fc_d_model}]")
341
+ print(f" Head pre_encoder_embs: [1, {total_concat_len}, {fc_d_model}]")
342
+ print(f" Head chunk_embs_in: [1, {pre_encode_out_len}, {fc_d_model}]")