sam-audio-small-onnx / export_vision.py
matbee's picture
Upload export_vision.py with huggingface_hub
0abf616 verified
#!/usr/bin/env python3
import os
import torch
import torch.nn as nn
import onnx
from sam_audio.model.vision_encoder import PerceptionEncoder
from onnx_export.standalone_config import PerceptionEncoderConfig
class VisionEncoderWrapper(nn.Module):
"""
Wrapper for the Vision Encoder (CLIP visual backbone).
"""
def __init__(self, vision_encoder):
super().__init__()
self.model = vision_encoder.model
self.normalize = vision_encoder.normalize_feature
def forward(self, x):
# x: (N, 3, H, W) where N is number of frames
# returns: (N, 1024) features
return self.model.encode_image(x, normalize=self.normalize)
def export_vision_encoder(model_id="facebook/sam-audio-small", output_dir="onnx_models"):
"""Export the vision encoder to ONNX."""
print(f"Loading Vision Encoder from {model_id}...")
import torch
from transformers import AutoConfig
from sam_audio.model.vision_encoder import PerceptionEncoder
from onnx_export.standalone_config import PerceptionEncoderConfig
print("Fetching config...")
cfg_hf = AutoConfig.from_pretrained(model_id)
cfg_dict = cfg_hf.to_dict()
# Extract vision encoder config
v_cfg_dict = cfg_dict.get("vision_encoder", {})
v_cfg = PerceptionEncoderConfig(**v_cfg_dict)
print(f"Initializing PerceptionEncoder with name: {v_cfg.name}...")
vision_encoder = PerceptionEncoder(v_cfg)
# Load weights from checkpoint
print("Loading weights from SAM Audio checkpoint...")
from huggingface_hub import hf_hub_download
checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt")
state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True)
# Filter vision encoder weights
vision_state = {}
prefix = "vision_encoder."
for key, value in state_dict.items():
if key.startswith(prefix):
new_key = key[len(prefix):]
vision_state[new_key] = value
if vision_state:
print(f" Loading {len(vision_state)} tensors into vision encoder...")
vision_encoder.load_state_dict(vision_state)
print(" ✓ Vision encoder weights loaded.")
else:
print(" WARNING: No 'vision_encoder' weights found in checkpoint. Using base weights.")
image_size = vision_encoder.image_size
print(f" Image size: {image_size}")
wrapper = VisionEncoderWrapper(vision_encoder).eval()
# Create dummy input
image_size = vision_encoder.image_size
dummy_input = torch.randn(1, 3, image_size, image_size)
output_path = os.path.join(output_dir, "vision_encoder.onnx")
os.makedirs(output_dir, exist_ok=True)
print(f"Exporting to {output_path}...")
input_names = ["video_frames"]
output_names = ["vision_features"]
opset_version = 17 # Using 17 for better support of ViT/ConvNext
torch.onnx.export(
wrapper,
dummy_input,
output_path,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"video_frames": {0: "num_frames"},
"vision_features": {0: "num_frames"},
},
opset_version=opset_version,
do_constant_folding=True,
dynamo=False,
external_data=True,
)
# Check if data was saved separately
data_path = output_path + ".data"
if os.path.exists(data_path):
print(f" Large model detected, weights saved to {data_path}")
print("✓ Vision encoder export complete!")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="facebook/sam-audio-small")
parser.add_argument("--output", type=str, default="onnx_models")
args = parser.parse_args()
export_vision_encoder(args.model, args.output)