BirdNET_GLOBAL_6K_V2.4_MData_Model_FP16.tflite DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4813567791d2fe2f38fb9e195e61a6261141a6f3b134b3056b6b062d22ac88f5
3
- size 7071440
 
 
 
 
BirdNET_GLOBAL_6K_V2.4_Model_FP32.tflite DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:55f3e4055b1a13bfa9a2452731d0d34f6a02d6b775a334362665892794165e4c
3
- size 51726412
 
 
 
 
README.md CHANGED
@@ -1,26 +1,11 @@
1
  ---
2
- license: cc
3
- tags:
4
- - audio
5
- - bird
6
- - nature
7
- - science
8
- - vocalization
9
- - bio
10
- - birds-classification
11
- - bioacoustics
12
  ---
13
 
14
  # BirdNET ONNX
15
 
16
  ONNX model converted and optimized from `BirdNET_GLOBAL_6K_V2.4_Model_FP32.tflite`.
17
 
18
- Files:
19
-
20
- - `model.onnx`: Initial model converted with tf2onnx and edited using NVIDIA Nsight DL Designer
21
- - `birdnet.onnx`: Model further optimized with the `scripts/optimize.py` script. Recommended
22
- - `birdnet_data_model.onnx`: The range filter meta model converted and optimized from `BirdNET_GLOBAL_6K_V2.4_MData_Model_FP16.tflite`
23
-
24
- Source: https://github.com/birdnet-team/BirdNET-Analyzer
25
-
26
- License: CC BY-NC-SA 4.0
 
1
  ---
2
+ license: mit
3
+ base_model:
4
+ - onnx-community/BirdNET
 
 
 
 
 
 
 
5
  ---
6
 
7
  # BirdNET ONNX
8
 
9
  ONNX model converted and optimized from `BirdNET_GLOBAL_6K_V2.4_Model_FP32.tflite`.
10
 
