stivenDR14
feat: Introduce audio captioning and categorization model with ONNX/ExecuTorch hybrid inference and category embedding generation.
5c8d855
| # export_encoder_proprocess_onnx.py | |
| import torch | |
| import torchaudio | |
| from transformers import AutoModel | |
| import argparse | |
| import os | |
| import onnxruntime_extensions # Ensure extensions are available if needed | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_id", default="wsntxxn/effb2-trm-audiocaps-captioning") | |
| parser.add_argument("--out", default="audio-caption/effb2_encoder_preprocess-2.onnx") | |
| parser.add_argument("--opset", type=int, default=17) | |
| parser.add_argument("--device", default="cpu") | |
| args = parser.parse_args() | |
| device = torch.device(args.device) | |
| print("Loading model (trust_remote_code=True)...") | |
| model = AutoModel.from_pretrained(args.model_id, trust_remote_code=True).to(device) | |
| model.eval() | |
| # Find the encoder (same logic as original script) | |
| encoder_wrapper = None | |
| for candidate in ("audio_encoder", "encoder", "model", "encoder_model"): | |
| if hasattr(model, candidate): | |
| encoder_wrapper = getattr(model, candidate) | |
| break | |
| if encoder_wrapper is None: | |
| try: | |
| encoder_wrapper = model.model.encoder | |
| except Exception: | |
| encoder_wrapper = None | |
| if encoder_wrapper is None: | |
| raise RuntimeError("Couldn't find encoder attribute on model.") | |
| # Find actual encoder | |
| actual_encoder = None | |
| if hasattr(encoder_wrapper, 'model'): | |
| if hasattr(encoder_wrapper.model, 'encoder'): | |
| actual_encoder = encoder_wrapper.model.encoder | |
| elif hasattr(encoder_wrapper.model, 'model') and hasattr(encoder_wrapper.model.model, 'encoder'): | |
| actual_encoder = encoder_wrapper.model.model.encoder | |
| if actual_encoder is None: | |
| print("Could not find actual encoder, using encoder_wrapper as fallback (might fail if it expects dict)") | |
| actual_encoder = encoder_wrapper | |
| # Custom MelSpectrogram to avoid complex type issues in ONNX export | |
| class OnnxCompatibleMelSpectrogram(torch.nn.Module): | |
| def __init__(self, sample_rate=16000, n_fft=512, win_length=512, hop_length=160, n_mels=64): | |
| super().__init__() | |
| self.n_fft = n_fft | |
| self.win_length = win_length | |
| self.hop_length = hop_length | |
| # Create window and mel scale buffers | |
| window = torch.hann_window(win_length) | |
| self.register_buffer('window', window) | |
| self.mel_scale = torchaudio.transforms.MelScale( | |
| n_mels=n_mels, | |
| sample_rate=sample_rate, | |
| n_stft=n_fft // 2 + 1 | |
| ) | |
| def forward(self, waveform): | |
| # Use return_complex=False to get (..., freq, time, 2) | |
| # This avoids passing complex tensors which some ONNX exporters struggle with | |
| spec = torch.stft( | |
| waveform, | |
| n_fft=self.n_fft, | |
| hop_length=self.hop_length, | |
| win_length=self.win_length, | |
| window=self.window, | |
| center=True, | |
| pad_mode='reflect', | |
| normalized=False, | |
| onesided=True, | |
| return_complex=False | |
| ) | |
| # Calculate power spectrogram: real^2 + imag^2 | |
| # spec shape: (batch, freq, time, 2) | |
| power_spec = spec.pow(2).sum(-1) # (batch, freq, time) | |
| # Apply Mel Scale | |
| # MelScale expects (..., freq, time) | |
| mel_spec = self.mel_scale(power_spec) | |
| return mel_spec | |
| class PreprocessEncoderWrapper(torch.nn.Module): | |
| def __init__(self, actual_encoder): | |
| super().__init__() | |
| self.actual_encoder = actual_encoder | |
| # Extract components | |
| self.backbone = actual_encoder.backbone if hasattr(actual_encoder, 'backbone') else None | |
| self.fc = actual_encoder.fc if hasattr(actual_encoder, 'fc') else None | |
| self.fc_proj = actual_encoder.fc_proj if hasattr(actual_encoder, 'fc_proj') else None | |
| if self.backbone is None: | |
| self.backbone = actual_encoder | |
| # Preprocessing settings | |
| self.mel_transform = OnnxCompatibleMelSpectrogram( | |
| sample_rate=16000, | |
| n_fft=512, | |
| win_length=512, | |
| hop_length=160, | |
| n_mels=64 | |
| ) | |
| self.db_transform = torchaudio.transforms.AmplitudeToDB(top_db=120) | |
| def forward(self, audio): | |
| """ | |
| Args: | |
| audio: (batch, time) - Raw waveform | |
| """ | |
| # 1. Compute Mel Spectrogram | |
| mel = self.mel_transform(audio) | |
| # 2. Amplitude to DB | |
| mel_db = self.db_transform(mel) | |
| # 3. Encoder Forward Pass | |
| features = self.backbone(mel_db) | |
| # Apply pooling/projection | |
| if self.fc is not None: | |
| if features.dim() == 4: | |
| pooled = torch.mean(features, dim=[2, 3]) | |
| elif features.dim() == 3: | |
| pooled = torch.mean(features, dim=2) | |
| else: | |
| pooled = features | |
| attn_emb = self.fc(pooled).unsqueeze(1) | |
| elif self.fc_proj is not None: | |
| if features.dim() == 4: | |
| pooled = torch.mean(features, dim=[2, 3]) | |
| elif features.dim() == 3: | |
| pooled = torch.mean(features, dim=2) | |
| else: | |
| pooled = features | |
| attn_emb = self.fc_proj(pooled).unsqueeze(1) | |
| else: | |
| if features.dim() == 4: | |
| attn_emb = torch.mean(features, dim=[2, 3]).unsqueeze(1) | |
| elif features.dim() == 3: | |
| attn_emb = features | |
| else: | |
| attn_emb = features.unsqueeze(1) | |
| return attn_emb | |
| print("\nAttempting to export Encoder with Preprocessing...") | |
| # Create dummy audio input | |
| # 1 second of audio at 16kHz | |
| dummy_audio = torch.randn(1, 16000).to(device) | |
| wrapper = PreprocessEncoderWrapper(actual_encoder).to(device) | |
| wrapper.eval() | |
| # Test forward pass | |
| with torch.no_grad(): | |
| out = wrapper(dummy_audio) | |
| print(f"✓ Wrapper output shape: {out.shape}") | |
| # Export | |
| export_inputs = (dummy_audio,) | |
| input_names = ["audio"] | |
| output_names = ["encoder_features"] | |
| dynamic_axes = { | |
| "audio": {0: "batch", 1: "time"}, | |
| "encoder_features": {0: "batch", 1: "time"} | |
| } | |
| print(f"Exporting to {args.out}...") | |
| try: | |
| torch.onnx.export( | |
| wrapper, | |
| export_inputs, | |
| args.out, | |
| export_params=True, | |
| opset_version=args.opset, | |
| do_constant_folding=True, | |
| input_names=["audio"], | |
| output_names=["attn_emb"], | |
| dynamic_axes=dynamic_axes, | |
| dynamo=False, | |
| ) | |
| print("✅ Export successful!") | |
| except Exception as e: | |
| print(f"❌ Export failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |