BirdNET-onnx-backbone / extract_backbone.py
biodiversica's picture
Upload models and script
47cca71 verified
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
"""
Extract and validate BirdNET v2.4 ONNX backbone models.
Downloads model.onnx and birdnet.onnx from HuggingFace
(justinchuby/BirdNET-onnx), strips the classification head, and saves:
- model_backbone.onnx
- birdnet_backbone.onnx
Also downloads the reference TF SavedModel from Zenodo
(BirdNET_v2.4_protobuf) and verifies that embeddings match.
"""
import io
import os
import urllib.request
import zipfile
import huggingface_hub
import numpy as np
import onnx
import onnx.helper
import onnxruntime as ort
import tensorflow as tf
# Suppress TF C++ info/warning logs; only errors are shown.
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# Source HuggingFace repo that hosts the full BirdNET v2.4 ONNX models.
HF_REPO_ID = "justinchuby/BirdNET-onnx"
# Zenodo URL for the BirdNET v2.4 protobuf SavedModel archive.
ZENODO_URL = "https://zenodo.org/records/15050749/files/BirdNET_v2.4_protobuf.zip?download=1"
# Sub-directory inside the Zenodo zip that contains the audio SavedModel.
AUDIO_MODEL_ZIP_PREFIX = "audio-model/"
# Internal tensor name of the global-average-pool output — the last node of
# the backbone, immediately before the classification dense layer.
BACKBONE_RAW_OUTPUT = "model/GLOBAL_AVG_POOL/Mean_reduced_0"
# Public name exposed by the extracted backbone model.
BACKBONE_OUTPUT = "embedding"
# Expected number of audio samples fed to the model (3 s at 48 kHz).
BIRDNET_SAMPLE_LEN = 144000
# Tolerances for np.testing.assert_allclose when comparing ONNX vs TF outputs.
# birdnet.onnx is a separate ONNX export whose weights differ slightly from the
# reference SavedModel, so a loose tolerance is used to accommodate both variants.
RTOL = 1e-3
ATOL = 1e-3
# ---------------------------------------------------------------------------
# Download helpers
# ---------------------------------------------------------------------------
def download_onnx_models(output_dir: str) -> dict[str, str]:
"""Download model.onnx and birdnet.onnx from HuggingFace.
Returns a dict mapping filename -> absolute local path.
"""
filenames = ["model.onnx", "birdnet.onnx"]
paths = {}
for fname in filenames:
path = huggingface_hub.hf_hub_download(
repo_id=HF_REPO_ID,
filename=fname,
local_dir=output_dir,
)
paths[fname] = path
print(f"Downloaded {fname} -> {path}")
return paths
def download_pb_model(output_dir: str) -> str:
"""Download BirdNET_v2.4_protobuf.zip from Zenodo and extract audio-model.
The zip contains two SavedModel sub-directories; only audio-model is
extracted since that is the one whose embeddings signature we compare
against.
Returns the path to the extracted audio-model SavedModel directory.
"""
audio_model_dir = os.path.join(output_dir, "audio-model")
if os.path.isdir(audio_model_dir):
print(f"Protobuf already extracted -> {audio_model_dir}")
return audio_model_dir
print(f"Downloading BirdNET protobuf from Zenodo...")
with urllib.request.urlopen(ZENODO_URL) as response:
data = response.read()
print(f"Download complete ({len(data) / 1_000_000:.1f} MB)")
with zipfile.ZipFile(io.BytesIO(data)) as zf:
members = [m for m in zf.namelist() if m.startswith(AUDIO_MODEL_ZIP_PREFIX)]
zf.extractall(output_dir, members=members)
print(f"Extracted audio-model -> {audio_model_dir}")
return audio_model_dir
# ---------------------------------------------------------------------------
# Backbone extraction
# ---------------------------------------------------------------------------
def _extract(
src_path: str,
out_path: str,
input_names: list[str],
output_names: list[str],
output_renames: dict[str, str] | None = None,
) -> None:
"""Extract a subgraph from an ONNX model using backwards BFS and save it.
Starting from `output_names`, the algorithm traces each tensor back through
the graph to find every node that contributes to those outputs. Nodes that
only serve the classification head (i.e., downstream of `output_names`) are
never reached and are therefore excluded from the new model.
Args:
src_path: Path to the source ONNX model file.
out_path: Destination path for the extracted subgraph.
input_names: Graph-level input tensor names to keep (weight initializers
that appear in graph.input are automatically excluded).
output_names: Tensor names that define the extraction boundary — the new
model will produce exactly these tensors as outputs.
output_renames: Optional mapping {old_name: new_name} applied to the
output tensors of the producing nodes before saving.
"""
model = onnx.load(src_path)
renames = output_renames or {}
# Build a reverse lookup: tensor name -> the node that produces it.
tensor_to_node: dict = {}
for node in model.graph.node:
for out in node.output:
if out:
tensor_to_node[out] = node
# BFS backwards from the requested outputs to collect all contributing nodes.
visited_node_ids: set = set()
queue = list(output_names)
while queue:
tensor = queue.pop()
node = tensor_to_node.get(tensor)
if node is None or id(node) in visited_node_ids:
continue
visited_node_ids.add(id(node))
for inp in node.input:
if inp:
queue.append(inp)
# Re-filter from the original node list to preserve topological order.
filtered_nodes = [n for n in model.graph.node if id(n) in visited_node_ids]
# Apply any requested output renames directly on the producing nodes.
for node in filtered_nodes:
for i, out in enumerate(node.output):
if out in renames:
node.output[i] = renames[out]
# Collect only the initializers consumed by the retained nodes.
needed_tensors: set = set()
for node in filtered_nodes:
needed_tensors.update(i for i in node.input if i)
filtered_inits = [i for i in model.graph.initializer if i.name in needed_tensors]
# Keep only the declared data inputs (skip weight aliases in graph.input).
input_name_set = set(input_names)
graph_inputs = [vi for vi in model.graph.input if vi.name in input_name_set]
# Build output ValueInfoProtos.
existing_out = {o.name: o for o in model.graph.output}
graph_outputs = []
for name in output_names:
final_name = renames.get(name, name)
if final_name in existing_out:
graph_outputs.append(existing_out[final_name])
else:
graph_outputs.append(
onnx.helper.make_tensor_value_info(final_name, onnx.TensorProto.FLOAT, None)
)
new_graph = onnx.helper.make_graph(
filtered_nodes,
"backbone",
graph_inputs,
graph_outputs,
initializer=filtered_inits,
)
new_model = onnx.helper.make_model(new_graph)
new_model.ir_version = model.ir_version
del new_model.opset_import[:]
new_model.opset_import.extend(model.opset_import)
onnx.save(new_model, out_path)
def _get_graph_input_names(onnx_path: str) -> list[str]:
"""Return the true data-input tensor names for an ONNX model."""
model = onnx.load(onnx_path)
init_names = {i.name for i in model.graph.initializer}
return [vi.name for vi in model.graph.input if vi.name not in init_names]
def extract_backbone(src_path: str, out_path: str) -> str:
"""Extract the backbone subgraph from a full BirdNET ONNX model and save it.
Traces backwards from BACKBONE_RAW_OUTPUT (the global average pool tensor)
and renames it to BACKBONE_OUTPUT ("embedding") in the saved file.
Returns out_path for chaining.
"""
input_names = _get_graph_input_names(src_path)
_extract(
src_path,
out_path,
input_names,
[BACKBONE_RAW_OUTPUT],
output_renames={BACKBONE_RAW_OUTPUT: BACKBONE_OUTPUT},
)
model = onnx.load(out_path)
print(f"Backbone saved -> {out_path}")
print(f" inputs : {input_names}")
print(f" outputs: {[o.name for o in model.graph.output]}")
return out_path
# ---------------------------------------------------------------------------
# Comparison helpers
# ---------------------------------------------------------------------------
def _make_audio(length: int, seed: int = 42) -> np.ndarray:
"""Generate a reproducible Gaussian noise waveform shaped (1, length)."""
rng = np.random.default_rng(seed)
return rng.standard_normal((1, length)).astype(np.float32)
def _onnx_embedding(onnx_path: str, audio: np.ndarray) -> np.ndarray:
"""Run inference on a backbone ONNX model and return the embedding array."""
input_names = _get_graph_input_names(onnx_path)
sess = ort.InferenceSession(onnx_path)
(emb,) = sess.run([BACKBONE_OUTPUT], {input_names[0]: audio})
return emb
def _pb_embedding(pb_dir: str, audio: np.ndarray) -> np.ndarray:
"""Run inference on the BirdNET TF SavedModel and return the embedding array.
The audio-model SavedModel exposes an "embeddings" signature whose output
dict contains an "embeddings" key, used here as the ground-truth reference.
"""
model = tf.saved_model.load(pb_dir)
audio_tf = tf.constant(audio)
return model.signatures["embeddings"](inputs=audio_tf)["embeddings"].numpy()
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
"""End-to-end pipeline: download → extract → compare."""
out_dir = os.path.dirname(os.path.abspath(__file__))
# --- Step 1: download source models ---
print("=== Downloading models ===")
onnx_paths = download_onnx_models(out_dir)
pb_dir = download_pb_model(out_dir)
# --- Step 2: extract backbone from each ONNX variant ---
print("\n=== Extracting backbones ===")
backbone_paths = {}
for fname, src in onnx_paths.items():
stem = fname.replace(".onnx", "")
out_path = os.path.join(out_dir, f"{stem}_backbone.onnx")
backbone_paths[stem] = extract_backbone(src, out_path)
# --- Step 3: numerical comparison against the TF SavedModel reference ---
print("\n=== Comparing embeddings against Zenodo TF SavedModel ===")
audio = _make_audio(BIRDNET_SAMPLE_LEN)
pb_emb = _pb_embedding(pb_dir, audio)
print(f"PB embedding shape: {pb_emb.shape}")
for stem, path in backbone_paths.items():
onnx_emb = _onnx_embedding(path, audio)
diff = np.abs(onnx_emb - pb_emb)
print(f"\n{stem}_backbone.onnx:")
print(f" ONNX embedding shape: {onnx_emb.shape}")
print(f" |diff| mean={diff.mean():.6e} max={diff.max():.6e}")
try:
np.testing.assert_allclose(onnx_emb, pb_emb, rtol=RTOL, atol=ATOL)
print(f" Embeddings match PB reference with rtol={RTOL:.0e}, atol={ATOL:.0e} PASSED")
except AssertionError as e:
print(f" Embeddings differ from PB reference FAILED\n {e}")
print("\nDone.")
if __name__ == "__main__":
main()