11
+ Source: https://github.com/birdnet-team/BirdNET-Analyzer
 
 
 
 
 
 
 
 
scripts/REALTIME_README.md → REALTIME_README.md RENAMED
File without changes
scripts/USAGE.md → USAGE.md RENAMED
File without changes
birdnet.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ea590fad3d9616195f465db8330882ff24bad8eb2b4a95b4ecc3a4c228aa364e
3
- size 66932288
 
 
 
 
birdnet_data_model.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:92790ef53c2990c6fbd0657c1e3424df988657e4058a66640301ae3d27f27a3f
3
- size 14127934
 
 
 
 
scripts/predict_audio.py → predict_audio.py RENAMED
File without changes
scripts/realtime_detection.py → realtime_detection.py RENAMED
File without changes
scripts/compare_onnx_tflite.py DELETED
@@ -1,667 +0,0 @@
1
- #!/usr/bin/env python3
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- """
5
- Script to compare the results of an ONNX model with a TFLite model given the same input.
6
- Optionally also compare with Tract runtime for ONNX.
7
- Created by Copilot.
8
-
9
- Usage:
10
- python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite
11
- python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --input input.npy
12
- python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --rtol 1e-5 --atol 1e-5
13
- python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark
14
- python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract --benchmark
15
- """
16
-
17
- import argparse
18
- import time
19
- import numpy as np
20
- import onnxruntime as ort
21
- import tensorflow as tf
22
- from typing import Dict, List, Tuple, Optional, Any
23
-
24
- try:
25
- import tract
26
-
27
- TRACT_AVAILABLE = True
28
- except ImportError:
29
- TRACT_AVAILABLE = False
30
-
31
-
32
- def load_onnx_model(onnx_path: str) -> ort.InferenceSession:
33
- """Load an ONNX model and return an inference session."""
34
- print(f"Loading ONNX model from: {onnx_path}")
35
- session = ort.InferenceSession(onnx_path)
36
- return session
37
-
38
-
39
- def load_tflite_model(tflite_path: str) -> tf.lite.Interpreter:
40
- """Load a TFLite model and return an interpreter."""
41
- print(f"Loading TFLite model from: {tflite_path}")
42
- interpreter = tf.lite.Interpreter(model_path=tflite_path)
43
- interpreter.allocate_tensors()
44
- return interpreter
45
-
46
-
47
- def load_tract_model(onnx_path: str) -> Optional[Any]:
48
- """Load an ONNX model using tract and return a runnable model."""
49
- if not TRACT_AVAILABLE:
50
- print("Tract is not available. Install with: pip install tract")
51
- return None
52
- print(f"Loading ONNX model with tract from: {onnx_path}")
53
- model = tract.onnx().model_for_path(onnx_path).into_optimized().into_runnable()
54
- return model
55
-
56
-
57
- def get_onnx_model_info(session: ort.InferenceSession) -> Tuple[List, List]:
58
- """Get input and output information from ONNX model."""
59
- inputs = session.get_inputs()
60
- outputs = session.get_outputs()
61
-
62
- print("\nONNX Model Information:")
63
- print("Inputs:")
64
- for inp in inputs:
65
- print(f" - Name: {inp.name}, Shape: {inp.shape}, Type: {inp.type}")
66
- print("Outputs:")
67
- for out in outputs:
68
- print(f" - Name: {out.name}, Shape: {out.shape}, Type: {out.type}")
69
-
70
- return inputs, outputs
71
-
72
-
73
- def get_tflite_model_info(interpreter: tf.lite.Interpreter) -> Tuple[List, List]:
74
- """Get input and output information from TFLite model."""
75
- input_details = interpreter.get_input_details()
76
- output_details = interpreter.get_output_details()
77
-
78
- print("\nTFLite Model Information:")
79
- print("Inputs:")
80
- for inp in input_details:
81
- print(f" - Name: {inp['name']}, Shape: {inp['shape']}, Type: {inp['dtype']}")
82
- print("Outputs:")
83
- for out in output_details:
84
- print(f" - Name: {out['name']}, Shape: {out['shape']}, Type: {out['dtype']}")
85
-
86
- return input_details, output_details
87
-
88
-
89
- def generate_random_inputs(onnx_inputs: List, seed: int = 42) -> Dict[str, np.ndarray]:
90
- """Generate random inputs based on ONNX model input specs."""
91
- np.random.seed(seed)
92
- inputs = {}
93
-
94
- print("\nGenerating random inputs:")
95
- for inp in onnx_inputs:
96
- # Handle dynamic dimensions
97
- shape = []
98
- for dim in inp.shape:
99
- if isinstance(dim, str) or dim is None or dim < 0:
100
- # Default to 1 for dynamic dimensions
101
- shape.append(1)
102
- else:
103
- shape.append(dim)
104
-
105
- # Generate random data based on type
106
- if "float" in inp.type.lower():
107
- data = np.random.randn(*shape).astype(np.float32)
108
- elif "int64" in inp.type.lower():
109
- data = np.random.randint(0, 100, size=shape).astype(np.int64)
110
- elif "int32" in inp.type.lower():
111
- data = np.random.randint(0, 100, size=shape).astype(np.int32)
112
- else:
113
- # Default to float32
114
- data = np.random.randn(*shape).astype(np.float32)
115
-
116
- inputs[inp.name] = data
117
- print(f" - {inp.name}: shape={data.shape}, dtype={data.dtype}")
118
-
119
- return inputs
120
-
121
-
122
- def load_inputs_from_file(input_path: str) -> Dict[str, np.ndarray]:
123
- """Load inputs from a numpy file (.npy or .npz)."""
124
- print(f"\nLoading inputs from: {input_path}")
125
-
126
- if input_path.endswith(".npz"):
127
- data = np.load(input_path)
128
- inputs = {key: data[key] for key in data.files}
129
- elif input_path.endswith(".npy"):
130
- data = np.load(input_path)
131
- # Assume single input
132
- inputs = {"input": data}
133
- else:
134
- raise ValueError("Input file must be .npy or .npz format")
135
-
136
- for name, value in inputs.items():
137
- print(f" - {name}: shape={value.shape}, dtype={value.dtype}")
138
-
139
- return inputs
140
-
141
-
142
- def run_onnx_model(
143
- session: ort.InferenceSession, inputs: Dict[str, np.ndarray]
144
- ) -> List[np.ndarray]:
145
- """Run inference on ONNX model."""
146
- print("\nRunning ONNX model inference...")
147
- outputs = session.run(None, inputs)
148
- return outputs
149
-
150
-
151
- def run_tflite_model(
152
- interpreter: tf.lite.Interpreter, inputs: Dict[str, np.ndarray], input_details: List
153
- ) -> List[np.ndarray]:
154
- """Run inference on TFLite model."""
155
- print("Running TFLite model inference...")
156
-
157
- # Set input tensors
158
- for i, detail in enumerate(input_details):
159
- # Try to match by name or use order
160
- input_data = None
161
- if detail["name"] in inputs:
162
- input_data = inputs[detail["name"]]
163
- elif len(inputs) == 1:
164
- # If only one input, use it
165
- input_data = list(inputs.values())[0]
166
- elif i < len(inputs):
167
- # Use by order
168
- input_data = list(inputs.values())[i]
169
- else:
170
- raise ValueError(f"Cannot match input for TFLite input {detail['name']}")
171
-
172
- # Ensure correct dtype
173
- if input_data.dtype != detail["dtype"]:
174
- input_data = input_data.astype(detail["dtype"])
175
-
176
- interpreter.set_tensor(detail["index"], input_data)
177
-
178
- # Run inference
179
- interpreter.invoke()
180
-
181
- # Get output tensors
182
- output_details = interpreter.get_output_details()
183
- outputs = []
184
- for detail in output_details:
185
- outputs.append(interpreter.get_tensor(detail["index"]))
186
-
187
- return outputs
188
-
189
-
190
- def run_tract_model(model: Any, inputs: Dict[str, np.ndarray]) -> List[np.ndarray]:
191
- """Run inference on tract model."""
192
- if model is None:
193
- return []
194
- print("Running tract model inference...")
195
-
196
- # Convert inputs to list (tract expects a list of tensors)
197
- input_list = list(inputs.values())
198
-
199
- # Run inference
200
- outputs = model.run(input_list)
201
-
202
- # Convert outputs to numpy arrays
203
- result = []
204
- for output in outputs:
205
- result.append(output.to_numpy())
206
-
207
- return result
208
-
209
-
210
- def benchmark_onnx_model(
211
- session: ort.InferenceSession,
212
- inputs: Dict[str, np.ndarray],
213
- num_runs: int = 100,
214
- warmup_runs: int = 10,
215
- ) -> Dict[str, float]:
216
- """Benchmark ONNX model inference speed."""
217
- print(f"\nBenchmarking ONNX model ({warmup_runs} warmup + {num_runs} test runs)...")
218
-
219
- # Warmup runs
220
- for _ in range(warmup_runs):
221
- session.run(None, inputs)
222
-
223
- # Timed runs
224
- times = []
225
- for _ in range(num_runs):
226
- start = time.perf_counter()
227
- session.run(None, inputs)
228
- end = time.perf_counter()
229
- times.append((end - start) * 1000) # Convert to ms
230
-
231
- return {
232
- "mean": np.mean(times),
233
- "median": np.median(times),
234
- "std": np.std(times),
235
- "min": np.min(times),
236
- "max": np.max(times),
237
- }
238
-
239
-
240
- def benchmark_tflite_model(
241
- interpreter: tf.lite.Interpreter,
242
- inputs: Dict[str, np.ndarray],
243
- input_details: List,
244
- num_runs: int = 100,
245
- warmup_runs: int = 10,
246
- ) -> Dict[str, float]:
247
- """Benchmark TFLite model inference speed."""
248
- print(f"Benchmarking TFLite model ({warmup_runs} warmup + {num_runs} test runs)...")
249
-
250
- # Prepare inputs
251
- def set_inputs():
252
- for i, detail in enumerate(input_details):
253
- input_data = None
254
- if detail["name"] in inputs:
255
- input_data = inputs[detail["name"]]
256
- elif len(inputs) == 1:
257
- input_data = list(inputs.values())[0]
258
- elif i < len(inputs):
259
- input_data = list(inputs.values())[i]
260
- else:
261
- raise ValueError(
262
- f"Cannot match input for TFLite input {detail['name']}"
263
- )
264
-
265
- if input_data.dtype != detail["dtype"]:
266
- input_data = input_data.astype(detail["dtype"])
267
-
268
- interpreter.set_tensor(detail["index"], input_data)
269
-
270
- # Warmup runs
271
- for _ in range(warmup_runs):
272
- set_inputs()
273
- interpreter.invoke()
274
-
275
- # Timed runs
276
- times = []
277
- for _ in range(num_runs):
278
- set_inputs()
279
- start = time.perf_counter()
280
- interpreter.invoke()
281
- end = time.perf_counter()
282
- times.append((end - start) * 1000) # Convert to ms
283
-
284
- return {
285
- "mean": np.mean(times),
286
- "median": np.median(times),
287
- "std": np.std(times),
288
- "min": np.min(times),
289
- "max": np.max(times),
290
- }
291
-
292
-
293
- def benchmark_tract_model(
294
- model: Any,
295
- inputs: Dict[str, np.ndarray],
296
- num_runs: int = 100,
297
- warmup_runs: int = 10,
298
- ) -> Optional[Dict[str, float]]:
299
- """Benchmark tract model inference speed."""
300
- if model is None:
301
- return None
302
- print(f"Benchmarking tract model ({warmup_runs} warmup + {num_runs} test runs)...")
303
-
304
- # Convert inputs to list
305
- input_list = list(inputs.values())
306
-
307
- # Warmup runs
308
- for _ in range(warmup_runs):
309
- model.run(input_list)
310
-
311
- # Timed runs
312
- times = []
313
- for _ in range(num_runs):
314
- start = time.perf_counter()
315
- model.run(input_list)
316
- end = time.perf_counter()
317
- times.append((end - start) * 1000) # Convert to ms
318
-
319
- return {
320
- "mean": np.mean(times),
321
- "median": np.median(times),
322
- "std": np.std(times),
323
- "min": np.min(times),
324
- "max": np.max(times),
325
- }
326
-
327
-
328
- def print_benchmark_results(
329
- onnx_stats: Dict[str, float],
330
- tflite_stats: Dict[str, float],
331
- tract_stats: Optional[Dict[str, float]] = None,
332
- ) -> None:
333
- """Print benchmark comparison results."""
334
- print("\n" + "=" * 80)
335
- print("BENCHMARK RESULTS")
336
- print("=" * 80)
337
-
338
- print("\nONNX Model:")
339
- print(f" Mean: {onnx_stats['mean']:.3f} ms")
340
- print(f" Median: {onnx_stats['median']:.3f} ms")
341
- print(f" Std: {onnx_stats['std']:.3f} ms")
342
- print(f" Min: {onnx_stats['min']:.3f} ms")
343
- print(f" Max: {onnx_stats['max']:.3f} ms")
344
-
345
- print("\nTFLite Model:")
346
- print(f" Mean: {tflite_stats['mean']:.3f} ms")
347
- print(f" Median: {tflite_stats['median']:.3f} ms")
348
- print(f" Std: {tflite_stats['std']:.3f} ms")
349
- print(f" Min: {tflite_stats['min']:.3f} ms")
350
- print(f" Max: {tflite_stats['max']:.3f} ms")
351
-
352
- if tract_stats:
353
- print("\nTract Model:")
354
- print(f" Mean: {tract_stats['mean']:.3f} ms")
355
- print(f" Median: {tract_stats['median']:.3f} ms")
356
- print(f" Std: {tract_stats['std']:.3f} ms")
357
- print(f" Min: {tract_stats['min']:.3f} ms")
358
- print(f" Max: {tract_stats['max']:.3f} ms")
359
-
360
- print("\nComparison:")
361
- speedup = tflite_stats["mean"] / onnx_stats["mean"]
362
- if speedup > 1:
363
- print(f" ONNX Runtime is {speedup:.2f}x faster than TFLite")
364
- else:
365
- print(f" TFLite is {1 / speedup:.2f}x faster than ONNX Runtime")
366
- print(f" Difference: {abs(onnx_stats['mean'] - tflite_stats['mean']):.3f} ms")
367
-
368
- if tract_stats:
369
- speedup_tract = tflite_stats["mean"] / tract_stats["mean"]
370
- if speedup_tract > 1:
371
- print(f" Tract is {speedup_tract:.2f}x faster than TFLite")
372
- else:
373
- print(f" TFLite is {1 / speedup_tract:.2f}x faster than Tract")
374
- print(f" Difference: {abs(tract_stats['mean'] - tflite_stats['mean']):.3f} ms")
375
-
376
- speedup_ort = onnx_stats["mean"] / tract_stats["mean"]
377
- if speedup_ort > 1:
378
- print(f" Tract is {speedup_ort:.2f}x faster than ONNX Runtime")
379
- else:
380
- print(f" ONNX Runtime is {1 / speedup_ort:.2f}x faster than Tract")
381
- print(f" Difference: {abs(tract_stats['mean'] - onnx_stats['mean']):.3f} ms")
382
-
383
- print("=" * 80)
384
-
385
-
386
- def compare_outputs(
387
- onnx_outputs: List[np.ndarray],
388
- tflite_outputs: List[np.ndarray],
389
- tract_outputs: Optional[List[np.ndarray]] = None,
390
- rtol: float = 1e-5,
391
- atol: float = 1e-5,
392
- ) -> bool:
393
- """Compare outputs from ONNX, TFLite, and optionally Tract models."""
394
- print("\n" + "=" * 80)
395
- print("COMPARISON RESULTS")
396
- print("=" * 80)
397
-
398
- if len(onnx_outputs) != len(tflite_outputs):
399
- print(
400
- f"❌ Number of outputs differs: ONNX={len(onnx_outputs)}, TFLite={len(tflite_outputs)}"
401
- )
402
- return False
403
-
404
- if tract_outputs and len(onnx_outputs) != len(tract_outputs):
405
- print(
406
- f"❌ Number of outputs differs: ONNX={len(onnx_outputs)}, Tract={len(tract_outputs)}"
407
- )
408
- return False
409
-
410
- all_match = True
411
- for i, (onnx_out, tflite_out) in enumerate(zip(onnx_outputs, tflite_outputs)):
412
- tract_out = tract_outputs[i] if tract_outputs else None
413
-
414
- print(f"\nOutput {i}:")
415
- print(f" ONNX Runtime shape: {onnx_out.shape}, dtype: {onnx_out.dtype}")
416
- print(f" TFLite shape: {tflite_out.shape}, dtype: {tflite_out.dtype}")
417
- if tract_out is not None:
418
- print(f" Tract shape: {tract_out.shape}, dtype: {tract_out.dtype}")
419
-
420
- if onnx_out.shape != tflite_out.shape:
421
- print(" ❌ Shape mismatch between ONNX and TFLite!")
422
- all_match = False
423
- continue
424
-
425
- if tract_out is not None and onnx_out.shape != tract_out.shape:
426
- print(" ❌ Shape mismatch between ONNX and Tract!")
427
- all_match = False
428
- continue
429
-
430
- # Convert to same dtype for comparison
431
- if onnx_out.dtype != tflite_out.dtype:
432
- print(" ⚠️ Different dtypes, converting to float32 for comparison")
433
- onnx_out = onnx_out.astype(np.float32)
434
- tflite_out = tflite_out.astype(np.float32)
435
-
436
- if tract_out is not None and onnx_out.dtype != tract_out.dtype:
437
- tract_out = tract_out.astype(np.float32)
438
-
439
- # Compute statistics - ONNX vs TFLite
440
- print("\n ONNX Runtime vs TFLite:")
441
- diff = np.abs(onnx_out - tflite_out)
442
- max_diff = np.max(diff)
443
- mean_diff = np.mean(diff)
444
- is_close = np.allclose(onnx_out, tflite_out, rtol=rtol, atol=atol)
445
-
446
- print(f" Max difference: {max_diff:.10f}")
447
- print(f" Mean difference: {mean_diff:.10f}")
448
- print(f" Relative tolerance: {rtol}")
449
- print(f" Absolute tolerance: {atol}")
450
-
451
- if is_close:
452
- print(" ✅ Outputs match within tolerance")
453
- else:
454
- print(" ❌ Outputs do NOT match within tolerance")
455
- all_match = False
456
-
457
- # Show some sample values
458
- print("\n Sample values (first 5 elements):")
459
- flat_onnx = onnx_out.flatten()[:5]
460
- flat_tflite = tflite_out.flatten()[:5]
461
- for j, (o, t) in enumerate(zip(flat_onnx, flat_tflite)):
462
- print(
463
- f" [{j}] ONNX: {o:.10f}, TFLite: {t:.10f}, Diff: {abs(o - t):.10f}"
464
- )
465
-
466
- # Compute statistics - ONNX vs Tract
467
- if tract_out is not None:
468
- print("\n ONNX Runtime vs Tract:")
469
- diff_tract = np.abs(onnx_out - tract_out)
470
- max_diff_tract = np.max(diff_tract)
471
- mean_diff_tract = np.mean(diff_tract)
472
- is_close_tract = np.allclose(onnx_out, tract_out, rtol=rtol, atol=atol)
473
-
474
- print(f" Max difference: {max_diff_tract:.10f}")
475
- print(f" Mean difference: {mean_diff_tract:.10f}")
476
-
477
- if is_close_tract:
478
- print(" ✅ Outputs match within tolerance")
479
- else:
480
- print(" ❌ Outputs do NOT match within tolerance")
481
- all_match = False
482
-
483
- # Show some sample values
484
- print("\n Sample values (first 5 elements):")
485
- flat_onnx_tract = onnx_out.flatten()[:5]
486
- flat_tract = tract_out.flatten()[:5]
487
- for j, (o, tr) in enumerate(zip(flat_onnx_tract, flat_tract)):
488
- print(
489
- f" [{j}] ONNX: {o:.10f}, Tract: {tr:.10f}, Diff: {abs(o - tr):.10f}"
490
- )
491
-
492
- # Compute statistics - TFLite vs Tract
493
- print("\n TFLite vs Tract:")
494
- diff_tflite_tract = np.abs(tflite_out - tract_out)
495
- max_diff_tflite_tract = np.max(diff_tflite_tract)
496
- mean_diff_tflite_tract = np.mean(diff_tflite_tract)
497
- is_close_tflite_tract = np.allclose(
498
- tflite_out, tract_out, rtol=rtol, atol=atol
499
- )
500
-
501
- print(f" Max difference: {max_diff_tflite_tract:.10f}")
502
- print(f" Mean difference: {mean_diff_tflite_tract:.10f}")
503
-
504
- if is_close_tflite_tract:
505
- print(" ✅ Outputs match within tolerance")
506
- else:
507
- print(" ❌ Outputs do NOT match within tolerance")
508
- all_match = False
509
-
510
- print("\n" + "=" * 80)
511
- if all_match:
512
- print("✅ ALL OUTPUTS MATCH!")
513
- else:
514
- print("❌ SOME OUTPUTS DO NOT MATCH")
515
- print("=" * 80)
516
-
517
- return all_match
518
-
519
-
520
- def main():
521
- parser = argparse.ArgumentParser(
522
- description="Compare ONNX and TFLite model outputs",
523
- formatter_class=argparse.RawDescriptionHelpFormatter,
524
- epilog="""
525
- Examples:
526
- # Compare with random inputs
527
- python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite
528
-
529
- # Compare with custom inputs from file
530
- python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --input input.npz
531
-
532
- # Compare with custom tolerances
533
- python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --rtol 1e-3 --atol 1e-3
534
-
535
- # Save outputs for inspection
536
- python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --save-outputs
537
-
538
- # Benchmark execution speed
539
- python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark
540
-
541
- # Benchmark with custom number of runs
542
- python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark --num-runs 200 --warmup-runs 20
543
-
544
- # Compare with tract runtime as well
545
- python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract
546
-
547
- # Benchmark all three runtimes
548
- python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract --benchmark
549
- """,
550
- )
551
-
552
- parser.add_argument("--onnx", required=True, help="Path to ONNX model")
553
- parser.add_argument("--tflite", required=True, help="Path to TFLite model")
554
- parser.add_argument("--input", help="Path to input file (.npy or .npz)")
555
- parser.add_argument(
556
- "--rtol", type=float, default=1e-5, help="Relative tolerance (default: 1e-5)"
557
- )
558
- parser.add_argument(
559
- "--atol", type=float, default=1e-5, help="Absolute tolerance (default: 1e-5)"
560
- )
561
- parser.add_argument(
562
- "--seed",
563
- type=int,
564
- default=42,
565
- help="Random seed for input generation (default: 42)",
566
- )
567
- parser.add_argument(
568
- "--save-outputs", action="store_true", help="Save outputs to files"
569
- )
570
- parser.add_argument(
571
- "--benchmark",
572
- action="store_true",
573
- help="Benchmark execution speed of both models",
574
- )
575
- parser.add_argument(
576
- "--num-runs",
577
- type=int,
578
- default=100,
579
- help="Number of benchmark runs (default: 100)",
580
- )
581
- parser.add_argument(
582
- "--warmup-runs",
583
- type=int,
584
- default=10,
585
- help="Number of warmup runs (default: 10)",
586
- )
587
- parser.add_argument(
588
- "--use-tract", action="store_true", help="Also test with tract ONNX runtime"
589
- )
590
-
591
- args = parser.parse_args()
592
-
593
- # Load models
594
- onnx_session = load_onnx_model(args.onnx)
595
- tflite_interpreter = load_tflite_model(args.tflite)
596
-
597
- # Load tract model if requested
598
- tract_model = None
599
- if args.use_tract:
600
- if not TRACT_AVAILABLE:
601
- print(
602
- "\n⚠️ Warning: Tract is not installed. Install with: pip install tract"
603
- )
604
- print("Continuing without tract comparison...\n")
605
- else:
606
- tract_model = load_tract_model(args.onnx)
607
-
608
- # Get model info
609
- onnx_inputs, onnx_outputs = get_onnx_model_info(onnx_session)
610
- tflite_input_details, tflite_output_details = get_tflite_model_info(
611
- tflite_interpreter
612
- )
613
-
614
- # Prepare inputs
615
- if args.input:
616
- inputs = load_inputs_from_file(args.input)
617
- else:
618
- inputs = generate_random_inputs(onnx_inputs, seed=args.seed)
619
-
620
- # Run inference
621
- onnx_results = run_onnx_model(onnx_session, inputs)
622
- tflite_results = run_tflite_model(tflite_interpreter, inputs, tflite_input_details)
623
- tract_results = None
624
- if tract_model:
625
- tract_results = run_tract_model(tract_model, inputs)
626
-
627
- # Save outputs if requested
628
- if args.save_outputs:
629
- print("\nSaving outputs...")
630
- np.savez("onnx_outputs.npz", *onnx_results)
631
- np.savez("tflite_outputs.npz", *tflite_results)
632
- print(" - onnx_outputs.npz")
633
- print(" - tflite_outputs.npz")
634
- if tract_results:
635
- np.savez("tract_outputs.npz", *tract_results)
636
- print(" - tract_outputs.npz")
637
-
638
- # Compare results
639
- match = compare_outputs(
640
- onnx_results, tflite_results, tract_results, rtol=args.rtol, atol=args.atol
641
- )
642
-
643
- # Benchmark if requested
644
- if args.benchmark:
645
- onnx_stats = benchmark_onnx_model(
646
- onnx_session, inputs, args.num_runs, args.warmup_runs
647
- )
648
- tflite_stats = benchmark_tflite_model(
649
- tflite_interpreter,
650
- inputs,
651
- tflite_input_details,
652
- args.num_runs,
653
- args.warmup_runs,
654
- )
655
- tract_stats = None
656
- if tract_model:
657
- tract_stats = benchmark_tract_model(
658
- tract_model, inputs, args.num_runs, args.warmup_runs
659
- )
660
- print_benchmark_results(onnx_stats, tflite_stats, tract_stats)
661
-
662
- # Return exit code
663
- return 0 if match else 1
664
-
665
-
666
- if __name__ == "__main__":
667
- exit(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/optimize.py DELETED
@@ -1,191 +0,0 @@
1
- import onnxscript
2
- import onnx_ir as ir
3
- import onnx_ir.passes.common
4
- import numpy as np
5
- import onnxslim
6
-
7
-
8
- class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase):
9
- def pattern(self, op, x, dft_length):
10
- x = op.Reshape(x, _allow_other_inputs=True)
11
- dft = op.DFT(x, dft_length, _outputs=["dft_output"])
12
- real_part = op.Slice(dft, [0], [1], [-1])
13
- return op.Squeeze(real_part, [-1])
14
-
15
- def rewrite(self, op, x: ir.Value, dft_length: ir.Value, dft_output: ir.Value):
16
- # Get the DFT node attributes
17
- dft_node = dft_output.producer()
18
- assert dft_node is not None
19
-
20
- dft_size = ir.convenience.get_const_tensor(dft_length).numpy().item()
21
-
22
- # Create one-sided DFT matrix (only real part, DC to Nyquist)
23
- # The real part of DFT is: Re(DFT[k]) = sum(x[n] * cos(2*pi*k*n/N))
24
- # For one-sided DFT, we only need frequencies from 0 to Nyquist (dft_size//2 + 1)
25
- num_freqs = dft_size // 2 + 1
26
-
27
- # Vectorized creation of DFT matrix
28
- n = np.arange(dft_size, dtype=np.float32)[:, np.newaxis] # Shape: (dft_size, 1)
29
- k = np.arange(num_freqs, dtype=np.float32)[
30
- np.newaxis, :
31
- ] # Shape: (1, num_freqs)
32
- dft_matrix = np.cos(
33
- 2 * np.pi * k * n / dft_size
34
- ) # Shape: (dft_size, num_freqs)
35
-
36
- # Create constant node for the DFT matrix
37
- dft_matrix = op.initializer(ir.tensor(dft_matrix), name=f"{x.name}_dft_matrix")
38
-
39
- # DFT axis is already at the end, direct matrix multiplication
40
- result = op.MatMul(x, dft_matrix)
41
-
42
- return result
43
-
44
-
45
- class ReplaceSplit(onnxscript.rewriter.RewriteRuleClassBase):
46
- def pattern(self, op, x):
47
- return op.Split(
48
- x, _allow_other_inputs=True, _outputs=["split_out_1", "split_out_2"]
49
- )
50
-
51
- def rewrite(self, op, x: ir.Value, **kwargs):
52
- zero = op.initializer(ir.tensor(np.array([0], dtype=np.int64)), "zero")
53
- batch_size = op.Gather(x, zero)
54
- sample_size = op.initializer(
55
- ir.tensor(np.array([144000], dtype=np.int32)), "sample_size"
56
- )
57
- return batch_size, sample_size
58
-
59
-
60
- class RemoveCast(onnxscript.rewriter.RewriteRuleClassBase):
61
- def pattern(self, op, x):
62
- return op.Cast(x)
63
-
64
- def rewrite(self, op, x: ir.Value, **kwargs):
65
- return op.Identity(x)
66
-
67
-
68
- class RemoveReversedSequenceFork(onnxscript.rewriter.RewriteRuleClassBase):
69
- def pattern(self, op, x, y, scale, bias):
70
- x = op.Transpose(x)
71
- y = op.Transpose(y)
72
- x = op.ReverseSequence(x, _allow_other_inputs=True)
73
- y = op.ReverseSequence(y, _allow_other_inputs=True)
74
- x = op.Unsqueeze(x, _allow_other_inputs=True)
75
- y = op.Unsqueeze(y, _allow_other_inputs=True)
76
- concat = op.Concat(x, y)
77
- mul = op.Mul(concat, scale)
78
- add = op.Add(mul, bias)
79
- return op.Transpose(add)
80
-
81
- def rewrite(self, op, x, y, scale, bias, **kwargs):
82
- # x: batch, 511, 96
83
- neg_one = op.initializer(ir.tensor(np.array([-1], dtype=np.int64)), "neg_one")
84
- int_64_min = op.initializer(
85
- ir.tensor(np.array([-9223372036854775808], dtype=np.int64)), "int_64_min"
86
- )
87
- # slice
88
- x = op.Slice(x, neg_one, int_64_min, neg_one, neg_one)
89
- y = op.Slice(y, neg_one, int_64_min, neg_one, neg_one)
90
- x = op.Unsqueeze(x, neg_one)
91
- y = op.Unsqueeze(y, neg_one)
92
- concat = op.Concat(x, y, axis=3)
93
- # batch, 511, 96, 2
94
- mul = op.Mul(concat, scale)
95
- add = op.Add(mul, bias)
96
- return op.Transpose(add, perm=[0, 3, 2, 1]) # batch, 2, 96, 511
97
-
98
-
99
- model = ir.load("model.onnx")
100
-
101
- # Set dynamic axes
102
- model.graph.inputs[0].shape = ir.Shape(["batch", 144000])
103
- model.graph.outputs[0].shape = ir.Shape(["batch", 6522])
104
-
105
- onnxscript.rewriter.rewrite(
106
- model,
107
- [
108
- ReplaceDftWithMatMulRule().rule(),
109
- ReplaceSplit().rule(),
110
- RemoveCast().rule(),
111
- ],
112
- )
113
-
114
- # Change all int32 initializers to int64
115
- initializers = list(model.graph.initializers.values())
116
- for initializer in initializers:
117
- if initializer.dtype == ir.DataType.INT32:
118
- int32_array = initializer.const_value.numpy()
119
- int64_array = int32_array.astype(np.int64)
120
- new_initializer = ir.val(initializer.name, const_value=ir.tensor(int64_array))
121
- model.graph.initializers.pop(initializer.name)
122
- model.graph.initializers.add(new_initializer)
123
- initializer.replace_all_uses_with(new_initializer)
124
-
125
- onnxscript.optimizer.optimize(
126
- model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
127
- )
128
-
129
-
130
- # Remove Slice-Reshape
131
- def remove_slice_reshape(model: ir.Model):
132
- mul_node = model.graph.node("model/MEL_SPEC1/Mul")
133
- first_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_1")
134
- first_shape = ir.val(
135
- "first_shape", const_value=ir.tensor([-1, 72000, 2], dtype=ir.DataType.INT64)
136
- )
137
- model.graph.initializers.add(first_shape)
138
- second_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_1")
139
- second_shape = ir.val(
140
- "second_shape", const_value=ir.tensor([-1, 18000, 8], dtype=ir.DataType.INT64)
141
- )
142
- model.graph.initializers.add(second_shape)
143
-
144
- third_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_4")
145
- third_shape = ir.val(
146
- "third_shape", const_value=ir.tensor([-1, 511, 2048], dtype=ir.DataType.INT64)
147
- )
148
- model.graph.initializers.add(third_shape)
149
- fourth_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_4")
150
- fourth_shape = ir.val(
151
- "fourth_shape", const_value=ir.tensor([-1, 511, 1024], dtype=ir.DataType.INT64)
152
- )
153
- model.graph.initializers.add(fourth_shape)
154
-
155
- # Replace with Mul-Reshape-Gather
156
- first_reshape.replace_input_with(0, mul_node.outputs[0])
157
- first_reshape.replace_input_with(1, first_shape)
158
- second_reshape.replace_input_with(0, mul_node.outputs[0])
159
- second_reshape.replace_input_with(1, second_shape)
160
- third_reshape.replace_input_with(1, third_shape)
161
- fourth_reshape.replace_input_with(1, fourth_shape)
162
-
163
-
164
- remove_slice_reshape(model)
165
- # Run DCE again
166
- onnxscript.optimizer.optimize(
167
- model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
168
- )
169
-
170
- print("Slimming model...")
171
- model = ir.from_proto(onnxslim.slim(ir.to_proto(model)))
172
-
173
- print("Removing reversed sequence fork...")
174
- onnxscript.rewriter.rewrite(
175
- model,
176
- [
177
- RemoveReversedSequenceFork.rule(),
178
- ],
179
- )
180
-
181
- # Use onnxslim to do shape inference
182
- model = ir.from_proto(onnxslim.slim(ir.to_proto(model)))
183
-
184
- onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model)
185
- model.graph.inputs[0].name = "input"
186
- model.graph.outputs[0].name = "output"
187
- model.ir_version = 10
188
- model.producer_name = "onnx-ir"
189
- model.graph.name = "BirdNET-v2.4"
190
-
191
- ir.save(model, "birdnet.onnx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/optimize_data_model.py DELETED
@@ -1,34 +0,0 @@
1
- """Optimize the data model ONNX file.
2
-
3
- The birdnet_data_model_slim.onnx file is obtained with
4
-
5
- python -m tf2onnx.convert --opset 18 --tflite 'BirdNET_GLOBAL_6K_V2.4_MData_Model_FP16.tflite' --output birdnet_data_model.onnx
6
- onnxslim birdnet_data_model.onnx birdnet_data_model_slim.onnx
7
- """
8
-
9
- import onnxscript.optimizer
10
- import onnx_ir as ir
11
-
12
- model = ir.load("birdnet_data_model_slim.onnx")
13
-
14
- # Remove add-mul-0
15
-
16
- cast_1 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2_1__50")
17
- add_1 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2_1")
18
- add_1.outputs[0].replace_all_uses_with(cast_1.outputs[0])
19
-
20
- cast_2 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2__60")
21
- add_2 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2")
22
- add_2.outputs[0].replace_all_uses_with(cast_2.outputs[0])
23
-
24
- onnxscript.optimizer.optimize(model)
25
-
26
- model.ir_version = 10
27
- model.graph.name = "BirdNET-v2.4-Data_Model"
28
- model.producer_name = "onnx-ir"
29
- model.producer_version = None
30
- model.graph.inputs[0].name = "input"
31
- model.graph.outputs[0].name = "output"
32
- model.graph.outputs[0].shape = ir.Shape(["batch", model.graph.outputs[0].shape[1]])
33
-
34
- ir.save(model, "birdnet_data_model.onnx")