| |
| |
|
|
| """ |
| Extract and validate BirdNET v2.4 ONNX backbone models. |
| |
| Downloads model.onnx and birdnet.onnx from HuggingFace |
| (justinchuby/BirdNET-onnx), strips the classification head, and saves: |
| - model_backbone.onnx |
| - birdnet_backbone.onnx |
| |
| Also downloads the reference TF SavedModel from Zenodo |
| (BirdNET_v2.4_protobuf) and verifies that embeddings match. |
| """ |
|
|
| import io |
| import os |
| import urllib.request |
| import zipfile |
|
|
| import huggingface_hub |
| import numpy as np |
| import onnx |
| import onnx.helper |
| import onnxruntime as ort |
| import tensorflow as tf |
|
|
| |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
|
|
| |
| HF_REPO_ID = "justinchuby/BirdNET-onnx" |
|
|
| |
| ZENODO_URL = "https://zenodo.org/records/15050749/files/BirdNET_v2.4_protobuf.zip?download=1" |
|
|
| |
| AUDIO_MODEL_ZIP_PREFIX = "audio-model/" |
|
|
| |
| |
| BACKBONE_RAW_OUTPUT = "model/GLOBAL_AVG_POOL/Mean_reduced_0" |
|
|
| |
| BACKBONE_OUTPUT = "embedding" |
|
|
| |
| BIRDNET_SAMPLE_LEN = 144000 |
|
|
| |
| |
| |
| RTOL = 1e-3 |
| ATOL = 1e-3 |
|
|
|
|
| |
| |
| |
|
|
| def download_onnx_models(output_dir: str) -> dict[str, str]: |
| """Download model.onnx and birdnet.onnx from HuggingFace. |
| |
| Returns a dict mapping filename -> absolute local path. |
| """ |
| filenames = ["model.onnx", "birdnet.onnx"] |
| paths = {} |
| for fname in filenames: |
| path = huggingface_hub.hf_hub_download( |
| repo_id=HF_REPO_ID, |
| filename=fname, |
| local_dir=output_dir, |
| ) |
| paths[fname] = path |
| print(f"Downloaded {fname} -> {path}") |
| return paths |
|
|
|
|
| def download_pb_model(output_dir: str) -> str: |
| """Download BirdNET_v2.4_protobuf.zip from Zenodo and extract audio-model. |
| |
| The zip contains two SavedModel sub-directories; only audio-model is |
| extracted since that is the one whose embeddings signature we compare |
| against. |
| |
| Returns the path to the extracted audio-model SavedModel directory. |
| """ |
| audio_model_dir = os.path.join(output_dir, "audio-model") |
| if os.path.isdir(audio_model_dir): |
| print(f"Protobuf already extracted -> {audio_model_dir}") |
| return audio_model_dir |
|
|
| print(f"Downloading BirdNET protobuf from Zenodo...") |
| with urllib.request.urlopen(ZENODO_URL) as response: |
| data = response.read() |
| print(f"Download complete ({len(data) / 1_000_000:.1f} MB)") |
|
|
| with zipfile.ZipFile(io.BytesIO(data)) as zf: |
| members = [m for m in zf.namelist() if m.startswith(AUDIO_MODEL_ZIP_PREFIX)] |
| zf.extractall(output_dir, members=members) |
|
|
| print(f"Extracted audio-model -> {audio_model_dir}") |
| return audio_model_dir |
|
|
|
|
| |
| |
| |
|
|
| def _extract( |
| src_path: str, |
| out_path: str, |
| input_names: list[str], |
| output_names: list[str], |
| output_renames: dict[str, str] | None = None, |
| ) -> None: |
| """Extract a subgraph from an ONNX model using backwards BFS and save it. |
| |
| Starting from `output_names`, the algorithm traces each tensor back through |
| the graph to find every node that contributes to those outputs. Nodes that |
| only serve the classification head (i.e., downstream of `output_names`) are |
| never reached and are therefore excluded from the new model. |
| |
| Args: |
| src_path: Path to the source ONNX model file. |
| out_path: Destination path for the extracted subgraph. |
| input_names: Graph-level input tensor names to keep (weight initializers |
| that appear in graph.input are automatically excluded). |
| output_names: Tensor names that define the extraction boundary — the new |
| model will produce exactly these tensors as outputs. |
| output_renames: Optional mapping {old_name: new_name} applied to the |
| output tensors of the producing nodes before saving. |
| """ |
| model = onnx.load(src_path) |
| renames = output_renames or {} |
|
|
| |
| tensor_to_node: dict = {} |
| for node in model.graph.node: |
| for out in node.output: |
| if out: |
| tensor_to_node[out] = node |
|
|
| |
| visited_node_ids: set = set() |
| queue = list(output_names) |
| while queue: |
| tensor = queue.pop() |
| node = tensor_to_node.get(tensor) |
| if node is None or id(node) in visited_node_ids: |
| continue |
| visited_node_ids.add(id(node)) |
| for inp in node.input: |
| if inp: |
| queue.append(inp) |
|
|
| |
| filtered_nodes = [n for n in model.graph.node if id(n) in visited_node_ids] |
|
|
| |
| for node in filtered_nodes: |
| for i, out in enumerate(node.output): |
| if out in renames: |
| node.output[i] = renames[out] |
|
|
| |
| needed_tensors: set = set() |
| for node in filtered_nodes: |
| needed_tensors.update(i for i in node.input if i) |
| filtered_inits = [i for i in model.graph.initializer if i.name in needed_tensors] |
|
|
| |
| input_name_set = set(input_names) |
| graph_inputs = [vi for vi in model.graph.input if vi.name in input_name_set] |
|
|
| |
| existing_out = {o.name: o for o in model.graph.output} |
| graph_outputs = [] |
| for name in output_names: |
| final_name = renames.get(name, name) |
| if final_name in existing_out: |
| graph_outputs.append(existing_out[final_name]) |
| else: |
| graph_outputs.append( |
| onnx.helper.make_tensor_value_info(final_name, onnx.TensorProto.FLOAT, None) |
| ) |
|
|
| new_graph = onnx.helper.make_graph( |
| filtered_nodes, |
| "backbone", |
| graph_inputs, |
| graph_outputs, |
| initializer=filtered_inits, |
| ) |
| new_model = onnx.helper.make_model(new_graph) |
| new_model.ir_version = model.ir_version |
| del new_model.opset_import[:] |
| new_model.opset_import.extend(model.opset_import) |
| onnx.save(new_model, out_path) |
|
|
|
|
| def _get_graph_input_names(onnx_path: str) -> list[str]: |
| """Return the true data-input tensor names for an ONNX model.""" |
| model = onnx.load(onnx_path) |
| init_names = {i.name for i in model.graph.initializer} |
| return [vi.name for vi in model.graph.input if vi.name not in init_names] |
|
|
|
|
| def extract_backbone(src_path: str, out_path: str) -> str: |
| """Extract the backbone subgraph from a full BirdNET ONNX model and save it. |
| |
| Traces backwards from BACKBONE_RAW_OUTPUT (the global average pool tensor) |
| and renames it to BACKBONE_OUTPUT ("embedding") in the saved file. |
| |
| Returns out_path for chaining. |
| """ |
| input_names = _get_graph_input_names(src_path) |
| _extract( |
| src_path, |
| out_path, |
| input_names, |
| [BACKBONE_RAW_OUTPUT], |
| output_renames={BACKBONE_RAW_OUTPUT: BACKBONE_OUTPUT}, |
| ) |
|
|
| model = onnx.load(out_path) |
| print(f"Backbone saved -> {out_path}") |
| print(f" inputs : {input_names}") |
| print(f" outputs: {[o.name for o in model.graph.output]}") |
| return out_path |
|
|
|
|
| |
| |
| |
|
|
| def _make_audio(length: int, seed: int = 42) -> np.ndarray: |
| """Generate a reproducible Gaussian noise waveform shaped (1, length).""" |
| rng = np.random.default_rng(seed) |
| return rng.standard_normal((1, length)).astype(np.float32) |
|
|
|
|
| def _onnx_embedding(onnx_path: str, audio: np.ndarray) -> np.ndarray: |
| """Run inference on a backbone ONNX model and return the embedding array.""" |
| input_names = _get_graph_input_names(onnx_path) |
| sess = ort.InferenceSession(onnx_path) |
| (emb,) = sess.run([BACKBONE_OUTPUT], {input_names[0]: audio}) |
| return emb |
|
|
|
|
| def _pb_embedding(pb_dir: str, audio: np.ndarray) -> np.ndarray: |
| """Run inference on the BirdNET TF SavedModel and return the embedding array. |
| |
| The audio-model SavedModel exposes an "embeddings" signature whose output |
| dict contains an "embeddings" key, used here as the ground-truth reference. |
| """ |
| model = tf.saved_model.load(pb_dir) |
| audio_tf = tf.constant(audio) |
| return model.signatures["embeddings"](inputs=audio_tf)["embeddings"].numpy() |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| """End-to-end pipeline: download → extract → compare.""" |
| out_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
| |
| print("=== Downloading models ===") |
| onnx_paths = download_onnx_models(out_dir) |
| pb_dir = download_pb_model(out_dir) |
|
|
| |
| print("\n=== Extracting backbones ===") |
| backbone_paths = {} |
| for fname, src in onnx_paths.items(): |
| stem = fname.replace(".onnx", "") |
| out_path = os.path.join(out_dir, f"{stem}_backbone.onnx") |
| backbone_paths[stem] = extract_backbone(src, out_path) |
|
|
| |
| print("\n=== Comparing embeddings against Zenodo TF SavedModel ===") |
| audio = _make_audio(BIRDNET_SAMPLE_LEN) |
|
|
| pb_emb = _pb_embedding(pb_dir, audio) |
| print(f"PB embedding shape: {pb_emb.shape}") |
|
|
| for stem, path in backbone_paths.items(): |
| onnx_emb = _onnx_embedding(path, audio) |
| diff = np.abs(onnx_emb - pb_emb) |
| print(f"\n{stem}_backbone.onnx:") |
| print(f" ONNX embedding shape: {onnx_emb.shape}") |
| print(f" |diff| mean={diff.mean():.6e} max={diff.max():.6e}") |
| try: |
| np.testing.assert_allclose(onnx_emb, pb_emb, rtol=RTOL, atol=ATOL) |
| print(f" Embeddings match PB reference with rtol={RTOL:.0e}, atol={ATOL:.0e} PASSED") |
| except AssertionError as e: |
| print(f" Embeddings differ from PB reference FAILED\n {e}") |
|
|
| print("\nDone.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|