stivenDR14
feat: Introduce audio captioning and categorization model with ONNX/ExecuTorch hybrid inference and category embedding generation.
5c8d855
| """ | |
| Export decoder to ExecuTorch .pte format as an alternative to ONNX. | |
| This might handle dynamic sequence lengths better. | |
| """ | |
| import torch | |
| import argparse | |
| from transformers import AutoModel, AutoTokenizer | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model", default="wsntxxn/effb2-trm-audiocaps-captioning") | |
| parser.add_argument("--out", default="effb2_decoder_step.pte") | |
| args = parser.parse_args() | |
| print(f"Loading model: {args.model}") | |
| model = AutoModel.from_pretrained(args.model, trust_remote_code=True) | |
| model.eval() | |
| # Get decoder - navigate through the model structure | |
| # Based on inspection: model.model.model.decoder | |
| if hasattr(model, "model") and hasattr(model.model, "model") and hasattr(model.model.model, "decoder"): | |
| decoder = model.model.model.decoder | |
| encoder = model.model.model.encoder | |
| print(f"Found decoder at model.model.model.decoder") | |
| elif hasattr(model, "model") and hasattr(model.model, "decoder"): | |
| decoder = model.model.decoder | |
| encoder = model.model.encoder | |
| print(f"Found decoder at model.model.decoder") | |
| else: | |
| # Try to find by iterating | |
| for name, module in model.named_modules(): | |
| if "decoder" in name.lower() and "TransformerDecoder" in module.__class__.__name__: | |
| decoder = module | |
| print(f"Found decoder at {name}") | |
| break | |
| else: | |
| raise RuntimeError("Could not find decoder in model") | |
| print(f"Decoder: {decoder.__class__.__name__}") | |
| # Wrap decoder similar to ONNX version | |
| class DecoderStepWrapper(torch.nn.Module): | |
| def __init__(self, decoder, vocab_size): | |
| super().__init__() | |
| self.decoder = decoder | |
| self.vocab_size = vocab_size | |
| def forward(self, word_ids, attn_emb, attn_emb_len): | |
| """ | |
| Args: | |
| word_ids: (batch, seq_len) | |
| attn_emb: (batch, time, dim) | |
| attn_emb_len: (batch,) | |
| Returns: | |
| logits: (batch, seq_len, vocab_size) | |
| """ | |
| import math | |
| # Replicate the custom decoder's forward logic | |
| p_attn_emb = self.decoder.attn_proj(attn_emb) | |
| p_attn_emb = p_attn_emb.transpose(0, 1) # [time, batch, dim] | |
| embed = self.decoder.word_embedding(word_ids) | |
| emb_dim = getattr(self.decoder, "emb_dim", 256) | |
| embed = self.decoder.in_dropout(embed) * math.sqrt(emb_dim) | |
| embed = embed.transpose(0, 1) # [seq, batch, dim] | |
| embed = self.decoder.pos_encoder(embed) | |
| # 5. Masks | |
| # CRITICAL: Create causal mask without NaN | |
| # Don't use ones * inf because 0 * inf = NaN! | |
| seq_len = embed.size(0) | |
| # Create causal mask: 0 on and below diagonal, -inf above diagonal | |
| # Start with zeros, then mask_fill the upper triangle | |
| tgt_mask = torch.zeros(seq_len, seq_len, device=embed.device, dtype=torch.float32) | |
| if seq_len > 1: | |
| tgt_mask = tgt_mask.masked_fill( | |
| torch.triu(torch.ones(seq_len, seq_len, device=embed.device), diagonal=1).bool(), | |
| float('-inf') | |
| ) | |
| # memory_key_padding_mask | |
| batch_size = attn_emb.shape[0] | |
| max_len = attn_emb.shape[1] | |
| # Create range [0, 1, ..., max_len-1] | |
| arange = torch.arange(max_len, device=attn_emb.device).unsqueeze(0).expand(batch_size, -1) | |
| # Mask is True where arange >= length | |
| memory_key_padding_mask = arange >= attn_emb_len.unsqueeze(1) | |
| # tgt_key_padding_mask (cap_padding_mask) | |
| # For generation, we assume no padding in word_ids (all valid) | |
| tgt_key_padding_mask = torch.zeros(word_ids.shape[0], word_ids.shape[1], dtype=torch.bool, device=word_ids.device) | |
| # 6. Inner Decoder Call | |
| # Pass BOTH the mask AND is_causal=True | |
| # Do NOT call generate_square_subsequent_mask as it might have detection logic | |
| output = self.decoder.model( | |
| embed, | |
| p_attn_emb, | |
| tgt_mask=tgt_mask, # Static causal mask | |
| tgt_is_causal=True, # Hint for optimization | |
| tgt_key_padding_mask=tgt_key_padding_mask, | |
| memory_key_padding_mask=memory_key_padding_mask | |
| ) | |
| output = output.transpose(0, 1) # [batch, seq, dim] | |
| logits = self.decoder.classifier(output) | |
| return logits | |
| # Get vocab size | |
| tokenizer = AutoTokenizer.from_pretrained("wsntxxn/audiocaps-simple-tokenizer", trust_remote_code=True) | |
| vocab_size = len(tokenizer) | |
| # Create wrapper | |
| wrapper = DecoderStepWrapper(decoder, vocab_size) | |
| wrapper.eval() | |
| # Test with dummy input | |
| device = torch.device("cpu") | |
| wrapper = wrapper.to(device) | |
| # Get encoder output for attn_emb | |
| # Use the existing ONNX encoder to avoid HF encoder complications | |
| print("\nLoading ONNX encoder to get attn_emb...") | |
| import onnxruntime as ort | |
| import numpy as np | |
| encoder_onnx_path = "audio-caption/effb2_encoder_preprocess.onnx" | |
| enc_sess = ort.InferenceSession(encoder_onnx_path) | |
| # Create exactly 5 seconds of audio (production use case) | |
| sample_rate = 16000 | |
| dummy_audio_np = np.random.randn(1, sample_rate * 5).astype(np.float32) | |
| enc_in_name = enc_sess.get_inputs()[0].name | |
| enc_out_name = enc_sess.get_outputs()[0].name | |
| attn_emb_np = enc_sess.run([enc_out_name], {enc_in_name: dummy_audio_np})[0] | |
| attn_emb = torch.from_numpy(attn_emb_np) | |
| attn_emb_len = torch.tensor([attn_emb.shape[1] - 1], dtype=torch.int64) | |
| print(f"attn_emb shape for 5-sec audio: {attn_emb.shape}") | |
| # Try exporting with variable sequence length | |
| # Start with seq_len=1, then test with seq_len=5 | |
| for seq_len in [1, 5]: | |
| print(f"\n--- Testing with seq_len={seq_len} ---") | |
| dummy_input_ids = torch.randint(0, vocab_size, (1, seq_len), dtype=torch.long) | |
| with torch.no_grad(): | |
| test_out = wrapper(dummy_input_ids, attn_emb, attn_emb_len) | |
| print(f"β Forward pass successful! Output shape: {test_out.shape}") | |
| # Now try to export with dynamic shapes using torch.export | |
| print("\n--- Attempting ExecuTorch Export ---") | |
| try: | |
| from executorch.exir import to_edge | |
| from torch.export import export, Dim | |
| # Define dynamic dimensions following PyTorch's suggestions | |
| # batch is always 1 for mobile inference (PyTorch detected this) | |
| # seq can vary from 1 to max_seq_len | |
| seq = Dim("seq", max=100) | |
| dynamic_shapes = { | |
| "word_ids": {1: seq}, # Only seq dim is dynamic | |
| "attn_emb": {}, # No dynamic dims (batch=1, time is fixed per audio) | |
| "attn_emb_len": {}, # Scalar-like | |
| } | |
| # Export with a mid-range example (seq_len=3) to show it's variable | |
| example_inputs = ( | |
| torch.randint(0, vocab_size, (1, 3), dtype=torch.long), | |
| attn_emb, | |
| attn_emb_len | |
| ) | |
| print("Exporting with torch.export (seq_len=3 example)...") | |
| exported_program = export( | |
| wrapper, | |
| example_inputs, | |
| dynamic_shapes=dynamic_shapes | |
| ) | |
| print("β torch.export successful!") | |
| print("Converting to ExecuTorch edge dialect...") | |
| edge_program = to_edge(exported_program) | |
| print("β Edge conversion successful!") | |
| # Save as .pte | |
| with open(args.out, 'wb') as f: | |
| edge_program.to_executorch().write_to_file(f) | |
| print(f"β ExecuTorch export done: {args.out}") | |
| print("\nπ This .pte model supports dynamic sequence lengths!") | |
| print(" You can pass (batch, 1), (batch, 2), ..., (batch, 30) at inference") | |
| except ImportError: | |
| print("β ExecuTorch not installed. Install with:") | |
| print(" pip install executorch") | |
| except Exception as e: | |
| print(f"β ExecuTorch export failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print("\nFalling back to regular torch.export (no ExecuTorch)") | |
| # Try just torch.export to see if that works | |
| try: | |
| from torch.export import export, Dim | |
| batch = Dim("batch", min=1, max=4) | |
| seq = Dim("seq", min=1, max=30) | |
| time = Dim("time", min=1, max=100) | |
| dynamic_shapes = { | |
| "word_ids": {0: batch, 1: seq}, | |
| "attn_emb": {0: batch, 1: time}, | |
| "attn_emb_len": {0: batch}, | |
| } | |
| example_inputs = ( | |
| torch.randint(0, vocab_size, (1, 1), dtype=torch.long), | |
| attn_emb, | |
| attn_emb_len | |
| ) | |
| exported_program = export(wrapper, example_inputs, dynamic_shapes=dynamic_shapes) | |
| print("β torch.export successful (without ExecuTorch conversion)") | |
| print(" Dynamic shapes are supported in the exported graph") | |
| except Exception as e2: | |
| print(f"β torch.export also failed: {e2}") | |
| if __name__ == "__main__": | |
| main() | |