sam-audio-large-onnx / onnx_export /export_peaframe.py
matbee's picture
Upload folder using huggingface_hub
07823f7 verified
#!/usr/bin/env python3
"""
Export PE-A-Frame (Perception Encoder Audio Frame) span predictor to ONNX.
The PE-A-Frame model is used for automatic anchor detection in SAM Audio.
It analyzes audio features and predicts which segments correspond to the
target audio source.
Usage:
python -m onnx_export.export_peaframe --output-dir onnx_models --verify
"""
import os
import argparse
import torch
import torch.nn as nn
from typing import Optional
class PEAFrameWrapper(nn.Module):
"""
Wrapper for PE-A-Frame model for ONNX export.
Exposes the forward pass that takes audio features and returns
frame-level predictions.
"""
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
def forward(
self,
audio_features: torch.Tensor,
audio_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass for span prediction.
Args:
audio_features: Audio features [batch, seq_len, hidden_dim]
audio_mask: Optional attention mask [batch, seq_len]
Returns:
Frame-level predictions [batch, seq_len, num_classes]
"""
return self.model(audio_features, attention_mask=audio_mask)
def load_peaframe_model(config_name: str = "pe-a-frame-large", device: str = "cpu"):
"""Load the PE-A-Frame model from perception_models."""
from core.audio_visual_encoder.pe import PEAudioFrame
print(f"Loading PE-A-Frame model: {config_name}...")
model = PEAudioFrame.from_config(config_name, pretrained=True)
model = model.eval().to(device)
num_params = sum(p.numel() for p in model.parameters())
print(f" ✓ Model loaded: {num_params:,} parameters")
return model
def get_tokenizer(model):
"""Get the text tokenizer from the model config."""
from transformers import AutoTokenizer
text_model_name = model.config.text_model._name_or_path
return AutoTokenizer.from_pretrained(text_model_name)
def create_sample_inputs(model, batch_size: int = 1, device: str = "cpu"):
"""Create sample inputs for tracing."""
tokenizer = get_tokenizer(model)
# Sample text query
text = "a person speaking"
tokens = tokenizer(
[text] * batch_size,
return_tensors="pt",
padding=True,
truncation=True,
max_length=77,
)
# Sample audio (10 seconds at 16kHz)
# DAC encoder expects (batch, channels, samples) format
sample_rate = 16000
audio_len = sample_rate * 10
audio = torch.randn(batch_size, 1, audio_len, device=device) # Added channel dimension
return {
"input_ids": tokens["input_ids"].to(device),
"attention_mask": tokens["attention_mask"].to(device),
"input_values": audio,
}
def export_peaframe(
model: nn.Module,
output_path: str,
opset_version: int = 21,
device: str = "cpu",
):
"""Export PE-A-Frame to ONNX."""
import onnx
print(f"Exporting PE-A-Frame to {output_path}...")
sample_inputs = create_sample_inputs(model, device=device)
# Put model in eval mode
model = model.eval()
# Test forward pass first
with torch.no_grad():
try:
output = model(
input_ids=sample_inputs["input_ids"],
input_values=sample_inputs["input_values"],
attention_mask=sample_inputs["attention_mask"],
return_spans=False, # Disable span return for ONNX (list output)
)
print(f" Test forward pass: audio_embeds shape = {output.audio_embeds.shape}")
print(f" Test forward pass: text_embeds shape = {output.text_embeds.shape}")
except Exception as e:
print(f" Forward pass failed: {e}")
raise
# Create a wrapper that returns just the audio embeddings for simpler ONNX
class PEAFrameONNXWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, input_values, attention_mask):
output = self.model(
input_ids=input_ids,
input_values=input_values,
attention_mask=attention_mask,
return_spans=False,
)
return output.audio_embeds, output.text_embeds
wrapper = PEAFrameONNXWrapper(model)
wrapper.eval()
torch.onnx.export(
wrapper,
(sample_inputs["input_ids"], sample_inputs["input_values"], sample_inputs["attention_mask"]),
output_path,
input_names=["input_ids", "input_values", "attention_mask"],
output_names=["audio_embeds", "text_embeds"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "seq_len"},
"input_values": {0: "batch_size", 1: "audio_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"audio_embeds": {0: "batch_size", 1: "num_frames"},
"text_embeds": {0: "batch_size"},
},
opset_version=opset_version,
do_constant_folding=True,
external_data=True,
)
print(" ✓ PE-A-Frame exported successfully")
# Load without external data to avoid OOM - we just need to validate structure
onnx_model = onnx.load(output_path, load_external_data=False)
onnx.checker.check_model(onnx_model, full_check=False)
print(" ✓ ONNX model validation passed")
return True
def verify_peaframe(
model: nn.Module,
onnx_path: str,
device: str = "cpu",
tolerance: float = 1e-3,
) -> bool:
"""Verify ONNX output matches PyTorch."""
import onnxruntime as ort
import numpy as np
print("Verifying PE-A-Frame output...")
sample_inputs = create_sample_inputs(model, device=device)
# PyTorch output
model = model.eval()
with torch.no_grad():
pytorch_output = model(
input_ids=sample_inputs["input_ids"],
input_values=sample_inputs["input_values"],
attention_mask=sample_inputs["attention_mask"],
return_spans=False,
)
pytorch_audio_embeds = pytorch_output.audio_embeds.cpu().numpy()
pytorch_text_embeds = pytorch_output.text_embeds.cpu().numpy()
# ONNX Runtime output
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
onnx_inputs = {
"input_ids": sample_inputs["input_ids"].cpu().numpy().astype(np.int64),
"input_values": sample_inputs["input_values"].cpu().numpy().astype(np.float32),
"attention_mask": sample_inputs["attention_mask"].cpu().numpy().astype(np.int64),
}
onnx_outputs = sess.run(["audio_embeds", "text_embeds"], onnx_inputs)
onnx_audio_embeds = onnx_outputs[0]
onnx_text_embeds = onnx_outputs[1]
# Compare
audio_max_diff = np.abs(pytorch_audio_embeds - onnx_audio_embeds).max()
text_max_diff = np.abs(pytorch_text_embeds - onnx_text_embeds).max()
print(f" Audio embeds max diff: {audio_max_diff:.2e}")
print(f" Text embeds max diff: {text_max_diff:.2e}")
max_diff = max(audio_max_diff, text_max_diff)
if max_diff < tolerance:
print(f" ✓ Verification passed (tolerance: {tolerance})")
return True
else:
print(f" ✗ Verification failed (tolerance: {tolerance})")
return False
def main():
parser = argparse.ArgumentParser(description="Export PE-A-Frame to ONNX")
parser.add_argument(
"--config",
type=str,
default="pe-a-frame-large",
help="PE-A-Frame config name",
)
parser.add_argument(
"--output-dir",
type=str,
default="onnx_models",
help="Output directory for ONNX models",
)
parser.add_argument(
"--opset",
type=int,
default=18,
help="ONNX opset version",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="Device to use",
)
parser.add_argument(
"--verify",
action="store_true",
help="Verify ONNX output",
)
parser.add_argument(
"--tolerance",
type=float,
default=1e-3,
help="Verification tolerance",
)
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# Load model
model = load_peaframe_model(args.config, args.device)
# Export
output_path = os.path.join(args.output_dir, "peaframe.onnx")
export_peaframe(model, output_path, args.opset, args.device)
# Verify
if args.verify:
verify_peaframe(model, output_path, args.device, args.tolerance)
print(f"\n✓ Export complete! Model saved to {output_path}")
if __name__ == "__main__":
main()