File size: 10,034 Bytes
1d971a3 ba60410 1d971a3 7f4b648 1d971a3 7f4b648 1d971a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 | #!/usr/bin/env python3
"""
Export PE-A-Frame (Perception Encoder Audio Frame) span predictor to ONNX.
The PE-A-Frame model is used for automatic anchor detection in SAM Audio.
It analyzes audio features and predicts which segments correspond to the
target audio source.
Usage:
python -m onnx_export.export_peaframe --output-dir onnx_models --verify
"""
import os
import argparse
import torch
import torch.nn as nn
from typing import Optional
class PEAFrameWrapper(nn.Module):
"""
Wrapper for PE-A-Frame model for ONNX export.
Exposes the forward pass that takes audio features and returns
frame-level predictions.
"""
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
def forward(
self,
audio_features: torch.Tensor,
audio_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass for span prediction.
Args:
audio_features: Audio features [batch, seq_len, hidden_dim]
audio_mask: Optional attention mask [batch, seq_len]
Returns:
Frame-level predictions [batch, seq_len, num_classes]
"""
return self.model(audio_features, attention_mask=audio_mask)
def load_peaframe_model(config_name: str = "pe-a-frame-large", device: str = "cpu"):
"""Load the PE-A-Frame model from perception_models."""
from core.audio_visual_encoder.pe import PEAudioFrame
print(f"Loading PE-A-Frame model: {config_name}...")
model = PEAudioFrame.from_config(config_name, pretrained=True)
model = model.eval().to(device)
num_params = sum(p.numel() for p in model.parameters())
print(f" ✓ Model loaded: {num_params:,} parameters")
return model
def get_tokenizer(model):
"""Get the text tokenizer from the model config."""
from transformers import AutoTokenizer
text_model_name = model.config.text_model._name_or_path
return AutoTokenizer.from_pretrained(text_model_name)
def create_sample_inputs(model, batch_size: int = 1, device: str = "cpu"):
"""Create sample inputs for tracing."""
tokenizer = get_tokenizer(model)
# Sample text query
text = "a person speaking"
tokens = tokenizer(
[text] * batch_size,
return_tensors="pt",
padding=True,
truncation=True,
max_length=77,
)
# Sample audio (10 seconds at 16kHz)
# DAC encoder expects (batch, channels, samples) format
sample_rate = 16000
audio_len = sample_rate * 10
audio = torch.randn(batch_size, 1, audio_len, device=device) # Added channel dimension
return {
"input_ids": tokens["input_ids"].to(device),
"attention_mask": tokens["attention_mask"].to(device),
"input_values": audio,
}
def export_peaframe(
model: nn.Module,
output_path: str,
opset_version: int = 21,
device: str = "cpu",
):
"""Export PE-A-Frame to ONNX."""
import onnx
print(f"Exporting PE-A-Frame to {output_path}...")
sample_inputs = create_sample_inputs(model, device=device)
# Put model in eval mode
model = model.eval()
# Test forward pass first
with torch.no_grad():
try:
output = model(
input_ids=sample_inputs["input_ids"],
input_values=sample_inputs["input_values"],
attention_mask=sample_inputs["attention_mask"],
return_spans=False, # Disable span return for ONNX (list output)
)
print(f" Test forward pass: audio_embeds shape = {output.audio_embeds.shape}")
print(f" Test forward pass: text_embeds shape = {output.text_embeds.shape}")
except Exception as e:
print(f" Forward pass failed: {e}")
raise
# Create a wrapper that returns just the audio embeddings for simpler ONNX
class PEAFrameONNXWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, input_values, attention_mask):
output = self.model(
input_ids=input_ids,
input_values=input_values,
attention_mask=attention_mask,
return_spans=False,
)
return output.audio_embeds, output.text_embeds
wrapper = PEAFrameONNXWrapper(model)
wrapper.eval()
torch.onnx.export(
wrapper,
(sample_inputs["input_ids"], sample_inputs["input_values"], sample_inputs["attention_mask"]),
output_path,
input_names=["input_ids", "input_values", "attention_mask"],
output_names=["audio_embeds", "text_embeds"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "seq_len"},
"input_values": {0: "batch_size", 1: "audio_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"audio_embeds": {0: "batch_size", 1: "num_frames"},
"text_embeds": {0: "batch_size"},
},
opset_version=opset_version,
do_constant_folding=True,
external_data=True,
)
print(" ✓ PE-A-Frame exported successfully")
# Save scaling parameters for post-processing
import json
config = {
"logit_scale": float(model.logit_scale.item()),
"logit_bias": float(model.logit_bias.item()),
"hop_length": model.config.audio_model.dac_vae_encoder.hop_length,
"sampling_rate": model.config.audio_model.dac_vae_encoder.sampling_rate,
"threshold": 0.3,
}
config_path = output_path.replace(".onnx", "_config.json")
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
print(f" ✓ Config saved to {config_path}")
# Basic validation - just check the file exists and can be loaded
# Skip detailed checking with external data to avoid path issues
try:
onnx_model = onnx.load(output_path, load_external_data=False)
print(" ✓ ONNX model structure validated")
except Exception as e:
print(f" âš Warning: Could not validate ONNX structure: {e}")
return True
def verify_peaframe(
model: nn.Module,
onnx_path: str,
device: str = "cpu",
tolerance: float = 1e-3,
) -> bool:
"""Verify ONNX output matches PyTorch."""
import onnxruntime as ort
import numpy as np
print("Verifying PE-A-Frame output...")
sample_inputs = create_sample_inputs(model, device=device)
# PyTorch output
model = model.eval()
with torch.no_grad():
pytorch_output = model(
input_ids=sample_inputs["input_ids"],
input_values=sample_inputs["input_values"],
attention_mask=sample_inputs["attention_mask"],
return_spans=False,
)
pytorch_audio_embeds = pytorch_output.audio_embeds.cpu().numpy()
pytorch_text_embeds = pytorch_output.text_embeds.cpu().numpy()
# ONNX Runtime output
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
onnx_inputs = {
"input_ids": sample_inputs["input_ids"].cpu().numpy().astype(np.int64),
"input_values": sample_inputs["input_values"].cpu().numpy().astype(np.float32),
"attention_mask": sample_inputs["attention_mask"].cpu().numpy().astype(np.int64),
}
onnx_outputs = sess.run(["audio_embeds", "text_embeds"], onnx_inputs)
onnx_audio_embeds = onnx_outputs[0]
onnx_text_embeds = onnx_outputs[1]
# Compare
audio_max_diff = np.abs(pytorch_audio_embeds - onnx_audio_embeds).max()
text_max_diff = np.abs(pytorch_text_embeds - onnx_text_embeds).max()
print(f" Audio embeds max diff: {audio_max_diff:.2e}")
print(f" Text embeds max diff: {text_max_diff:.2e}")
max_diff = max(audio_max_diff, text_max_diff)
if max_diff < tolerance:
print(f" ✓ Verification passed (tolerance: {tolerance})")
return True
else:
print(f" ✗ Verification failed (tolerance: {tolerance})")
return False
def main():
parser = argparse.ArgumentParser(description="Export PE-A-Frame to ONNX")
parser.add_argument(
"--config",
type=str,
default="pe-a-frame-large",
help="PE-A-Frame config name",
)
parser.add_argument(
"--output-dir",
type=str,
default="onnx_models",
help="Output directory for ONNX models",
)
parser.add_argument(
"--opset",
type=int,
default=18,
help="ONNX opset version",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="Device to use",
)
parser.add_argument(
"--verify",
action="store_true",
help="Verify ONNX output",
)
parser.add_argument(
"--tolerance",
type=float,
default=1e-3,
help="Verification tolerance",
)
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# Load model
model = load_peaframe_model(args.config, args.device)
# Export
output_path = os.path.join(args.output_dir, "peaframe.onnx")
export_peaframe(model, output_path, args.opset, args.device)
# Export tokenizer for inference
tokenizer_dir = os.path.join(args.output_dir, "peaframe_tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
from transformers import AutoTokenizer
text_model_name = model.config.text_model._name_or_path
tokenizer = AutoTokenizer.from_pretrained(text_model_name)
tokenizer.save_pretrained(tokenizer_dir)
print(f" ✓ Tokenizer saved to {tokenizer_dir}")
# Verify
if args.verify:
verify_peaframe(model, output_path, args.device, args.tolerance)
print(f"\n✓ Export complete! Model saved to {output_path}")
if __name__ == "__main__":
main()
|