| | |
| | import os |
| | import torch |
| | import torch.nn as nn |
| | import onnx |
| | from sam_audio.model.vision_encoder import PerceptionEncoder |
| | from onnx_export.standalone_config import PerceptionEncoderConfig |
| |
|
| | class VisionEncoderWrapper(nn.Module): |
| | """ |
| | Wrapper for the Vision Encoder (CLIP visual backbone). |
| | """ |
| | def __init__(self, vision_encoder): |
| | super().__init__() |
| | self.model = vision_encoder.model |
| | self.normalize = vision_encoder.normalize_feature |
| |
|
| | def forward(self, x): |
| | |
| | |
| | return self.model.encode_image(x, normalize=self.normalize) |
| |
|
| | def export_vision_encoder(model_id="facebook/sam-audio-small", output_dir="onnx_models"): |
| | """Export the vision encoder to ONNX.""" |
| | print(f"Loading Vision Encoder from {model_id}...") |
| | |
| | import torch |
| | from transformers import AutoConfig |
| | from sam_audio.model.vision_encoder import PerceptionEncoder |
| | from onnx_export.standalone_config import PerceptionEncoderConfig |
| | |
| | print("Fetching config...") |
| | cfg_hf = AutoConfig.from_pretrained(model_id) |
| | cfg_dict = cfg_hf.to_dict() |
| | |
| | |
| | v_cfg_dict = cfg_dict.get("vision_encoder", {}) |
| | v_cfg = PerceptionEncoderConfig(**v_cfg_dict) |
| | |
| | print(f"Initializing PerceptionEncoder with name: {v_cfg.name}...") |
| | vision_encoder = PerceptionEncoder(v_cfg) |
| | |
| | |
| | print("Loading weights from SAM Audio checkpoint...") |
| | from huggingface_hub import hf_hub_download |
| | checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt") |
| | state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True) |
| | |
| | |
| | vision_state = {} |
| | prefix = "vision_encoder." |
| | for key, value in state_dict.items(): |
| | if key.startswith(prefix): |
| | new_key = key[len(prefix):] |
| | vision_state[new_key] = value |
| | |
| | if vision_state: |
| | print(f" Loading {len(vision_state)} tensors into vision encoder...") |
| | vision_encoder.load_state_dict(vision_state) |
| | print(" ✓ Vision encoder weights loaded.") |
| | else: |
| | print(" WARNING: No 'vision_encoder' weights found in checkpoint. Using base weights.") |
| |
|
| | image_size = vision_encoder.image_size |
| | print(f" Image size: {image_size}") |
| |
|
| |
|
| | wrapper = VisionEncoderWrapper(vision_encoder).eval() |
| | |
| | |
| | image_size = vision_encoder.image_size |
| | dummy_input = torch.randn(1, 3, image_size, image_size) |
| | |
| | output_path = os.path.join(output_dir, "vision_encoder.onnx") |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | print(f"Exporting to {output_path}...") |
| | input_names = ["video_frames"] |
| | output_names = ["vision_features"] |
| | opset_version = 17 |
| | torch.onnx.export( |
| | wrapper, |
| | dummy_input, |
| | output_path, |
| | input_names=input_names, |
| | output_names=output_names, |
| | dynamic_axes={ |
| | "video_frames": {0: "num_frames"}, |
| | "vision_features": {0: "num_frames"}, |
| | }, |
| | opset_version=opset_version, |
| | do_constant_folding=True, |
| | dynamo=False, |
| | external_data=True, |
| | ) |
| | |
| | |
| | data_path = output_path + ".data" |
| | if os.path.exists(data_path): |
| | print(f" Large model detected, weights saved to {data_path}") |
| | |
| | print("✓ Vision encoder export complete!") |
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--model", type=str, default="facebook/sam-audio-small") |
| | parser.add_argument("--output", type=str, default="onnx_models") |
| | args = parser.parse_args() |
| | |
| | export_vision_encoder(args.model, args.output) |
| |
|