# export_encoder_proprocess_onnx.py import torch import torchaudio from transformers import AutoModel import argparse import os import onnxruntime_extensions # Ensure extensions are available if needed from dotenv import load_dotenv load_dotenv() parser = argparse.ArgumentParser() parser.add_argument("--model_id", default="wsntxxn/effb2-trm-audiocaps-captioning") parser.add_argument("--out", default="audio-caption/effb2_encoder_preprocess-2.onnx") parser.add_argument("--opset", type=int, default=17) parser.add_argument("--device", default="cpu") args = parser.parse_args() device = torch.device(args.device) print("Loading model (trust_remote_code=True)...") model = AutoModel.from_pretrained(args.model_id, trust_remote_code=True).to(device) model.eval() # Find the encoder (same logic as original script) encoder_wrapper = None for candidate in ("audio_encoder", "encoder", "model", "encoder_model"): if hasattr(model, candidate): encoder_wrapper = getattr(model, candidate) break if encoder_wrapper is None: try: encoder_wrapper = model.model.encoder except Exception: encoder_wrapper = None if encoder_wrapper is None: raise RuntimeError("Couldn't find encoder attribute on model.") # Find actual encoder actual_encoder = None if hasattr(encoder_wrapper, 'model'): if hasattr(encoder_wrapper.model, 'encoder'): actual_encoder = encoder_wrapper.model.encoder elif hasattr(encoder_wrapper.model, 'model') and hasattr(encoder_wrapper.model.model, 'encoder'): actual_encoder = encoder_wrapper.model.model.encoder if actual_encoder is None: print("Could not find actual encoder, using encoder_wrapper as fallback (might fail if it expects dict)") actual_encoder = encoder_wrapper # Custom MelSpectrogram to avoid complex type issues in ONNX export class OnnxCompatibleMelSpectrogram(torch.nn.Module): def __init__(self, sample_rate=16000, n_fft=512, win_length=512, hop_length=160, n_mels=64): super().__init__() self.n_fft = n_fft self.win_length = win_length self.hop_length = hop_length # Create window and mel scale buffers window = torch.hann_window(win_length) self.register_buffer('window', window) self.mel_scale = torchaudio.transforms.MelScale( n_mels=n_mels, sample_rate=sample_rate, n_stft=n_fft // 2 + 1 ) def forward(self, waveform): # Use return_complex=False to get (..., freq, time, 2) # This avoids passing complex tensors which some ONNX exporters struggle with spec = torch.stft( waveform, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, center=True, pad_mode='reflect', normalized=False, onesided=True, return_complex=False ) # Calculate power spectrogram: real^2 + imag^2 # spec shape: (batch, freq, time, 2) power_spec = spec.pow(2).sum(-1) # (batch, freq, time) # Apply Mel Scale # MelScale expects (..., freq, time) mel_spec = self.mel_scale(power_spec) return mel_spec class PreprocessEncoderWrapper(torch.nn.Module): def __init__(self, actual_encoder): super().__init__() self.actual_encoder = actual_encoder # Extract components self.backbone = actual_encoder.backbone if hasattr(actual_encoder, 'backbone') else None self.fc = actual_encoder.fc if hasattr(actual_encoder, 'fc') else None self.fc_proj = actual_encoder.fc_proj if hasattr(actual_encoder, 'fc_proj') else None if self.backbone is None: self.backbone = actual_encoder # Preprocessing settings self.mel_transform = OnnxCompatibleMelSpectrogram( sample_rate=16000, n_fft=512, win_length=512, hop_length=160, n_mels=64 ) self.db_transform = torchaudio.transforms.AmplitudeToDB(top_db=120) def forward(self, audio): """ Args: audio: (batch, time) - Raw waveform """ # 1. Compute Mel Spectrogram mel = self.mel_transform(audio) # 2. Amplitude to DB mel_db = self.db_transform(mel) # 3. Encoder Forward Pass features = self.backbone(mel_db) # Apply pooling/projection if self.fc is not None: if features.dim() == 4: pooled = torch.mean(features, dim=[2, 3]) elif features.dim() == 3: pooled = torch.mean(features, dim=2) else: pooled = features attn_emb = self.fc(pooled).unsqueeze(1) elif self.fc_proj is not None: if features.dim() == 4: pooled = torch.mean(features, dim=[2, 3]) elif features.dim() == 3: pooled = torch.mean(features, dim=2) else: pooled = features attn_emb = self.fc_proj(pooled).unsqueeze(1) else: if features.dim() == 4: attn_emb = torch.mean(features, dim=[2, 3]).unsqueeze(1) elif features.dim() == 3: attn_emb = features else: attn_emb = features.unsqueeze(1) return attn_emb print("\nAttempting to export Encoder with Preprocessing...") # Create dummy audio input # 1 second of audio at 16kHz dummy_audio = torch.randn(1, 16000).to(device) wrapper = PreprocessEncoderWrapper(actual_encoder).to(device) wrapper.eval() # Test forward pass with torch.no_grad(): out = wrapper(dummy_audio) print(f"✓ Wrapper output shape: {out.shape}") # Export export_inputs = (dummy_audio,) input_names = ["audio"] output_names = ["encoder_features"] dynamic_axes = { "audio": {0: "batch", 1: "time"}, "encoder_features": {0: "batch", 1: "time"} } print(f"Exporting to {args.out}...") try: torch.onnx.export( wrapper, export_inputs, args.out, export_params=True, opset_version=args.opset, do_constant_folding=True, input_names=["audio"], output_names=["attn_emb"], dynamic_axes=dynamic_axes, dynamo=False, ) print("✅ Export successful!") except Exception as e: print(f"❌ Export failed: {e}") import traceback traceback.print_exc()