| |
| """ |
| Export PE-A-Frame (Perception Encoder Audio Frame) span predictor to ONNX. |
| |
| The PE-A-Frame model is used for automatic anchor detection in SAM Audio. |
| It analyzes audio features and predicts which segments correspond to the |
| target audio source. |
| |
| Usage: |
| python -m onnx_export.export_peaframe --output-dir onnx_models --verify |
| """ |
|
|
| import os |
| import argparse |
| import torch |
| import torch.nn as nn |
| from typing import Optional |
|
|
|
|
| class PEAFrameWrapper(nn.Module): |
| """ |
| Wrapper for PE-A-Frame model for ONNX export. |
| |
| Exposes the forward pass that takes audio features and returns |
| frame-level predictions. |
| """ |
| |
| def __init__(self, model: nn.Module): |
| super().__init__() |
| self.model = model |
| |
| def forward( |
| self, |
| audio_features: torch.Tensor, |
| audio_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Forward pass for span prediction. |
| |
| Args: |
| audio_features: Audio features [batch, seq_len, hidden_dim] |
| audio_mask: Optional attention mask [batch, seq_len] |
| |
| Returns: |
| Frame-level predictions [batch, seq_len, num_classes] |
| """ |
| return self.model(audio_features, attention_mask=audio_mask) |
|
|
|
|
| def load_peaframe_model(config_name: str = "pe-a-frame-large", device: str = "cpu"): |
| """Load the PE-A-Frame model from perception_models.""" |
| from core.audio_visual_encoder.pe import PEAudioFrame |
| |
| print(f"Loading PE-A-Frame model: {config_name}...") |
| model = PEAudioFrame.from_config(config_name, pretrained=True) |
| model = model.eval().to(device) |
| |
| num_params = sum(p.numel() for p in model.parameters()) |
| print(f" ✓ Model loaded: {num_params:,} parameters") |
| |
| return model |
|
|
|
|
| def get_tokenizer(model): |
| """Get the text tokenizer from the model config.""" |
| from transformers import AutoTokenizer |
| |
| text_model_name = model.config.text_model._name_or_path |
| return AutoTokenizer.from_pretrained(text_model_name) |
|
|
|
|
| def create_sample_inputs(model, batch_size: int = 1, device: str = "cpu"): |
| """Create sample inputs for tracing.""" |
| tokenizer = get_tokenizer(model) |
| |
| |
| text = "a person speaking" |
| tokens = tokenizer( |
| [text] * batch_size, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=77, |
| ) |
| |
| |
| |
| sample_rate = 16000 |
| audio_len = sample_rate * 10 |
| audio = torch.randn(batch_size, 1, audio_len, device=device) |
| |
| return { |
| "input_ids": tokens["input_ids"].to(device), |
| "attention_mask": tokens["attention_mask"].to(device), |
| "input_values": audio, |
| } |
|
|
|
|
| def export_peaframe( |
| model: nn.Module, |
| output_path: str, |
| opset_version: int = 21, |
| device: str = "cpu", |
| ): |
| """Export PE-A-Frame to ONNX.""" |
| import onnx |
| |
| print(f"Exporting PE-A-Frame to {output_path}...") |
| |
| sample_inputs = create_sample_inputs(model, device=device) |
| |
| |
| model = model.eval() |
| |
| |
| with torch.no_grad(): |
| try: |
| output = model( |
| input_ids=sample_inputs["input_ids"], |
| input_values=sample_inputs["input_values"], |
| attention_mask=sample_inputs["attention_mask"], |
| return_spans=False, |
| ) |
| print(f" Test forward pass: audio_embeds shape = {output.audio_embeds.shape}") |
| print(f" Test forward pass: text_embeds shape = {output.text_embeds.shape}") |
| except Exception as e: |
| print(f" Forward pass failed: {e}") |
| raise |
| |
| |
| class PEAFrameONNXWrapper(nn.Module): |
| def __init__(self, model): |
| super().__init__() |
| self.model = model |
| |
| def forward(self, input_ids, input_values, attention_mask): |
| output = self.model( |
| input_ids=input_ids, |
| input_values=input_values, |
| attention_mask=attention_mask, |
| return_spans=False, |
| ) |
| return output.audio_embeds, output.text_embeds |
| |
| wrapper = PEAFrameONNXWrapper(model) |
| wrapper.eval() |
| |
| torch.onnx.export( |
| wrapper, |
| (sample_inputs["input_ids"], sample_inputs["input_values"], sample_inputs["attention_mask"]), |
| output_path, |
| input_names=["input_ids", "input_values", "attention_mask"], |
| output_names=["audio_embeds", "text_embeds"], |
| dynamic_axes={ |
| "input_ids": {0: "batch_size", 1: "seq_len"}, |
| "input_values": {0: "batch_size", 1: "audio_len"}, |
| "attention_mask": {0: "batch_size", 1: "seq_len"}, |
| "audio_embeds": {0: "batch_size", 1: "num_frames"}, |
| "text_embeds": {0: "batch_size"}, |
| }, |
| opset_version=opset_version, |
| do_constant_folding=True, |
| external_data=True, |
| ) |
| |
| print(" ✓ PE-A-Frame exported successfully") |
|
|
| |
| import json |
|
|
| config = { |
| "logit_scale": float(model.logit_scale.item()), |
| "logit_bias": float(model.logit_bias.item()), |
| "hop_length": model.config.audio_model.dac_vae_encoder.hop_length, |
| "sampling_rate": model.config.audio_model.dac_vae_encoder.sampling_rate, |
| "threshold": 0.3, |
| } |
| config_path = output_path.replace(".onnx", "_config.json") |
| with open(config_path, "w") as f: |
| json.dump(config, f, indent=2) |
| print(f" ✓ Config saved to {config_path}") |
|
|
| |
| |
| try: |
| onnx_model = onnx.load(output_path, load_external_data=False) |
| print(" ✓ ONNX model structure validated") |
| except Exception as e: |
| print(f" ⚠ Warning: Could not validate ONNX structure: {e}") |
|
|
| return True |
|
|
|
|
| def verify_peaframe( |
| model: nn.Module, |
| onnx_path: str, |
| device: str = "cpu", |
| tolerance: float = 1e-3, |
| ) -> bool: |
| """Verify ONNX output matches PyTorch.""" |
| import onnxruntime as ort |
| import numpy as np |
| |
| print("Verifying PE-A-Frame output...") |
| |
| sample_inputs = create_sample_inputs(model, device=device) |
| |
| |
| model = model.eval() |
| with torch.no_grad(): |
| pytorch_output = model( |
| input_ids=sample_inputs["input_ids"], |
| input_values=sample_inputs["input_values"], |
| attention_mask=sample_inputs["attention_mask"], |
| return_spans=False, |
| ) |
| pytorch_audio_embeds = pytorch_output.audio_embeds.cpu().numpy() |
| pytorch_text_embeds = pytorch_output.text_embeds.cpu().numpy() |
| |
| |
| sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) |
| |
| onnx_inputs = { |
| "input_ids": sample_inputs["input_ids"].cpu().numpy().astype(np.int64), |
| "input_values": sample_inputs["input_values"].cpu().numpy().astype(np.float32), |
| "attention_mask": sample_inputs["attention_mask"].cpu().numpy().astype(np.int64), |
| } |
| |
| onnx_outputs = sess.run(["audio_embeds", "text_embeds"], onnx_inputs) |
| onnx_audio_embeds = onnx_outputs[0] |
| onnx_text_embeds = onnx_outputs[1] |
| |
| |
| audio_max_diff = np.abs(pytorch_audio_embeds - onnx_audio_embeds).max() |
| text_max_diff = np.abs(pytorch_text_embeds - onnx_text_embeds).max() |
| |
| print(f" Audio embeds max diff: {audio_max_diff:.2e}") |
| print(f" Text embeds max diff: {text_max_diff:.2e}") |
| |
| max_diff = max(audio_max_diff, text_max_diff) |
| if max_diff < tolerance: |
| print(f" ✓ Verification passed (tolerance: {tolerance})") |
| return True |
| else: |
| print(f" ✗ Verification failed (tolerance: {tolerance})") |
| return False |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Export PE-A-Frame to ONNX") |
| parser.add_argument( |
| "--config", |
| type=str, |
| default="pe-a-frame-large", |
| help="PE-A-Frame config name", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default="onnx_models", |
| help="Output directory for ONNX models", |
| ) |
| parser.add_argument( |
| "--opset", |
| type=int, |
| default=18, |
| help="ONNX opset version", |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="cpu", |
| help="Device to use", |
| ) |
| parser.add_argument( |
| "--verify", |
| action="store_true", |
| help="Verify ONNX output", |
| ) |
| parser.add_argument( |
| "--tolerance", |
| type=float, |
| default=1e-3, |
| help="Verification tolerance", |
| ) |
| |
| args = parser.parse_args() |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| model = load_peaframe_model(args.config, args.device) |
| |
| |
| output_path = os.path.join(args.output_dir, "peaframe.onnx") |
| export_peaframe(model, output_path, args.opset, args.device) |
|
|
| |
| tokenizer_dir = os.path.join(args.output_dir, "peaframe_tokenizer") |
| os.makedirs(tokenizer_dir, exist_ok=True) |
|
|
| from transformers import AutoTokenizer |
| text_model_name = model.config.text_model._name_or_path |
| tokenizer = AutoTokenizer.from_pretrained(text_model_name) |
| tokenizer.save_pretrained(tokenizer_dir) |
| print(f" ✓ Tokenizer saved to {tokenizer_dir}") |
|
|
| |
| if args.verify: |
| verify_peaframe(model, output_path, args.device, args.tolerance) |
| |
| print(f"\n✓ Export complete! Model saved to {output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|