|
|
|
|
|
""" |
|
|
Export PE-A-Frame (Perception Encoder Audio Frame) span predictor to ONNX. |
|
|
|
|
|
The PE-A-Frame model is used for automatic anchor detection in SAM Audio. |
|
|
It analyzes audio features and predicts which segments correspond to the |
|
|
target audio source. |
|
|
|
|
|
Usage: |
|
|
python -m onnx_export.export_peaframe --output-dir onnx_models --verify |
|
|
""" |
|
|
|
|
|
import os |
|
|
import argparse |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
class PEAFrameWrapper(nn.Module): |
|
|
""" |
|
|
Wrapper for PE-A-Frame model for ONNX export. |
|
|
|
|
|
Exposes the forward pass that takes audio features and returns |
|
|
frame-level predictions. |
|
|
""" |
|
|
|
|
|
def __init__(self, model: nn.Module): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
audio_features: torch.Tensor, |
|
|
audio_mask: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass for span prediction. |
|
|
|
|
|
Args: |
|
|
audio_features: Audio features [batch, seq_len, hidden_dim] |
|
|
audio_mask: Optional attention mask [batch, seq_len] |
|
|
|
|
|
Returns: |
|
|
Frame-level predictions [batch, seq_len, num_classes] |
|
|
""" |
|
|
return self.model(audio_features, attention_mask=audio_mask) |
|
|
|
|
|
|
|
|
def load_peaframe_model(config_name: str = "pe-a-frame-large", device: str = "cpu"): |
|
|
"""Load the PE-A-Frame model from perception_models.""" |
|
|
from core.audio_visual_encoder.pe import PEAudioFrame |
|
|
|
|
|
print(f"Loading PE-A-Frame model: {config_name}...") |
|
|
model = PEAudioFrame.from_config(config_name, pretrained=True) |
|
|
model = model.eval().to(device) |
|
|
|
|
|
num_params = sum(p.numel() for p in model.parameters()) |
|
|
print(f" ✓ Model loaded: {num_params:,} parameters") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def get_tokenizer(model): |
|
|
"""Get the text tokenizer from the model config.""" |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
text_model_name = model.config.text_model._name_or_path |
|
|
return AutoTokenizer.from_pretrained(text_model_name) |
|
|
|
|
|
|
|
|
def create_sample_inputs(model, batch_size: int = 1, device: str = "cpu"): |
|
|
"""Create sample inputs for tracing.""" |
|
|
tokenizer = get_tokenizer(model) |
|
|
|
|
|
|
|
|
text = "a person speaking" |
|
|
tokens = tokenizer( |
|
|
[text] * batch_size, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=77, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
sample_rate = 16000 |
|
|
audio_len = sample_rate * 10 |
|
|
audio = torch.randn(batch_size, 1, audio_len, device=device) |
|
|
|
|
|
return { |
|
|
"input_ids": tokens["input_ids"].to(device), |
|
|
"attention_mask": tokens["attention_mask"].to(device), |
|
|
"input_values": audio, |
|
|
} |
|
|
|
|
|
|
|
|
def export_peaframe( |
|
|
model: nn.Module, |
|
|
output_path: str, |
|
|
opset_version: int = 21, |
|
|
device: str = "cpu", |
|
|
): |
|
|
"""Export PE-A-Frame to ONNX.""" |
|
|
import onnx |
|
|
|
|
|
print(f"Exporting PE-A-Frame to {output_path}...") |
|
|
|
|
|
sample_inputs = create_sample_inputs(model, device=device) |
|
|
|
|
|
|
|
|
model = model.eval() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
try: |
|
|
output = model( |
|
|
input_ids=sample_inputs["input_ids"], |
|
|
input_values=sample_inputs["input_values"], |
|
|
attention_mask=sample_inputs["attention_mask"], |
|
|
return_spans=False, |
|
|
) |
|
|
print(f" Test forward pass: audio_embeds shape = {output.audio_embeds.shape}") |
|
|
print(f" Test forward pass: text_embeds shape = {output.text_embeds.shape}") |
|
|
except Exception as e: |
|
|
print(f" Forward pass failed: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
class PEAFrameONNXWrapper(nn.Module): |
|
|
def __init__(self, model): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
|
|
|
def forward(self, input_ids, input_values, attention_mask): |
|
|
output = self.model( |
|
|
input_ids=input_ids, |
|
|
input_values=input_values, |
|
|
attention_mask=attention_mask, |
|
|
return_spans=False, |
|
|
) |
|
|
return output.audio_embeds, output.text_embeds |
|
|
|
|
|
wrapper = PEAFrameONNXWrapper(model) |
|
|
wrapper.eval() |
|
|
|
|
|
torch.onnx.export( |
|
|
wrapper, |
|
|
(sample_inputs["input_ids"], sample_inputs["input_values"], sample_inputs["attention_mask"]), |
|
|
output_path, |
|
|
input_names=["input_ids", "input_values", "attention_mask"], |
|
|
output_names=["audio_embeds", "text_embeds"], |
|
|
dynamic_axes={ |
|
|
"input_ids": {0: "batch_size", 1: "seq_len"}, |
|
|
"input_values": {0: "batch_size", 1: "audio_len"}, |
|
|
"attention_mask": {0: "batch_size", 1: "seq_len"}, |
|
|
"audio_embeds": {0: "batch_size", 1: "num_frames"}, |
|
|
"text_embeds": {0: "batch_size"}, |
|
|
}, |
|
|
opset_version=opset_version, |
|
|
do_constant_folding=True, |
|
|
external_data=True, |
|
|
) |
|
|
|
|
|
print(" ✓ PE-A-Frame exported successfully") |
|
|
|
|
|
|
|
|
onnx_model = onnx.load(output_path, load_external_data=False) |
|
|
onnx.checker.check_model(onnx_model, full_check=False) |
|
|
print(" ✓ ONNX model validation passed") |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def verify_peaframe( |
|
|
model: nn.Module, |
|
|
onnx_path: str, |
|
|
device: str = "cpu", |
|
|
tolerance: float = 1e-3, |
|
|
) -> bool: |
|
|
"""Verify ONNX output matches PyTorch.""" |
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
|
|
|
print("Verifying PE-A-Frame output...") |
|
|
|
|
|
sample_inputs = create_sample_inputs(model, device=device) |
|
|
|
|
|
|
|
|
model = model.eval() |
|
|
with torch.no_grad(): |
|
|
pytorch_output = model( |
|
|
input_ids=sample_inputs["input_ids"], |
|
|
input_values=sample_inputs["input_values"], |
|
|
attention_mask=sample_inputs["attention_mask"], |
|
|
return_spans=False, |
|
|
) |
|
|
pytorch_audio_embeds = pytorch_output.audio_embeds.cpu().numpy() |
|
|
pytorch_text_embeds = pytorch_output.text_embeds.cpu().numpy() |
|
|
|
|
|
|
|
|
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) |
|
|
|
|
|
onnx_inputs = { |
|
|
"input_ids": sample_inputs["input_ids"].cpu().numpy().astype(np.int64), |
|
|
"input_values": sample_inputs["input_values"].cpu().numpy().astype(np.float32), |
|
|
"attention_mask": sample_inputs["attention_mask"].cpu().numpy().astype(np.int64), |
|
|
} |
|
|
|
|
|
onnx_outputs = sess.run(["audio_embeds", "text_embeds"], onnx_inputs) |
|
|
onnx_audio_embeds = onnx_outputs[0] |
|
|
onnx_text_embeds = onnx_outputs[1] |
|
|
|
|
|
|
|
|
audio_max_diff = np.abs(pytorch_audio_embeds - onnx_audio_embeds).max() |
|
|
text_max_diff = np.abs(pytorch_text_embeds - onnx_text_embeds).max() |
|
|
|
|
|
print(f" Audio embeds max diff: {audio_max_diff:.2e}") |
|
|
print(f" Text embeds max diff: {text_max_diff:.2e}") |
|
|
|
|
|
max_diff = max(audio_max_diff, text_max_diff) |
|
|
if max_diff < tolerance: |
|
|
print(f" ✓ Verification passed (tolerance: {tolerance})") |
|
|
return True |
|
|
else: |
|
|
print(f" ✗ Verification failed (tolerance: {tolerance})") |
|
|
return False |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Export PE-A-Frame to ONNX") |
|
|
parser.add_argument( |
|
|
"--config", |
|
|
type=str, |
|
|
default="pe-a-frame-large", |
|
|
help="PE-A-Frame config name", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=str, |
|
|
default="onnx_models", |
|
|
help="Output directory for ONNX models", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--opset", |
|
|
type=int, |
|
|
default=18, |
|
|
help="ONNX opset version", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="cpu", |
|
|
help="Device to use", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--verify", |
|
|
action="store_true", |
|
|
help="Verify ONNX output", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--tolerance", |
|
|
type=float, |
|
|
default=1e-3, |
|
|
help="Verification tolerance", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
model = load_peaframe_model(args.config, args.device) |
|
|
|
|
|
|
|
|
output_path = os.path.join(args.output_dir, "peaframe.onnx") |
|
|
export_peaframe(model, output_path, args.opset, args.device) |
|
|
|
|
|
|
|
|
if args.verify: |
|
|
verify_peaframe(model, output_path, args.device, args.tolerance) |
|
|
|
|
|
print(f"\n✓ Export complete! Model saved to {output_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|