File size: 6,602 Bytes
5c8d855 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
# export_encoder_proprocess_onnx.py
import torch
import torchaudio
from transformers import AutoModel
import argparse
import os
import onnxruntime_extensions # Ensure extensions are available if needed
from dotenv import load_dotenv
load_dotenv()
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", default="wsntxxn/effb2-trm-audiocaps-captioning")
parser.add_argument("--out", default="audio-caption/effb2_encoder_preprocess-2.onnx")
parser.add_argument("--opset", type=int, default=17)
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
device = torch.device(args.device)
print("Loading model (trust_remote_code=True)...")
model = AutoModel.from_pretrained(args.model_id, trust_remote_code=True).to(device)
model.eval()
# Find the encoder (same logic as original script)
encoder_wrapper = None
for candidate in ("audio_encoder", "encoder", "model", "encoder_model"):
if hasattr(model, candidate):
encoder_wrapper = getattr(model, candidate)
break
if encoder_wrapper is None:
try:
encoder_wrapper = model.model.encoder
except Exception:
encoder_wrapper = None
if encoder_wrapper is None:
raise RuntimeError("Couldn't find encoder attribute on model.")
# Find actual encoder
actual_encoder = None
if hasattr(encoder_wrapper, 'model'):
if hasattr(encoder_wrapper.model, 'encoder'):
actual_encoder = encoder_wrapper.model.encoder
elif hasattr(encoder_wrapper.model, 'model') and hasattr(encoder_wrapper.model.model, 'encoder'):
actual_encoder = encoder_wrapper.model.model.encoder
if actual_encoder is None:
print("Could not find actual encoder, using encoder_wrapper as fallback (might fail if it expects dict)")
actual_encoder = encoder_wrapper
# Custom MelSpectrogram to avoid complex type issues in ONNX export
class OnnxCompatibleMelSpectrogram(torch.nn.Module):
def __init__(self, sample_rate=16000, n_fft=512, win_length=512, hop_length=160, n_mels=64):
super().__init__()
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
# Create window and mel scale buffers
window = torch.hann_window(win_length)
self.register_buffer('window', window)
self.mel_scale = torchaudio.transforms.MelScale(
n_mels=n_mels,
sample_rate=sample_rate,
n_stft=n_fft // 2 + 1
)
def forward(self, waveform):
# Use return_complex=False to get (..., freq, time, 2)
# This avoids passing complex tensors which some ONNX exporters struggle with
spec = torch.stft(
waveform,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=True,
pad_mode='reflect',
normalized=False,
onesided=True,
return_complex=False
)
# Calculate power spectrogram: real^2 + imag^2
# spec shape: (batch, freq, time, 2)
power_spec = spec.pow(2).sum(-1) # (batch, freq, time)
# Apply Mel Scale
# MelScale expects (..., freq, time)
mel_spec = self.mel_scale(power_spec)
return mel_spec
class PreprocessEncoderWrapper(torch.nn.Module):
def __init__(self, actual_encoder):
super().__init__()
self.actual_encoder = actual_encoder
# Extract components
self.backbone = actual_encoder.backbone if hasattr(actual_encoder, 'backbone') else None
self.fc = actual_encoder.fc if hasattr(actual_encoder, 'fc') else None
self.fc_proj = actual_encoder.fc_proj if hasattr(actual_encoder, 'fc_proj') else None
if self.backbone is None:
self.backbone = actual_encoder
# Preprocessing settings
self.mel_transform = OnnxCompatibleMelSpectrogram(
sample_rate=16000,
n_fft=512,
win_length=512,
hop_length=160,
n_mels=64
)
self.db_transform = torchaudio.transforms.AmplitudeToDB(top_db=120)
def forward(self, audio):
"""
Args:
audio: (batch, time) - Raw waveform
"""
# 1. Compute Mel Spectrogram
mel = self.mel_transform(audio)
# 2. Amplitude to DB
mel_db = self.db_transform(mel)
# 3. Encoder Forward Pass
features = self.backbone(mel_db)
# Apply pooling/projection
if self.fc is not None:
if features.dim() == 4:
pooled = torch.mean(features, dim=[2, 3])
elif features.dim() == 3:
pooled = torch.mean(features, dim=2)
else:
pooled = features
attn_emb = self.fc(pooled).unsqueeze(1)
elif self.fc_proj is not None:
if features.dim() == 4:
pooled = torch.mean(features, dim=[2, 3])
elif features.dim() == 3:
pooled = torch.mean(features, dim=2)
else:
pooled = features
attn_emb = self.fc_proj(pooled).unsqueeze(1)
else:
if features.dim() == 4:
attn_emb = torch.mean(features, dim=[2, 3]).unsqueeze(1)
elif features.dim() == 3:
attn_emb = features
else:
attn_emb = features.unsqueeze(1)
return attn_emb
print("\nAttempting to export Encoder with Preprocessing...")
# Create dummy audio input
# 1 second of audio at 16kHz
dummy_audio = torch.randn(1, 16000).to(device)
wrapper = PreprocessEncoderWrapper(actual_encoder).to(device)
wrapper.eval()
# Test forward pass
with torch.no_grad():
out = wrapper(dummy_audio)
print(f"✓ Wrapper output shape: {out.shape}")
# Export
export_inputs = (dummy_audio,)
input_names = ["audio"]
output_names = ["encoder_features"]
dynamic_axes = {
"audio": {0: "batch", 1: "time"},
"encoder_features": {0: "batch", 1: "time"}
}
print(f"Exporting to {args.out}...")
try:
torch.onnx.export(
wrapper,
export_inputs,
args.out,
export_params=True,
opset_version=args.opset,
do_constant_folding=True,
input_names=["audio"],
output_names=["attn_emb"],
dynamic_axes=dynamic_axes,
dynamo=False,
)
print("✅ Export successful!")
except Exception as e:
print(f"❌ Export failed: {e}")
import traceback
traceback.print_exc()
|