#!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 """ Extract and validate BirdNET v2.4 ONNX backbone models. Downloads model.onnx and birdnet.onnx from HuggingFace (justinchuby/BirdNET-onnx), strips the classification head, and saves: - model_backbone.onnx - birdnet_backbone.onnx Also downloads the reference TF SavedModel from Zenodo (BirdNET_v2.4_protobuf) and verifies that embeddings match. """ import io import os import urllib.request import zipfile import huggingface_hub import numpy as np import onnx import onnx.helper import onnxruntime as ort import tensorflow as tf # Suppress TF C++ info/warning logs; only errors are shown. os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # Source HuggingFace repo that hosts the full BirdNET v2.4 ONNX models. HF_REPO_ID = "justinchuby/BirdNET-onnx" # Zenodo URL for the BirdNET v2.4 protobuf SavedModel archive. ZENODO_URL = "https://zenodo.org/records/15050749/files/BirdNET_v2.4_protobuf.zip?download=1" # Sub-directory inside the Zenodo zip that contains the audio SavedModel. AUDIO_MODEL_ZIP_PREFIX = "audio-model/" # Internal tensor name of the global-average-pool output — the last node of # the backbone, immediately before the classification dense layer. BACKBONE_RAW_OUTPUT = "model/GLOBAL_AVG_POOL/Mean_reduced_0" # Public name exposed by the extracted backbone model. BACKBONE_OUTPUT = "embedding" # Expected number of audio samples fed to the model (3 s at 48 kHz). BIRDNET_SAMPLE_LEN = 144000 # Tolerances for np.testing.assert_allclose when comparing ONNX vs TF outputs. # birdnet.onnx is a separate ONNX export whose weights differ slightly from the # reference SavedModel, so a loose tolerance is used to accommodate both variants. RTOL = 1e-3 ATOL = 1e-3 # --------------------------------------------------------------------------- # Download helpers # --------------------------------------------------------------------------- def download_onnx_models(output_dir: str) -> dict[str, str]: """Download model.onnx and birdnet.onnx from HuggingFace. Returns a dict mapping filename -> absolute local path. """ filenames = ["model.onnx", "birdnet.onnx"] paths = {} for fname in filenames: path = huggingface_hub.hf_hub_download( repo_id=HF_REPO_ID, filename=fname, local_dir=output_dir, ) paths[fname] = path print(f"Downloaded {fname} -> {path}") return paths def download_pb_model(output_dir: str) -> str: """Download BirdNET_v2.4_protobuf.zip from Zenodo and extract audio-model. The zip contains two SavedModel sub-directories; only audio-model is extracted since that is the one whose embeddings signature we compare against. Returns the path to the extracted audio-model SavedModel directory. """ audio_model_dir = os.path.join(output_dir, "audio-model") if os.path.isdir(audio_model_dir): print(f"Protobuf already extracted -> {audio_model_dir}") return audio_model_dir print(f"Downloading BirdNET protobuf from Zenodo...") with urllib.request.urlopen(ZENODO_URL) as response: data = response.read() print(f"Download complete ({len(data) / 1_000_000:.1f} MB)") with zipfile.ZipFile(io.BytesIO(data)) as zf: members = [m for m in zf.namelist() if m.startswith(AUDIO_MODEL_ZIP_PREFIX)] zf.extractall(output_dir, members=members) print(f"Extracted audio-model -> {audio_model_dir}") return audio_model_dir # --------------------------------------------------------------------------- # Backbone extraction # --------------------------------------------------------------------------- def _extract( src_path: str, out_path: str, input_names: list[str], output_names: list[str], output_renames: dict[str, str] | None = None, ) -> None: """Extract a subgraph from an ONNX model using backwards BFS and save it. Starting from `output_names`, the algorithm traces each tensor back through the graph to find every node that contributes to those outputs. Nodes that only serve the classification head (i.e., downstream of `output_names`) are never reached and are therefore excluded from the new model. Args: src_path: Path to the source ONNX model file. out_path: Destination path for the extracted subgraph. input_names: Graph-level input tensor names to keep (weight initializers that appear in graph.input are automatically excluded). output_names: Tensor names that define the extraction boundary — the new model will produce exactly these tensors as outputs. output_renames: Optional mapping {old_name: new_name} applied to the output tensors of the producing nodes before saving. """ model = onnx.load(src_path) renames = output_renames or {} # Build a reverse lookup: tensor name -> the node that produces it. tensor_to_node: dict = {} for node in model.graph.node: for out in node.output: if out: tensor_to_node[out] = node # BFS backwards from the requested outputs to collect all contributing nodes. visited_node_ids: set = set() queue = list(output_names) while queue: tensor = queue.pop() node = tensor_to_node.get(tensor) if node is None or id(node) in visited_node_ids: continue visited_node_ids.add(id(node)) for inp in node.input: if inp: queue.append(inp) # Re-filter from the original node list to preserve topological order. filtered_nodes = [n for n in model.graph.node if id(n) in visited_node_ids] # Apply any requested output renames directly on the producing nodes. for node in filtered_nodes: for i, out in enumerate(node.output): if out in renames: node.output[i] = renames[out] # Collect only the initializers consumed by the retained nodes. needed_tensors: set = set() for node in filtered_nodes: needed_tensors.update(i for i in node.input if i) filtered_inits = [i for i in model.graph.initializer if i.name in needed_tensors] # Keep only the declared data inputs (skip weight aliases in graph.input). input_name_set = set(input_names) graph_inputs = [vi for vi in model.graph.input if vi.name in input_name_set] # Build output ValueInfoProtos. existing_out = {o.name: o for o in model.graph.output} graph_outputs = [] for name in output_names: final_name = renames.get(name, name) if final_name in existing_out: graph_outputs.append(existing_out[final_name]) else: graph_outputs.append( onnx.helper.make_tensor_value_info(final_name, onnx.TensorProto.FLOAT, None) ) new_graph = onnx.helper.make_graph( filtered_nodes, "backbone", graph_inputs, graph_outputs, initializer=filtered_inits, ) new_model = onnx.helper.make_model(new_graph) new_model.ir_version = model.ir_version del new_model.opset_import[:] new_model.opset_import.extend(model.opset_import) onnx.save(new_model, out_path) def _get_graph_input_names(onnx_path: str) -> list[str]: """Return the true data-input tensor names for an ONNX model.""" model = onnx.load(onnx_path) init_names = {i.name for i in model.graph.initializer} return [vi.name for vi in model.graph.input if vi.name not in init_names] def extract_backbone(src_path: str, out_path: str) -> str: """Extract the backbone subgraph from a full BirdNET ONNX model and save it. Traces backwards from BACKBONE_RAW_OUTPUT (the global average pool tensor) and renames it to BACKBONE_OUTPUT ("embedding") in the saved file. Returns out_path for chaining. """ input_names = _get_graph_input_names(src_path) _extract( src_path, out_path, input_names, [BACKBONE_RAW_OUTPUT], output_renames={BACKBONE_RAW_OUTPUT: BACKBONE_OUTPUT}, ) model = onnx.load(out_path) print(f"Backbone saved -> {out_path}") print(f" inputs : {input_names}") print(f" outputs: {[o.name for o in model.graph.output]}") return out_path # --------------------------------------------------------------------------- # Comparison helpers # --------------------------------------------------------------------------- def _make_audio(length: int, seed: int = 42) -> np.ndarray: """Generate a reproducible Gaussian noise waveform shaped (1, length).""" rng = np.random.default_rng(seed) return rng.standard_normal((1, length)).astype(np.float32) def _onnx_embedding(onnx_path: str, audio: np.ndarray) -> np.ndarray: """Run inference on a backbone ONNX model and return the embedding array.""" input_names = _get_graph_input_names(onnx_path) sess = ort.InferenceSession(onnx_path) (emb,) = sess.run([BACKBONE_OUTPUT], {input_names[0]: audio}) return emb def _pb_embedding(pb_dir: str, audio: np.ndarray) -> np.ndarray: """Run inference on the BirdNET TF SavedModel and return the embedding array. The audio-model SavedModel exposes an "embeddings" signature whose output dict contains an "embeddings" key, used here as the ground-truth reference. """ model = tf.saved_model.load(pb_dir) audio_tf = tf.constant(audio) return model.signatures["embeddings"](inputs=audio_tf)["embeddings"].numpy() # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): """End-to-end pipeline: download → extract → compare.""" out_dir = os.path.dirname(os.path.abspath(__file__)) # --- Step 1: download source models --- print("=== Downloading models ===") onnx_paths = download_onnx_models(out_dir) pb_dir = download_pb_model(out_dir) # --- Step 2: extract backbone from each ONNX variant --- print("\n=== Extracting backbones ===") backbone_paths = {} for fname, src in onnx_paths.items(): stem = fname.replace(".onnx", "") out_path = os.path.join(out_dir, f"{stem}_backbone.onnx") backbone_paths[stem] = extract_backbone(src, out_path) # --- Step 3: numerical comparison against the TF SavedModel reference --- print("\n=== Comparing embeddings against Zenodo TF SavedModel ===") audio = _make_audio(BIRDNET_SAMPLE_LEN) pb_emb = _pb_embedding(pb_dir, audio) print(f"PB embedding shape: {pb_emb.shape}") for stem, path in backbone_paths.items(): onnx_emb = _onnx_embedding(path, audio) diff = np.abs(onnx_emb - pb_emb) print(f"\n{stem}_backbone.onnx:") print(f" ONNX embedding shape: {onnx_emb.shape}") print(f" |diff| mean={diff.mean():.6e} max={diff.max():.6e}") try: np.testing.assert_allclose(onnx_emb, pb_emb, rtol=RTOL, atol=ATOL) print(f" Embeddings match PB reference with rtol={RTOL:.0e}, atol={ATOL:.0e} PASSED") except AssertionError as e: print(f" Embeddings differ from PB reference FAILED\n {e}") print("\nDone.") if __name__ == "__main__": main()