matbee commited on
Commit
0abf616
·
verified ·
1 Parent(s): 9beb89b

Upload export_vision.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. export_vision.py +111 -0
export_vision.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import onnx
6
+ from sam_audio.model.vision_encoder import PerceptionEncoder
7
+ from onnx_export.standalone_config import PerceptionEncoderConfig
8
+
9
+ class VisionEncoderWrapper(nn.Module):
10
+ """
11
+ Wrapper for the Vision Encoder (CLIP visual backbone).
12
+ """
13
+ def __init__(self, vision_encoder):
14
+ super().__init__()
15
+ self.model = vision_encoder.model
16
+ self.normalize = vision_encoder.normalize_feature
17
+
18
+ def forward(self, x):
19
+ # x: (N, 3, H, W) where N is number of frames
20
+ # returns: (N, 1024) features
21
+ return self.model.encode_image(x, normalize=self.normalize)
22
+
23
+ def export_vision_encoder(model_id="facebook/sam-audio-small", output_dir="onnx_models"):
24
+ """Export the vision encoder to ONNX."""
25
+ print(f"Loading Vision Encoder from {model_id}...")
26
+
27
+ import torch
28
+ from transformers import AutoConfig
29
+ from sam_audio.model.vision_encoder import PerceptionEncoder
30
+ from onnx_export.standalone_config import PerceptionEncoderConfig
31
+
32
+ print("Fetching config...")
33
+ cfg_hf = AutoConfig.from_pretrained(model_id)
34
+ cfg_dict = cfg_hf.to_dict()
35
+
36
+ # Extract vision encoder config
37
+ v_cfg_dict = cfg_dict.get("vision_encoder", {})
38
+ v_cfg = PerceptionEncoderConfig(**v_cfg_dict)
39
+
40
+ print(f"Initializing PerceptionEncoder with name: {v_cfg.name}...")
41
+ vision_encoder = PerceptionEncoder(v_cfg)
42
+
43
+ # Load weights from checkpoint
44
+ print("Loading weights from SAM Audio checkpoint...")
45
+ from huggingface_hub import hf_hub_download
46
+ checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt")
47
+ state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True)
48
+
49
+ # Filter vision encoder weights
50
+ vision_state = {}
51
+ prefix = "vision_encoder."
52
+ for key, value in state_dict.items():
53
+ if key.startswith(prefix):
54
+ new_key = key[len(prefix):]
55
+ vision_state[new_key] = value
56
+
57
+ if vision_state:
58
+ print(f" Loading {len(vision_state)} tensors into vision encoder...")
59
+ vision_encoder.load_state_dict(vision_state)
60
+ print(" ✓ Vision encoder weights loaded.")
61
+ else:
62
+ print(" WARNING: No 'vision_encoder' weights found in checkpoint. Using base weights.")
63
+
64
+ image_size = vision_encoder.image_size
65
+ print(f" Image size: {image_size}")
66
+
67
+
68
+ wrapper = VisionEncoderWrapper(vision_encoder).eval()
69
+
70
+ # Create dummy input
71
+ image_size = vision_encoder.image_size
72
+ dummy_input = torch.randn(1, 3, image_size, image_size)
73
+
74
+ output_path = os.path.join(output_dir, "vision_encoder.onnx")
75
+ os.makedirs(output_dir, exist_ok=True)
76
+
77
+ print(f"Exporting to {output_path}...")
78
+ input_names = ["video_frames"]
79
+ output_names = ["vision_features"]
80
+ opset_version = 17 # Using 17 for better support of ViT/ConvNext
81
+ torch.onnx.export(
82
+ wrapper,
83
+ dummy_input,
84
+ output_path,
85
+ input_names=input_names,
86
+ output_names=output_names,
87
+ dynamic_axes={
88
+ "video_frames": {0: "num_frames"},
89
+ "vision_features": {0: "num_frames"},
90
+ },
91
+ opset_version=opset_version,
92
+ do_constant_folding=True,
93
+ dynamo=False,
94
+ external_data=True,
95
+ )
96
+
97
+ # Check if data was saved separately
98
+ data_path = output_path + ".data"
99
+ if os.path.exists(data_path):
100
+ print(f" Large model detected, weights saved to {data_path}")
101
+
102
+ print("✓ Vision encoder export complete!")
103
+
104
+ if __name__ == "__main__":
105
+ import argparse
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument("--model", type=str, default="facebook/sam-audio-small")
108
+ parser.add_argument("--output", type=str, default="onnx_models")
109
+ args = parser.parse_args()
110
+
111
+ export_vision_encoder(args.model, args.output)