biodiversica commited on
Commit
47cca71
·
verified ·
1 Parent(s): da6c49b

Upload models and script

Browse files
birdnet_backbone.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:130cb8a574141ee97c49528319f9df33f14b3c047cbed86555aa7c9dc7a41417
3
+ size 40160497
extract_backbone.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ Extract and validate BirdNET v2.4 ONNX backbone models.
6
+
7
+ Downloads model.onnx and birdnet.onnx from HuggingFace
8
+ (justinchuby/BirdNET-onnx), strips the classification head, and saves:
9
+ - model_backbone.onnx
10
+ - birdnet_backbone.onnx
11
+
12
+ Also downloads the reference TF SavedModel from Zenodo
13
+ (BirdNET_v2.4_protobuf) and verifies that embeddings match.
14
+ """
15
+
16
+ import io
17
+ import os
18
+ import urllib.request
19
+ import zipfile
20
+
21
+ import huggingface_hub
22
+ import numpy as np
23
+ import onnx
24
+ import onnx.helper
25
+ import onnxruntime as ort
26
+ import tensorflow as tf
27
+
28
+ # Suppress TF C++ info/warning logs; only errors are shown.
29
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
30
+
31
+ # Source HuggingFace repo that hosts the full BirdNET v2.4 ONNX models.
32
+ HF_REPO_ID = "justinchuby/BirdNET-onnx"
33
+
34
+ # Zenodo URL for the BirdNET v2.4 protobuf SavedModel archive.
35
+ ZENODO_URL = "https://zenodo.org/records/15050749/files/BirdNET_v2.4_protobuf.zip?download=1"
36
+
37
+ # Sub-directory inside the Zenodo zip that contains the audio SavedModel.
38
+ AUDIO_MODEL_ZIP_PREFIX = "audio-model/"
39
+
40
+ # Internal tensor name of the global-average-pool output — the last node of
41
+ # the backbone, immediately before the classification dense layer.
42
+ BACKBONE_RAW_OUTPUT = "model/GLOBAL_AVG_POOL/Mean_reduced_0"
43
+
44
+ # Public name exposed by the extracted backbone model.
45
+ BACKBONE_OUTPUT = "embedding"
46
+
47
+ # Expected number of audio samples fed to the model (3 s at 48 kHz).
48
+ BIRDNET_SAMPLE_LEN = 144000
49
+
50
+ # Tolerances for np.testing.assert_allclose when comparing ONNX vs TF outputs.
51
+ # birdnet.onnx is a separate ONNX export whose weights differ slightly from the
52
+ # reference SavedModel, so a loose tolerance is used to accommodate both variants.
53
+ RTOL = 1e-3
54
+ ATOL = 1e-3
55
+
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # Download helpers
59
+ # ---------------------------------------------------------------------------
60
+
61
+ def download_onnx_models(output_dir: str) -> dict[str, str]:
62
+ """Download model.onnx and birdnet.onnx from HuggingFace.
63
+
64
+ Returns a dict mapping filename -> absolute local path.
65
+ """
66
+ filenames = ["model.onnx", "birdnet.onnx"]
67
+ paths = {}
68
+ for fname in filenames:
69
+ path = huggingface_hub.hf_hub_download(
70
+ repo_id=HF_REPO_ID,
71
+ filename=fname,
72
+ local_dir=output_dir,
73
+ )
74
+ paths[fname] = path
75
+ print(f"Downloaded {fname} -> {path}")
76
+ return paths
77
+
78
+
79
+ def download_pb_model(output_dir: str) -> str:
80
+ """Download BirdNET_v2.4_protobuf.zip from Zenodo and extract audio-model.
81
+
82
+ The zip contains two SavedModel sub-directories; only audio-model is
83
+ extracted since that is the one whose embeddings signature we compare
84
+ against.
85
+
86
+ Returns the path to the extracted audio-model SavedModel directory.
87
+ """
88
+ audio_model_dir = os.path.join(output_dir, "audio-model")
89
+ if os.path.isdir(audio_model_dir):
90
+ print(f"Protobuf already extracted -> {audio_model_dir}")
91
+ return audio_model_dir
92
+
93
+ print(f"Downloading BirdNET protobuf from Zenodo...")
94
+ with urllib.request.urlopen(ZENODO_URL) as response:
95
+ data = response.read()
96
+ print(f"Download complete ({len(data) / 1_000_000:.1f} MB)")
97
+
98
+ with zipfile.ZipFile(io.BytesIO(data)) as zf:
99
+ members = [m for m in zf.namelist() if m.startswith(AUDIO_MODEL_ZIP_PREFIX)]
100
+ zf.extractall(output_dir, members=members)
101
+
102
+ print(f"Extracted audio-model -> {audio_model_dir}")
103
+ return audio_model_dir
104
+
105
+
106
+ # ---------------------------------------------------------------------------
107
+ # Backbone extraction
108
+ # ---------------------------------------------------------------------------
109
+
110
+ def _extract(
111
+ src_path: str,
112
+ out_path: str,
113
+ input_names: list[str],
114
+ output_names: list[str],
115
+ output_renames: dict[str, str] | None = None,
116
+ ) -> None:
117
+ """Extract a subgraph from an ONNX model using backwards BFS and save it.
118
+
119
+ Starting from `output_names`, the algorithm traces each tensor back through
120
+ the graph to find every node that contributes to those outputs. Nodes that
121
+ only serve the classification head (i.e., downstream of `output_names`) are
122
+ never reached and are therefore excluded from the new model.
123
+
124
+ Args:
125
+ src_path: Path to the source ONNX model file.
126
+ out_path: Destination path for the extracted subgraph.
127
+ input_names: Graph-level input tensor names to keep (weight initializers
128
+ that appear in graph.input are automatically excluded).
129
+ output_names: Tensor names that define the extraction boundary — the new
130
+ model will produce exactly these tensors as outputs.
131
+ output_renames: Optional mapping {old_name: new_name} applied to the
132
+ output tensors of the producing nodes before saving.
133
+ """
134
+ model = onnx.load(src_path)
135
+ renames = output_renames or {}
136
+
137
+ # Build a reverse lookup: tensor name -> the node that produces it.
138
+ tensor_to_node: dict = {}
139
+ for node in model.graph.node:
140
+ for out in node.output:
141
+ if out:
142
+ tensor_to_node[out] = node
143
+
144
+ # BFS backwards from the requested outputs to collect all contributing nodes.
145
+ visited_node_ids: set = set()
146
+ queue = list(output_names)
147
+ while queue:
148
+ tensor = queue.pop()
149
+ node = tensor_to_node.get(tensor)
150
+ if node is None or id(node) in visited_node_ids:
151
+ continue
152
+ visited_node_ids.add(id(node))
153
+ for inp in node.input:
154
+ if inp:
155
+ queue.append(inp)
156
+
157
+ # Re-filter from the original node list to preserve topological order.
158
+ filtered_nodes = [n for n in model.graph.node if id(n) in visited_node_ids]
159
+
160
+ # Apply any requested output renames directly on the producing nodes.
161
+ for node in filtered_nodes:
162
+ for i, out in enumerate(node.output):
163
+ if out in renames:
164
+ node.output[i] = renames[out]
165
+
166
+ # Collect only the initializers consumed by the retained nodes.
167
+ needed_tensors: set = set()
168
+ for node in filtered_nodes:
169
+ needed_tensors.update(i for i in node.input if i)
170
+ filtered_inits = [i for i in model.graph.initializer if i.name in needed_tensors]
171
+
172
+ # Keep only the declared data inputs (skip weight aliases in graph.input).
173
+ input_name_set = set(input_names)
174
+ graph_inputs = [vi for vi in model.graph.input if vi.name in input_name_set]
175
+
176
+ # Build output ValueInfoProtos.
177
+ existing_out = {o.name: o for o in model.graph.output}
178
+ graph_outputs = []
179
+ for name in output_names:
180
+ final_name = renames.get(name, name)
181
+ if final_name in existing_out:
182
+ graph_outputs.append(existing_out[final_name])
183
+ else:
184
+ graph_outputs.append(
185
+ onnx.helper.make_tensor_value_info(final_name, onnx.TensorProto.FLOAT, None)
186
+ )
187
+
188
+ new_graph = onnx.helper.make_graph(
189
+ filtered_nodes,
190
+ "backbone",
191
+ graph_inputs,
192
+ graph_outputs,
193
+ initializer=filtered_inits,
194
+ )
195
+ new_model = onnx.helper.make_model(new_graph)
196
+ new_model.ir_version = model.ir_version
197
+ del new_model.opset_import[:]
198
+ new_model.opset_import.extend(model.opset_import)
199
+ onnx.save(new_model, out_path)
200
+
201
+
202
+ def _get_graph_input_names(onnx_path: str) -> list[str]:
203
+ """Return the true data-input tensor names for an ONNX model."""
204
+ model = onnx.load(onnx_path)
205
+ init_names = {i.name for i in model.graph.initializer}
206
+ return [vi.name for vi in model.graph.input if vi.name not in init_names]
207
+
208
+
209
+ def extract_backbone(src_path: str, out_path: str) -> str:
210
+ """Extract the backbone subgraph from a full BirdNET ONNX model and save it.
211
+
212
+ Traces backwards from BACKBONE_RAW_OUTPUT (the global average pool tensor)
213
+ and renames it to BACKBONE_OUTPUT ("embedding") in the saved file.
214
+
215
+ Returns out_path for chaining.
216
+ """
217
+ input_names = _get_graph_input_names(src_path)
218
+ _extract(
219
+ src_path,
220
+ out_path,
221
+ input_names,
222
+ [BACKBONE_RAW_OUTPUT],
223
+ output_renames={BACKBONE_RAW_OUTPUT: BACKBONE_OUTPUT},
224
+ )
225
+
226
+ model = onnx.load(out_path)
227
+ print(f"Backbone saved -> {out_path}")
228
+ print(f" inputs : {input_names}")
229
+ print(f" outputs: {[o.name for o in model.graph.output]}")
230
+ return out_path
231
+
232
+
233
+ # ---------------------------------------------------------------------------
234
+ # Comparison helpers
235
+ # ---------------------------------------------------------------------------
236
+
237
+ def _make_audio(length: int, seed: int = 42) -> np.ndarray:
238
+ """Generate a reproducible Gaussian noise waveform shaped (1, length)."""
239
+ rng = np.random.default_rng(seed)
240
+ return rng.standard_normal((1, length)).astype(np.float32)
241
+
242
+
243
+ def _onnx_embedding(onnx_path: str, audio: np.ndarray) -> np.ndarray:
244
+ """Run inference on a backbone ONNX model and return the embedding array."""
245
+ input_names = _get_graph_input_names(onnx_path)
246
+ sess = ort.InferenceSession(onnx_path)
247
+ (emb,) = sess.run([BACKBONE_OUTPUT], {input_names[0]: audio})
248
+ return emb
249
+
250
+
251
+ def _pb_embedding(pb_dir: str, audio: np.ndarray) -> np.ndarray:
252
+ """Run inference on the BirdNET TF SavedModel and return the embedding array.
253
+
254
+ The audio-model SavedModel exposes an "embeddings" signature whose output
255
+ dict contains an "embeddings" key, used here as the ground-truth reference.
256
+ """
257
+ model = tf.saved_model.load(pb_dir)
258
+ audio_tf = tf.constant(audio)
259
+ return model.signatures["embeddings"](inputs=audio_tf)["embeddings"].numpy()
260
+
261
+
262
+ # ---------------------------------------------------------------------------
263
+ # Main
264
+ # ---------------------------------------------------------------------------
265
+
266
+ def main():
267
+ """End-to-end pipeline: download → extract → compare."""
268
+ out_dir = os.path.dirname(os.path.abspath(__file__))
269
+
270
+ # --- Step 1: download source models ---
271
+ print("=== Downloading models ===")
272
+ onnx_paths = download_onnx_models(out_dir)
273
+ pb_dir = download_pb_model(out_dir)
274
+
275
+ # --- Step 2: extract backbone from each ONNX variant ---
276
+ print("\n=== Extracting backbones ===")
277
+ backbone_paths = {}
278
+ for fname, src in onnx_paths.items():
279
+ stem = fname.replace(".onnx", "")
280
+ out_path = os.path.join(out_dir, f"{stem}_backbone.onnx")
281
+ backbone_paths[stem] = extract_backbone(src, out_path)
282
+
283
+ # --- Step 3: numerical comparison against the TF SavedModel reference ---
284
+ print("\n=== Comparing embeddings against Zenodo TF SavedModel ===")
285
+ audio = _make_audio(BIRDNET_SAMPLE_LEN)
286
+
287
+ pb_emb = _pb_embedding(pb_dir, audio)
288
+ print(f"PB embedding shape: {pb_emb.shape}")
289
+
290
+ for stem, path in backbone_paths.items():
291
+ onnx_emb = _onnx_embedding(path, audio)
292
+ diff = np.abs(onnx_emb - pb_emb)
293
+ print(f"\n{stem}_backbone.onnx:")
294
+ print(f" ONNX embedding shape: {onnx_emb.shape}")
295
+ print(f" |diff| mean={diff.mean():.6e} max={diff.max():.6e}")
296
+ try:
297
+ np.testing.assert_allclose(onnx_emb, pb_emb, rtol=RTOL, atol=ATOL)
298
+ print(f" Embeddings match PB reference with rtol={RTOL:.0e}, atol={ATOL:.0e} PASSED")
299
+ except AssertionError as e:
300
+ print(f" Embeddings differ from PB reference FAILED\n {e}")
301
+
302
+ print("\nDone.")
303
+
304
+
305
+ if __name__ == "__main__":
306
+ main()
model_backbone.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:015344d3662d262e56dc52975523f1ea9b5e3852c1fd8ceb789f5cbbfba1dc25
3
+ size 24971712
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "birdnet-onnx-backbone"
3
+ version = "0.1.0"
4
+ description = "Backbone-only ONNX exports of BirdNET v2.4 bird sound classifier"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "huggingface-hub>=0.23.0",
9
+ "numpy>=2.0.0",
10
+ "onnx>=1.16.0",
11
+ "onnxruntime>=1.18.0",
12
+ "tensorflow-cpu>=2.16.0",
13
+ ]