korallll commited on
Commit
86f4581
·
verified ·
1 Parent(s): b85921b

Add TensorRT export script + ONNX export

Browse files
Files changed (2) hide show
  1. spectra-aasist3.onnx +3 -0
  2. trt_spectra_aasist3.py +565 -0
spectra-aasist3.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f05c29a01ad80c702b32654db87c2aa6e467c11c67b6d47f2fac873f846cae9
3
+ size 1279022864
trt_spectra_aasist3.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """TensorRT export + inference for SpectraAASIST3. Self-contained (no shared package).
3
+
4
+ Exports only the model's `net` (all preprocessing already lives in the original
5
+ `score_batch`) with a fixed time axis and a dynamic batch axis, builds a FP16
6
+ engine (FP32 fallback if parity drifts), finds the fastest batch on the current
7
+ GPU, and exposes a drop-in `SpectraAASIST3TRT` class identical to the PyTorch path except
8
+ the neural forward runs on TensorRT.
9
+
10
+ CLI:
11
+ python trt_spectra-aasist3.py export # ONNX -> engine -> parity -> sweep -> sidecar
12
+ python trt_spectra-aasist3.py sweep # re-run the batch sweep, update sidecar
13
+ python trt_spectra-aasist3.py parity # PyTorch vs TRT parity report
14
+ python trt_spectra-aasist3.py score AUDIO.wav
15
+
16
+ Pin the GPU with: CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=<n>
17
+ """
18
+ from __future__ import annotations
19
+ import argparse
20
+ import io
21
+ import json
22
+ import os
23
+ import sys
24
+ import time
25
+ from pathlib import Path
26
+
27
+ import numpy as np
28
+
29
+ HERE = Path(__file__).resolve().parent
30
+ sys.path.insert(0, str(HERE)) # import dir-local entry + _net
31
+ # Pin GPU deterministically: PCI order makes CUDA indices match `nvidia-smi`.
32
+ os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
33
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", os.environ.get("SSB_TRT_GPU", "3"))
34
+
35
+ import torch # noqa: E402 (after env pin)
36
+
37
+ # ======================= per-model config =======================
38
+ ENTRY_MODULE = "spectra_aasist3" # module exposing the AntiSpoofingModel subclass
39
+ ENTRY_CLASS = "SpectraAASIST3" # the subclass name
40
+ SLUG = "spectra-aasist3"
41
+ PARITY_DATASET = "InTheWild" # sibling dataset dir with data/*.parquet
42
+ MAX_BATCH_CAP = 24 # VRAM ceiling for the profile + sweep
43
+ PARITY_CHUNK = 8 # safe mini-batch for the parity comparison
44
+ OPSET = 17
45
+ # Keep FP16 iff it preserves the score RANKING (Spearman) -> identical EER.
46
+ # This is the metric that matters for the benchmark and is scale-invariant, so
47
+ # small absolute-logit drift (harmless for EER) does not force an FP32 fallback.
48
+ # FP16 is also mandatory for the largest models (FP32 would not fit in VRAM).
49
+ PARITY_SPEARMAN_TOL = 0.9999 # min Spearman rank-corr to keep FP16
50
+ PARITY_FLOOR = 0.99 # hard floor: below this the engine is wrong -> FAIL
51
+ PARITY_MAD_TOL = 1e-2 # informational only
52
+ PARITY_R_TOL = 0.9999 # informational only
53
+ FORCE_FP32 = False
54
+ FORCE_FP16 = False # skip FP32 (for giant models where FP32 won't fit VRAM)
55
+ DYNAMO_EXPORT = False # use the dynamo exporter + external data (models >2GB)
56
+ ALLOW_ORT_FALLBACK = False
57
+ # ================================================================
58
+
59
+ from importlib import import_module as _imp # noqa: E402
60
+ _OrigClass = getattr(_imp(ENTRY_MODULE), ENTRY_CLASS)
61
+
62
+
63
+ # ----------------------------------------------------------------------------
64
+ # helpers
65
+ # ----------------------------------------------------------------------------
66
+ def gpu_slug() -> str:
67
+ name = torch.cuda.get_device_name(0)
68
+ return name.replace("NVIDIA ", "").replace("GeForce ", "").strip().replace(" ", "_")
69
+
70
+
71
+ def load_model():
72
+ m = _OrigClass()
73
+ m.load()
74
+ return m
75
+
76
+
77
+ def real_audio(n=64):
78
+ """Decode up to n real 16 kHz mono utterances from PARITY_DATASET/data/*.parquet."""
79
+ import pyarrow.parquet as pq
80
+ import soundfile as sf
81
+ import torchaudio.functional as AF
82
+
83
+ data_dir = HERE.parent / PARITY_DATASET / "data"
84
+ files = sorted(data_dir.glob("test-*.parquet")) or sorted(data_dir.glob("*.parquet"))
85
+ out = []
86
+ for f in files:
87
+ t = pq.read_table(f)
88
+ col = "audio" if "audio" in t.column_names else t.column_names[0]
89
+ for row in t.column(col).to_pylist():
90
+ b = row["bytes"] if isinstance(row, dict) else row
91
+ if not b:
92
+ continue
93
+ a, sr = sf.read(io.BytesIO(b), dtype="float32")
94
+ if a.ndim > 1:
95
+ a = a.mean(1)
96
+ a = np.ascontiguousarray(a, dtype=np.float32)
97
+ if sr != 16000:
98
+ a = AF.resample(torch.from_numpy(a), sr, 16000).numpy().astype(np.float32)
99
+ sr = 16000
100
+ out.append((a, sr))
101
+ if len(out) >= n:
102
+ return out
103
+ if not out:
104
+ raise RuntimeError(f"no parity audio found under {data_dir}")
105
+ return out
106
+
107
+
108
+ class _Capture:
109
+ """Wrap net: pass through to the real net, record input tensor + output."""
110
+
111
+ def __init__(self, net):
112
+ self.net = net
113
+ self.x = None
114
+ self.out = None
115
+
116
+ def __call__(self, x, *a, **k):
117
+ self.x = x.detach()
118
+ self.out = self.net(x, *a, **k)
119
+ return self.out
120
+
121
+ def __getattr__(self, name):
122
+ return getattr(self.net, name)
123
+
124
+
125
+ def _logits_index(out):
126
+ """Return (L, i, n_classes): tuple length (None if tensor), logits slot, n_classes.
127
+
128
+ Heuristic: the class-logits tensor is the 2-D (B, C) tensor with the smallest C.
129
+ """
130
+ if isinstance(out, torch.Tensor):
131
+ return None, None, int(out.shape[-1])
132
+ cands = [(j, t) for j, t in enumerate(out)
133
+ if isinstance(t, torch.Tensor) and t.dim() == 2]
134
+ if not cands:
135
+ raise RuntimeError("could not locate a 2-D (B,C) logits tensor in net output")
136
+ j, t = min(cands, key=lambda it: int(it[1].shape[-1]))
137
+ return len(out), j, int(t.shape[-1])
138
+
139
+
140
+ def analyze(model):
141
+ """One real forward through the capture shim -> (T, L, i, n_classes)."""
142
+ data = real_audio(1)
143
+ audios = [a for a, _ in data]
144
+ srs = [s for _, s in data]
145
+ cap = _Capture(model.net)
146
+ model.net = cap
147
+ with torch.no_grad():
148
+ model.score_batch(audios, srs)
149
+ model.net = cap.net
150
+ T = int(cap.x.shape[-1])
151
+ L, i, n_classes = _logits_index(cap.out)
152
+ return T, L, i, n_classes
153
+
154
+
155
+ def _extractor(L, i):
156
+ """Pick the logits tensor out of a net's raw output."""
157
+ if L is None:
158
+ return lambda y: y
159
+ return lambda y, i=i: y[i]
160
+
161
+
162
+ def _rebuild(L, i):
163
+ """Wrap a bare logits tensor back into the net's original output structure."""
164
+ if L is None:
165
+ return lambda y: y
166
+ return lambda y, L=L, i=i: tuple(y if j == i else None for j in range(L))
167
+
168
+
169
+ def _prep_for_export(net):
170
+ """Make export-hostile layers traceable. No-op for non-fairseq models.
171
+
172
+ fairseq wav2vec2/hubert call `pad_to_multiple`, which does `(tsz/multiple)
173
+ .is_integer()`; under torch.jit tracing `tsz` becomes a Tensor with no
174
+ `.is_integer()`. Our time axis is static, so we swap in a constant-length
175
+ pad that traces cleanly. Patches every fairseq module that bound the name.
176
+ """
177
+ def _safe_pad(x, multiple, dim=-1, value=0):
178
+ import torch.nn.functional as F
179
+ if x is None:
180
+ return None, 0
181
+ tsz = int(x.shape[dim]) # static: time axis is fixed
182
+ rem = (multiple - tsz % multiple) % multiple
183
+ if rem == 0:
184
+ return x, 0
185
+ pad_offset = (0,) * (-1 - dim) * 2
186
+ return F.pad(x, (*pad_offset, 0, rem), value=value), rem
187
+
188
+ for modname in ("fairseq.models.wav2vec.utils",
189
+ "fairseq.models.wav2vec.wav2vec2",
190
+ "fairseq.models.hubert.hubert"):
191
+ mod = sys.modules.get(modname)
192
+ if mod is not None and hasattr(mod, "pad_to_multiple"):
193
+ mod.pad_to_multiple = _safe_pad
194
+ _freeze_sinc(net)
195
+ # optional per-model export patch (dir-local `_trt_patch.py` with `patch(net)`)
196
+ try:
197
+ import importlib
198
+ importlib.import_module("_trt_patch").patch(net)
199
+ except ModuleNotFoundError:
200
+ pass
201
+ if DYNAMO_EXPORT:
202
+ _replace_global_avgpool(net)
203
+ return net
204
+
205
+
206
+ class _MeanPool(torch.nn.Module):
207
+ """Global average over `dims` (keepdim) — == AdaptiveAvgPool{1,2}d(1)."""
208
+
209
+ def __init__(self, dims):
210
+ super().__init__()
211
+ self.dims = dims
212
+
213
+ def forward(self, x):
214
+ return x.mean(dim=self.dims, keepdim=True)
215
+
216
+
217
+ def _replace_global_avgpool(net):
218
+ """Swap AdaptiveAvgPool1d/2d(output_size=1) for an explicit mean. The dynamo
219
+ exporter lowers the adaptive pool to as_strided/SequenceEmpty, which TensorRT
220
+ rejects; a plain mean lowers to ReduceMean. Identical for output_size==1."""
221
+ import torch.nn as nn
222
+ for full_name, mod in list(net.named_modules()):
223
+ is1d = isinstance(mod, nn.AdaptiveAvgPool1d) and mod.output_size in (1, (1,))
224
+ is2d = isinstance(mod, nn.AdaptiveAvgPool2d) and mod.output_size in (1, (1, 1))
225
+ if not (is1d or is2d):
226
+ continue
227
+ parent = net
228
+ *parents, attr = full_name.split(".")
229
+ for p in parents:
230
+ parent = getattr(parent, p)
231
+ setattr(parent, attr, _MeanPool((-1,) if is1d else (-2, -1)))
232
+ return net
233
+
234
+
235
+ def _freeze_sinc(net):
236
+ """Replace SincConv-style layers with an equivalent nn.Conv1d holding the
237
+ precomputed band-pass filters. At eval the filters are constant, but their
238
+ in-forward construction (torch.sin/cat/flip from learnable params) either
239
+ won't build in TensorRT or constant-folds to wrong values. Baking them into a
240
+ plain Conv1d removes the sinc math from the graph. No-op when no Sinc layer.
241
+ """
242
+ import torch.nn as nn
243
+ sincs = [(n, m) for n, m in net.named_modules() if "Sinc" in type(m).__name__]
244
+ if not sincs:
245
+ return net
246
+ dev = next(net.parameters()).device
247
+ for full_name, mod in sincs:
248
+ kernel = int(getattr(mod, "kernel_size", 0)) or 1
249
+ with torch.no_grad():
250
+ try:
251
+ mod(torch.zeros(1, 1, max(kernel * 4, 4096), device=dev))
252
+ except Exception: # noqa: BLE001 — filters are set before the conv call
253
+ pass
254
+ W = mod.filters.detach().clone() # [out, 1, kernel] (or [out, kernel])
255
+ if W.dim() == 2:
256
+ W = W.unsqueeze(1)
257
+ conv = nn.Conv1d(W.shape[1], W.shape[0], W.shape[2],
258
+ stride=int(getattr(mod, "stride", 1)),
259
+ padding=int(getattr(mod, "padding", 0)),
260
+ dilation=int(getattr(mod, "dilation", 1)),
261
+ bias=False).to(dev).eval()
262
+ conv.weight.data.copy_(W)
263
+ parent = net
264
+ *parents, attr = full_name.split(".")
265
+ for p in parents:
266
+ parent = getattr(parent, p)
267
+ setattr(parent, attr, conv)
268
+ return net
269
+
270
+
271
+ class _ExportNet(torch.nn.Module):
272
+ """forward(x[B,T]) -> logits[B,C] (single tensor) for ONNX/TRT."""
273
+
274
+ def __init__(self, net, L, i):
275
+ super().__init__()
276
+ self.net = net
277
+ self._extract = _extractor(L, i)
278
+
279
+ def forward(self, x):
280
+ return self._extract(self.net(x))
281
+
282
+
283
+ # ----------------------------------------------------------------------------
284
+ # export + build
285
+ # ----------------------------------------------------------------------------
286
+ def export_onnx(model, T, L, i, onnx_path):
287
+ net = _prep_for_export(model.net)
288
+ wrap = _ExportNet(net, L, i).eval().to("cuda")
289
+ dummy = torch.zeros(2, T, device="cuda", dtype=torch.float32)
290
+ if DYNAMO_EXPORT:
291
+ # >2 GB models: TorchScript exporter's shape-inference overflows the 2 GB
292
+ # protobuf limit. The dynamo exporter writes weights as external data.
293
+ batch = torch.export.Dim("b", min=1, max=MAX_BATCH_CAP)
294
+ torch.onnx.export(
295
+ wrap, (dummy,), str(onnx_path), dynamo=True, external_data=True,
296
+ input_names=["wav"], output_names=["logits"],
297
+ dynamic_shapes={"x": {0: batch}},
298
+ )
299
+ else:
300
+ torch.onnx.export(
301
+ wrap, dummy, str(onnx_path), opset_version=OPSET,
302
+ input_names=["wav"], output_names=["logits"],
303
+ dynamic_axes={"wav": {0: "batch"}, "logits": {0: "batch"}},
304
+ do_constant_folding=True,
305
+ )
306
+ return onnx_path
307
+
308
+
309
+ def build_engine(onnx_path, T, precision, max_batch, opt_batch, engine_path, timing_cache):
310
+ import tensorrt as trt
311
+
312
+ sev = trt.Logger.VERBOSE if os.environ.get("SSB_TRT_VERBOSE") else trt.Logger.WARNING
313
+ logger = trt.Logger(sev)
314
+ builder = trt.Builder(logger)
315
+ network = builder.create_network(0)
316
+ parser = trt.OnnxParser(network, logger)
317
+ # parse_from_file resolves external-data sidecars (needed for >2 GB models);
318
+ # works for inline ONNX too.
319
+ if not parser.parse_from_file(str(onnx_path)):
320
+ errs = "; ".join(str(parser.get_error(k)) for k in range(parser.num_errors))
321
+ raise RuntimeError(f"onnx parse failed: {errs}")
322
+
323
+ cfg = builder.create_builder_config()
324
+ cfg.builder_optimization_level = 1 # minimum build time
325
+ if precision == "fp16":
326
+ cfg.set_flag(trt.BuilderFlag.FP16)
327
+
328
+ tc_bytes = Path(timing_cache).read_bytes() if Path(timing_cache).exists() else b""
329
+ tc = cfg.create_timing_cache(tc_bytes)
330
+ cfg.set_timing_cache(tc, ignore_mismatch=False)
331
+
332
+ profile = builder.create_optimization_profile()
333
+ profile.set_shape("wav", (1, T), (opt_batch, T), (max_batch, T))
334
+ cfg.add_optimization_profile(profile)
335
+
336
+ plan = builder.build_serialized_network(network, cfg)
337
+ if plan is None:
338
+ raise RuntimeError("engine build returned None")
339
+ Path(engine_path).write_bytes(bytes(plan))
340
+ Path(timing_cache).write_bytes(bytes(tc.serialize()))
341
+ return engine_path
342
+
343
+
344
+ # ----------------------------------------------------------------------------
345
+ # runtime
346
+ # ----------------------------------------------------------------------------
347
+ class _TRTCallable:
348
+ """Mimics net(xt): runs the engine on a [B,T] float32 CUDA tensor."""
349
+
350
+ def __init__(self, engine_path, n_classes, L, i):
351
+ import tensorrt as trt
352
+
353
+ self.n_classes = n_classes
354
+ self.rebuild = _rebuild(L, i)
355
+ logger = trt.Logger(trt.Logger.WARNING)
356
+ self.runtime = trt.Runtime(logger)
357
+ self.engine = self.runtime.deserialize_cuda_engine(Path(engine_path).read_bytes())
358
+ self.ctx = self.engine.create_execution_context()
359
+ if self.ctx is None:
360
+ raise RuntimeError(
361
+ "could not create execution context (likely OOM reserving max-profile "
362
+ "memory) — lower MAX_BATCH_CAP")
363
+ # resolve I/O tensor names
364
+ self.in_name, self.out_name = "wav", "logits"
365
+
366
+ def __call__(self, x, *a, **k):
367
+ x = x.to("cuda", torch.float32).contiguous()
368
+ B = x.shape[0]
369
+ self.ctx.set_input_shape(self.in_name, tuple(x.shape))
370
+ out = torch.empty((B, self.n_classes), device="cuda", dtype=torch.float32)
371
+ self.ctx.set_tensor_address(self.in_name, x.data_ptr())
372
+ self.ctx.set_tensor_address(self.out_name, out.data_ptr())
373
+ stream = torch.cuda.current_stream().cuda_stream
374
+ self.ctx.execute_async_v3(stream)
375
+ torch.cuda.current_stream().synchronize()
376
+ return self.rebuild(out)
377
+
378
+
379
+ # ----------------------------------------------------------------------------
380
+ # parity + sweep
381
+ # ----------------------------------------------------------------------------
382
+ def _chunked_scores(model, audios, srs, chunk):
383
+ out = []
384
+ for k in range(0, len(audios), chunk):
385
+ out.extend(model.score_batch(audios[k:k + chunk], srs[k:k + chunk]))
386
+ return np.asarray(out, dtype=np.float64)
387
+
388
+
389
+ def _spearman(a, b):
390
+ if len(a) < 2:
391
+ return 1.0
392
+ ra = np.argsort(np.argsort(a)).astype(np.float64)
393
+ rb = np.argsort(np.argsort(b)).astype(np.float64)
394
+ return float(np.corrcoef(ra, rb)[0, 1])
395
+
396
+
397
+ def parity(model, trt_call, n=64, chunk=PARITY_CHUNK):
398
+ data = real_audio(n)
399
+ audios = [a for a, _ in data]
400
+ srs = [s for _, s in data]
401
+ torch_net = model.net
402
+ py = _chunked_scores(model, audios, srs, chunk)
403
+ model.net = trt_call
404
+ tr = _chunked_scores(model, audios, srs, chunk)
405
+ model.net = torch_net
406
+ mad = float(np.max(np.abs(py - tr)))
407
+ pear = float(np.corrcoef(py, tr)[0, 1]) if len(py) > 1 else 1.0
408
+ spear = _spearman(py, tr)
409
+ return {"n": len(py), "max_abs_score_diff": mad, "pearson": pear,
410
+ "spearman": spear}
411
+
412
+
413
+ def sweep(model, trt_call,
414
+ batches=(1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128), iters=20):
415
+ a, sr = real_audio(1)[0]
416
+ model.net = trt_call
417
+ res = {}
418
+ for B in batches:
419
+ if B > MAX_BATCH_CAP:
420
+ break
421
+ ab, sb = [a] * B, [sr] * B
422
+ try:
423
+ for _ in range(3):
424
+ model.score_batch(ab, sb) # warmup
425
+ torch.cuda.synchronize()
426
+ t0 = time.time()
427
+ for _ in range(iters):
428
+ model.score_batch(ab, sb)
429
+ torch.cuda.synchronize()
430
+ dt = time.time() - t0
431
+ res[B] = B * iters / dt
432
+ except RuntimeError as e:
433
+ if "out of memory" in str(e).lower():
434
+ torch.cuda.empty_cache()
435
+ break
436
+ raise
437
+ best = max(res, key=res.get)
438
+ return best, res
439
+
440
+
441
+ # ----------------------------------------------------------------------------
442
+ # drop-in inference class
443
+ # ----------------------------------------------------------------------------
444
+ class SpectraAASIST3TRT(_OrigClass):
445
+ """Drop-in: original preprocessing/score_batch; net replaced by the TRT engine."""
446
+
447
+ def load(self):
448
+ self.device = "cuda"
449
+ side = json.loads((HERE / f"trt_{SLUG}.json").read_text())[gpu_slug()]
450
+ eng = HERE / side["engine"]
451
+ self.net = _TRTCallable(str(eng), side["n_classes"], side["L"], side["i"])
452
+ self.batch_size = side["best_batch"]
453
+
454
+
455
+ # ----------------------------------------------------------------------------
456
+ # CLI
457
+ # ----------------------------------------------------------------------------
458
+ def _do_export():
459
+ gpu = gpu_slug()
460
+ side_path = HERE / f"trt_{SLUG}.json"
461
+ tc = HERE / f".trt_timing_{gpu}.cache"
462
+ m = load_model()
463
+ T, L, i, n_classes = analyze(m)
464
+ print(f"[analyze] T={T} n_classes={n_classes} L={L} i={i}")
465
+ onnx_path = HERE / f"{SLUG}.onnx"
466
+ export_onnx(m, T, L, i, onnx_path)
467
+ print(f"[onnx] wrote {onnx_path.name}")
468
+
469
+ # PyTorch reference scores while the model is on GPU, then free it so the
470
+ # engine build + TRT inference never co-reside with the model (giant >2 GB
471
+ # models would otherwise OOM the 16 GB card).
472
+ pdata = real_audio(64)
473
+ paud, psr = [a for a, _ in pdata], [s for _, s in pdata]
474
+ py = _chunked_scores(m, paud, psr, PARITY_CHUNK)
475
+ m.net.to("cpu")
476
+ torch.cuda.empty_cache()
477
+
478
+ opt_batch = min(32, MAX_BATCH_CAP)
479
+ if FORCE_FP16:
480
+ precisions = ["fp16"]
481
+ elif FORCE_FP32:
482
+ precisions = ["fp32"]
483
+ else:
484
+ precisions = ["fp16", "fp32"]
485
+ chosen = None
486
+ last_err = None
487
+ for prec in precisions:
488
+ eng = HERE / f"engine_{gpu}_{prec}_b1-{opt_batch}-{MAX_BATCH_CAP}.plan"
489
+ try:
490
+ t0 = time.time()
491
+ build_engine(str(onnx_path), T, prec, MAX_BATCH_CAP, opt_batch, str(eng), str(tc))
492
+ bt = time.time() - t0
493
+ trt_call = _TRTCallable(str(eng), n_classes, L, i)
494
+ m.net = trt_call
495
+ tr = _chunked_scores(m, paud, psr, PARITY_CHUNK)
496
+ p = {"n": len(py),
497
+ "max_abs_score_diff": float(np.max(np.abs(py - tr))),
498
+ "pearson": float(np.corrcoef(py, tr)[0, 1]) if len(py) > 1 else 1.0,
499
+ "spearman": _spearman(py, tr)}
500
+ except Exception as e: # noqa: BLE001 — try the next precision (e.g. FP16 layer not buildable)
501
+ last_err = e
502
+ print(f"[{prec}] FAILED: {type(e).__name__}: {e}")
503
+ continue
504
+ print(f"[{prec}] build={bt:.1f}s parity={p}")
505
+ chosen = (prec, eng, p, trt_call)
506
+ if prec == "fp16" and p["spearman"] >= PARITY_SPEARMAN_TOL:
507
+ break
508
+
509
+ if chosen is None:
510
+ raise RuntimeError(f"all precisions failed to build; last error: {last_err}")
511
+ prec, eng, p, trt_call = chosen
512
+ if p["spearman"] < PARITY_FLOOR:
513
+ raise RuntimeError(
514
+ f"parity too low (spearman={p['spearman']:.4f} < {PARITY_FLOOR}): "
515
+ f"engine output does not match PyTorch — not accepting")
516
+ m.net = trt_call
517
+ best, table = sweep(m, trt_call)
518
+ side = json.loads(side_path.read_text()) if side_path.exists() else {}
519
+ side[gpu] = {
520
+ "precision": prec, "engine": eng.name, "window_samples": T,
521
+ "n_classes": n_classes, "L": L, "i": i, "best_batch": best,
522
+ "throughput_utt_s": {str(k): round(v, 2) for k, v in table.items()},
523
+ "parity": p, "trt_version": __import__("tensorrt").__version__,
524
+ }
525
+ side_path.write_text(json.dumps(side, indent=2, default=str))
526
+ print(f"[done] {SLUG}: prec={prec} best_batch={best} "
527
+ f"utt/s={table[best]:.1f} parity_mad={p['max_abs_score_diff']:.2e}")
528
+
529
+
530
+ def main():
531
+ ap = argparse.ArgumentParser()
532
+ ap.add_argument("cmd", choices=["export", "sweep", "parity", "score"])
533
+ ap.add_argument("audio", nargs="?")
534
+ args = ap.parse_args()
535
+ gpu = gpu_slug()
536
+ side_path = HERE / f"trt_{SLUG}.json"
537
+
538
+ if args.cmd == "export":
539
+ _do_export()
540
+ elif args.cmd in ("sweep", "parity"):
541
+ m = load_model()
542
+ side = json.loads(side_path.read_text())[gpu]
543
+ eng = HERE / side["engine"]
544
+ trt_call = _TRTCallable(str(eng), side["n_classes"], side["L"], side["i"])
545
+ if args.cmd == "parity":
546
+ print(parity(m, trt_call))
547
+ else:
548
+ best, table = sweep(m, trt_call)
549
+ full = json.loads(side_path.read_text())
550
+ full[gpu]["best_batch"] = best
551
+ full[gpu]["throughput_utt_s"] = {str(k): round(v, 2) for k, v in table.items()}
552
+ side_path.write_text(json.dumps(full, indent=2, default=str))
553
+ print(f"best_batch={best} utt/s={table[best]:.1f}")
554
+ elif args.cmd == "score":
555
+ import soundfile as sf
556
+ a, sr = sf.read(args.audio, dtype="float32")
557
+ if a.ndim > 1:
558
+ a = a.mean(1)
559
+ m = SpectraAASIST3TRT()
560
+ m.load()
561
+ print(m.score_batch([a.astype(np.float32)], [sr])[0])
562
+
563
+
564
+ if __name__ == "__main__":
565
+ main()