#!/usr/bin/env python3 import os import torch import torch.nn as nn import onnx from sam_audio.model.vision_encoder import PerceptionEncoder from onnx_export.standalone_config import PerceptionEncoderConfig class VisionEncoderWrapper(nn.Module): """ Wrapper for the Vision Encoder (CLIP visual backbone). """ def __init__(self, vision_encoder): super().__init__() self.model = vision_encoder.model self.normalize = vision_encoder.normalize_feature def forward(self, x): # x: (N, 3, H, W) where N is number of frames # returns: (N, 1024) features return self.model.encode_image(x, normalize=self.normalize) def export_vision_encoder(model_id="facebook/sam-audio-small", output_dir="onnx_models"): """Export the vision encoder to ONNX.""" print(f"Loading Vision Encoder from {model_id}...") import torch from transformers import AutoConfig from sam_audio.model.vision_encoder import PerceptionEncoder from onnx_export.standalone_config import PerceptionEncoderConfig print("Fetching config...") cfg_hf = AutoConfig.from_pretrained(model_id) cfg_dict = cfg_hf.to_dict() # Extract vision encoder config v_cfg_dict = cfg_dict.get("vision_encoder", {}) v_cfg = PerceptionEncoderConfig(**v_cfg_dict) print(f"Initializing PerceptionEncoder with name: {v_cfg.name}...") vision_encoder = PerceptionEncoder(v_cfg) # Load weights from checkpoint print("Loading weights from SAM Audio checkpoint...") from huggingface_hub import hf_hub_download checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt") state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True) # Filter vision encoder weights vision_state = {} prefix = "vision_encoder." for key, value in state_dict.items(): if key.startswith(prefix): new_key = key[len(prefix):] vision_state[new_key] = value if vision_state: print(f" Loading {len(vision_state)} tensors into vision encoder...") vision_encoder.load_state_dict(vision_state) print(" ✓ Vision encoder weights loaded.") else: print(" WARNING: No 'vision_encoder' weights found in checkpoint. Using base weights.") image_size = vision_encoder.image_size print(f" Image size: {image_size}") wrapper = VisionEncoderWrapper(vision_encoder).eval() # Create dummy input image_size = vision_encoder.image_size dummy_input = torch.randn(1, 3, image_size, image_size) output_path = os.path.join(output_dir, "vision_encoder.onnx") os.makedirs(output_dir, exist_ok=True) print(f"Exporting to {output_path}...") input_names = ["video_frames"] output_names = ["vision_features"] opset_version = 17 # Using 17 for better support of ViT/ConvNext torch.onnx.export( wrapper, dummy_input, output_path, input_names=input_names, output_names=output_names, dynamic_axes={ "video_frames": {0: "num_frames"}, "vision_features": {0: "num_frames"}, }, opset_version=opset_version, do_constant_folding=True, dynamo=False, external_data=True, ) # Check if data was saved separately data_path = output_path + ".data" if os.path.exists(data_path): print(f" Large model detected, weights saved to {data_path}") print("✓ Vision encoder export complete!") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="facebook/sam-audio-small") parser.add_argument("--output", type=str, default="onnx_models") args = parser.parse_args() export_vision_encoder(args.model, args.output)