|
|
|
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import onnx |
|
|
from sam_audio.model.vision_encoder import PerceptionEncoder |
|
|
from onnx_export.standalone_config import PerceptionEncoderConfig |
|
|
|
|
|
class VisionEncoderWrapper(nn.Module): |
|
|
""" |
|
|
Wrapper for the Vision Encoder (CLIP visual backbone). |
|
|
""" |
|
|
def __init__(self, vision_encoder): |
|
|
super().__init__() |
|
|
self.model = vision_encoder.model |
|
|
self.normalize = vision_encoder.normalize_feature |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
return self.model.encode_image(x, normalize=self.normalize) |
|
|
|
|
|
def export_vision_encoder(model_id="facebook/sam-audio-small", output_dir="onnx_models"): |
|
|
"""Export the vision encoder to ONNX.""" |
|
|
print(f"Loading Vision Encoder from {model_id}...") |
|
|
|
|
|
import torch |
|
|
from transformers import AutoConfig |
|
|
from sam_audio.model.vision_encoder import PerceptionEncoder |
|
|
from onnx_export.standalone_config import PerceptionEncoderConfig |
|
|
|
|
|
print("Fetching config...") |
|
|
cfg_hf = AutoConfig.from_pretrained(model_id) |
|
|
cfg_dict = cfg_hf.to_dict() |
|
|
|
|
|
|
|
|
v_cfg_dict = cfg_dict.get("vision_encoder", {}) |
|
|
v_cfg = PerceptionEncoderConfig(**v_cfg_dict) |
|
|
|
|
|
print(f"Initializing PerceptionEncoder with name: {v_cfg.name}...") |
|
|
vision_encoder = PerceptionEncoder(v_cfg) |
|
|
|
|
|
|
|
|
print("Loading weights from SAM Audio checkpoint...") |
|
|
from huggingface_hub import hf_hub_download |
|
|
checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt") |
|
|
state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True) |
|
|
|
|
|
|
|
|
vision_state = {} |
|
|
prefix = "vision_encoder." |
|
|
for key, value in state_dict.items(): |
|
|
if key.startswith(prefix): |
|
|
new_key = key[len(prefix):] |
|
|
vision_state[new_key] = value |
|
|
|
|
|
if vision_state: |
|
|
print(f" Loading {len(vision_state)} tensors into vision encoder...") |
|
|
vision_encoder.load_state_dict(vision_state) |
|
|
print(" ✓ Vision encoder weights loaded.") |
|
|
else: |
|
|
print(" WARNING: No 'vision_encoder' weights found in checkpoint. Using base weights.") |
|
|
|
|
|
image_size = vision_encoder.image_size |
|
|
print(f" Image size: {image_size}") |
|
|
|
|
|
|
|
|
wrapper = VisionEncoderWrapper(vision_encoder).eval() |
|
|
|
|
|
|
|
|
image_size = vision_encoder.image_size |
|
|
dummy_input = torch.randn(1, 3, image_size, image_size) |
|
|
|
|
|
output_path = os.path.join(output_dir, "vision_encoder.onnx") |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
print(f"Exporting to {output_path}...") |
|
|
input_names = ["video_frames"] |
|
|
output_names = ["vision_features"] |
|
|
opset_version = 17 |
|
|
torch.onnx.export( |
|
|
wrapper, |
|
|
dummy_input, |
|
|
output_path, |
|
|
input_names=input_names, |
|
|
output_names=output_names, |
|
|
dynamic_axes={ |
|
|
"video_frames": {0: "num_frames"}, |
|
|
"vision_features": {0: "num_frames"}, |
|
|
}, |
|
|
opset_version=opset_version, |
|
|
do_constant_folding=True, |
|
|
dynamo=False, |
|
|
external_data=True, |
|
|
) |
|
|
|
|
|
|
|
|
data_path = output_path + ".data" |
|
|
if os.path.exists(data_path): |
|
|
print(f" Large model detected, weights saved to {data_path}") |
|
|
|
|
|
print("✓ Vision encoder export complete!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model", type=str, default="facebook/sam-audio-small") |
|
|
parser.add_argument("--output", type=str, default="onnx_models") |
|
|
args = parser.parse_args() |
|
|
|
|
|
export_vision_encoder(args.model, args.output) |
|
|
|