audio-caption-categorizer-model / audio-caption /export_encoder_preprocess_onnx.py
stivenDR14
feat: Introduce audio captioning and categorization model with ONNX/ExecuTorch hybrid inference and category embedding generation.
5c8d855
# export_encoder_proprocess_onnx.py
import torch
import torchaudio
from transformers import AutoModel
import argparse
import os
import onnxruntime_extensions # Ensure extensions are available if needed
from dotenv import load_dotenv
load_dotenv()
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", default="wsntxxn/effb2-trm-audiocaps-captioning")
parser.add_argument("--out", default="audio-caption/effb2_encoder_preprocess-2.onnx")
parser.add_argument("--opset", type=int, default=17)
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
device = torch.device(args.device)
print("Loading model (trust_remote_code=True)...")
model = AutoModel.from_pretrained(args.model_id, trust_remote_code=True).to(device)
model.eval()
# Find the encoder (same logic as original script)
encoder_wrapper = None
for candidate in ("audio_encoder", "encoder", "model", "encoder_model"):
if hasattr(model, candidate):
encoder_wrapper = getattr(model, candidate)
break
if encoder_wrapper is None:
try:
encoder_wrapper = model.model.encoder
except Exception:
encoder_wrapper = None
if encoder_wrapper is None:
raise RuntimeError("Couldn't find encoder attribute on model.")
# Find actual encoder
actual_encoder = None
if hasattr(encoder_wrapper, 'model'):
if hasattr(encoder_wrapper.model, 'encoder'):
actual_encoder = encoder_wrapper.model.encoder
elif hasattr(encoder_wrapper.model, 'model') and hasattr(encoder_wrapper.model.model, 'encoder'):
actual_encoder = encoder_wrapper.model.model.encoder
if actual_encoder is None:
print("Could not find actual encoder, using encoder_wrapper as fallback (might fail if it expects dict)")
actual_encoder = encoder_wrapper
# Custom MelSpectrogram to avoid complex type issues in ONNX export
class OnnxCompatibleMelSpectrogram(torch.nn.Module):
def __init__(self, sample_rate=16000, n_fft=512, win_length=512, hop_length=160, n_mels=64):
super().__init__()
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
# Create window and mel scale buffers
window = torch.hann_window(win_length)
self.register_buffer('window', window)
self.mel_scale = torchaudio.transforms.MelScale(
n_mels=n_mels,
sample_rate=sample_rate,
n_stft=n_fft // 2 + 1
)
def forward(self, waveform):
# Use return_complex=False to get (..., freq, time, 2)
# This avoids passing complex tensors which some ONNX exporters struggle with
spec = torch.stft(
waveform,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=True,
pad_mode='reflect',
normalized=False,
onesided=True,
return_complex=False
)
# Calculate power spectrogram: real^2 + imag^2
# spec shape: (batch, freq, time, 2)
power_spec = spec.pow(2).sum(-1) # (batch, freq, time)
# Apply Mel Scale
# MelScale expects (..., freq, time)
mel_spec = self.mel_scale(power_spec)
return mel_spec
class PreprocessEncoderWrapper(torch.nn.Module):
def __init__(self, actual_encoder):
super().__init__()
self.actual_encoder = actual_encoder
# Extract components
self.backbone = actual_encoder.backbone if hasattr(actual_encoder, 'backbone') else None
self.fc = actual_encoder.fc if hasattr(actual_encoder, 'fc') else None
self.fc_proj = actual_encoder.fc_proj if hasattr(actual_encoder, 'fc_proj') else None
if self.backbone is None:
self.backbone = actual_encoder
# Preprocessing settings
self.mel_transform = OnnxCompatibleMelSpectrogram(
sample_rate=16000,
n_fft=512,
win_length=512,
hop_length=160,
n_mels=64
)
self.db_transform = torchaudio.transforms.AmplitudeToDB(top_db=120)
def forward(self, audio):
"""
Args:
audio: (batch, time) - Raw waveform
"""
# 1. Compute Mel Spectrogram
mel = self.mel_transform(audio)
# 2. Amplitude to DB
mel_db = self.db_transform(mel)
# 3. Encoder Forward Pass
features = self.backbone(mel_db)
# Apply pooling/projection
if self.fc is not None:
if features.dim() == 4:
pooled = torch.mean(features, dim=[2, 3])
elif features.dim() == 3:
pooled = torch.mean(features, dim=2)
else:
pooled = features
attn_emb = self.fc(pooled).unsqueeze(1)
elif self.fc_proj is not None:
if features.dim() == 4:
pooled = torch.mean(features, dim=[2, 3])
elif features.dim() == 3:
pooled = torch.mean(features, dim=2)
else:
pooled = features
attn_emb = self.fc_proj(pooled).unsqueeze(1)
else:
if features.dim() == 4:
attn_emb = torch.mean(features, dim=[2, 3]).unsqueeze(1)
elif features.dim() == 3:
attn_emb = features
else:
attn_emb = features.unsqueeze(1)
return attn_emb
print("\nAttempting to export Encoder with Preprocessing...")
# Create dummy audio input
# 1 second of audio at 16kHz
dummy_audio = torch.randn(1, 16000).to(device)
wrapper = PreprocessEncoderWrapper(actual_encoder).to(device)
wrapper.eval()
# Test forward pass
with torch.no_grad():
out = wrapper(dummy_audio)
print(f"✓ Wrapper output shape: {out.shape}")
# Export
export_inputs = (dummy_audio,)
input_names = ["audio"]
output_names = ["encoder_features"]
dynamic_axes = {
"audio": {0: "batch", 1: "time"},
"encoder_features": {0: "batch", 1: "time"}
}
print(f"Exporting to {args.out}...")
try:
torch.onnx.export(
wrapper,
export_inputs,
args.out,
export_params=True,
opset_version=args.opset,
do_constant_folding=True,
input_names=["audio"],
output_names=["attn_emb"],
dynamic_axes=dynamic_axes,
dynamo=False,
)
print("✅ Export successful!")
except Exception as e:
print(f"❌ Export failed: {e}")
import traceback
traceback.print_exc()