#!/usr/bin/env python3 """ Export PE-A-Frame (Perception Encoder Audio Frame) span predictor to ONNX. The PE-A-Frame model is used for automatic anchor detection in SAM Audio. It analyzes audio features and predicts which segments correspond to the target audio source. Usage: python -m onnx_export.export_peaframe --output-dir onnx_models --verify """ import os import argparse import torch import torch.nn as nn from typing import Optional class PEAFrameWrapper(nn.Module): """ Wrapper for PE-A-Frame model for ONNX export. Exposes the forward pass that takes audio features and returns frame-level predictions. """ def __init__(self, model: nn.Module): super().__init__() self.model = model def forward( self, audio_features: torch.Tensor, audio_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass for span prediction. Args: audio_features: Audio features [batch, seq_len, hidden_dim] audio_mask: Optional attention mask [batch, seq_len] Returns: Frame-level predictions [batch, seq_len, num_classes] """ return self.model(audio_features, attention_mask=audio_mask) def load_peaframe_model(config_name: str = "pe-a-frame-large", device: str = "cpu"): """Load the PE-A-Frame model from perception_models.""" from core.audio_visual_encoder.pe import PEAudioFrame print(f"Loading PE-A-Frame model: {config_name}...") model = PEAudioFrame.from_config(config_name, pretrained=True) model = model.eval().to(device) num_params = sum(p.numel() for p in model.parameters()) print(f" ✓ Model loaded: {num_params:,} parameters") return model def get_tokenizer(model): """Get the text tokenizer from the model config.""" from transformers import AutoTokenizer text_model_name = model.config.text_model._name_or_path return AutoTokenizer.from_pretrained(text_model_name) def create_sample_inputs(model, batch_size: int = 1, device: str = "cpu"): """Create sample inputs for tracing.""" tokenizer = get_tokenizer(model) # Sample text query text = "a person speaking" tokens = tokenizer( [text] * batch_size, return_tensors="pt", padding=True, truncation=True, max_length=77, ) # Sample audio (10 seconds at 16kHz) # DAC encoder expects (batch, channels, samples) format sample_rate = 16000 audio_len = sample_rate * 10 audio = torch.randn(batch_size, 1, audio_len, device=device) # Added channel dimension return { "input_ids": tokens["input_ids"].to(device), "attention_mask": tokens["attention_mask"].to(device), "input_values": audio, } def export_peaframe( model: nn.Module, output_path: str, opset_version: int = 21, device: str = "cpu", ): """Export PE-A-Frame to ONNX.""" import onnx print(f"Exporting PE-A-Frame to {output_path}...") sample_inputs = create_sample_inputs(model, device=device) # Put model in eval mode model = model.eval() # Test forward pass first with torch.no_grad(): try: output = model( input_ids=sample_inputs["input_ids"], input_values=sample_inputs["input_values"], attention_mask=sample_inputs["attention_mask"], return_spans=False, # Disable span return for ONNX (list output) ) print(f" Test forward pass: audio_embeds shape = {output.audio_embeds.shape}") print(f" Test forward pass: text_embeds shape = {output.text_embeds.shape}") except Exception as e: print(f" Forward pass failed: {e}") raise # Create a wrapper that returns just the audio embeddings for simpler ONNX class PEAFrameONNXWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, input_ids, input_values, attention_mask): output = self.model( input_ids=input_ids, input_values=input_values, attention_mask=attention_mask, return_spans=False, ) return output.audio_embeds, output.text_embeds wrapper = PEAFrameONNXWrapper(model) wrapper.eval() torch.onnx.export( wrapper, (sample_inputs["input_ids"], sample_inputs["input_values"], sample_inputs["attention_mask"]), output_path, input_names=["input_ids", "input_values", "attention_mask"], output_names=["audio_embeds", "text_embeds"], dynamic_axes={ "input_ids": {0: "batch_size", 1: "seq_len"}, "input_values": {0: "batch_size", 1: "audio_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}, "audio_embeds": {0: "batch_size", 1: "num_frames"}, "text_embeds": {0: "batch_size"}, }, opset_version=opset_version, do_constant_folding=True, external_data=True, ) print(" ✓ PE-A-Frame exported successfully") # Load without external data to avoid OOM - we just need to validate structure onnx_model = onnx.load(output_path, load_external_data=False) onnx.checker.check_model(onnx_model, full_check=False) print(" ✓ ONNX model validation passed") return True def verify_peaframe( model: nn.Module, onnx_path: str, device: str = "cpu", tolerance: float = 1e-3, ) -> bool: """Verify ONNX output matches PyTorch.""" import onnxruntime as ort import numpy as np print("Verifying PE-A-Frame output...") sample_inputs = create_sample_inputs(model, device=device) # PyTorch output model = model.eval() with torch.no_grad(): pytorch_output = model( input_ids=sample_inputs["input_ids"], input_values=sample_inputs["input_values"], attention_mask=sample_inputs["attention_mask"], return_spans=False, ) pytorch_audio_embeds = pytorch_output.audio_embeds.cpu().numpy() pytorch_text_embeds = pytorch_output.text_embeds.cpu().numpy() # ONNX Runtime output sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) onnx_inputs = { "input_ids": sample_inputs["input_ids"].cpu().numpy().astype(np.int64), "input_values": sample_inputs["input_values"].cpu().numpy().astype(np.float32), "attention_mask": sample_inputs["attention_mask"].cpu().numpy().astype(np.int64), } onnx_outputs = sess.run(["audio_embeds", "text_embeds"], onnx_inputs) onnx_audio_embeds = onnx_outputs[0] onnx_text_embeds = onnx_outputs[1] # Compare audio_max_diff = np.abs(pytorch_audio_embeds - onnx_audio_embeds).max() text_max_diff = np.abs(pytorch_text_embeds - onnx_text_embeds).max() print(f" Audio embeds max diff: {audio_max_diff:.2e}") print(f" Text embeds max diff: {text_max_diff:.2e}") max_diff = max(audio_max_diff, text_max_diff) if max_diff < tolerance: print(f" ✓ Verification passed (tolerance: {tolerance})") return True else: print(f" ✗ Verification failed (tolerance: {tolerance})") return False def main(): parser = argparse.ArgumentParser(description="Export PE-A-Frame to ONNX") parser.add_argument( "--config", type=str, default="pe-a-frame-large", help="PE-A-Frame config name", ) parser.add_argument( "--output-dir", type=str, default="onnx_models", help="Output directory for ONNX models", ) parser.add_argument( "--opset", type=int, default=18, help="ONNX opset version", ) parser.add_argument( "--device", type=str, default="cpu", help="Device to use", ) parser.add_argument( "--verify", action="store_true", help="Verify ONNX output", ) parser.add_argument( "--tolerance", type=float, default=1e-3, help="Verification tolerance", ) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) # Load model model = load_peaframe_model(args.config, args.device) # Export output_path = os.path.join(args.output_dir, "peaframe.onnx") export_peaframe(model, output_path, args.opset, args.device) # Verify if args.verify: verify_peaframe(model, output_path, args.device, args.tolerance) print(f"\n✓ Export complete! Model saved to {output_path}") if __name__ == "__main__": main()