matbee commited on
Commit
4e44c24
·
verified ·
1 Parent(s): 2d75c95

Delete files export_vision.py vision_encoder.onnx vision_encoder.onnx.data with huggingface_hub

Browse files
export_vision.py DELETED
@@ -1,111 +0,0 @@
1
- #!/usr/bin/env python3
2
- import os
3
- import torch
4
- import torch.nn as nn
5
- import onnx
6
- from sam_audio.model.vision_encoder import PerceptionEncoder
7
- from onnx_export.standalone_config import PerceptionEncoderConfig
8
-
9
- class VisionEncoderWrapper(nn.Module):
10
- """
11
- Wrapper for the Vision Encoder (CLIP visual backbone).
12
- """
13
- def __init__(self, vision_encoder):
14
- super().__init__()
15
- self.model = vision_encoder.model
16
- self.normalize = vision_encoder.normalize_feature
17
-
18
- def forward(self, x):
19
- # x: (N, 3, H, W) where N is number of frames
20
- # returns: (N, 1024) features
21
- return self.model.encode_image(x, normalize=self.normalize)
22
-
23
- def export_vision_encoder(model_id="facebook/sam-audio-small", output_dir="onnx_models"):
24
- """Export the vision encoder to ONNX."""
25
- print(f"Loading Vision Encoder from {model_id}...")
26
-
27
- import torch
28
- from transformers import AutoConfig
29
- from sam_audio.model.vision_encoder import PerceptionEncoder
30
- from onnx_export.standalone_config import PerceptionEncoderConfig
31
-
32
- print("Fetching config...")
33
- cfg_hf = AutoConfig.from_pretrained(model_id)
34
- cfg_dict = cfg_hf.to_dict()
35
-
36
- # Extract vision encoder config
37
- v_cfg_dict = cfg_dict.get("vision_encoder", {})
38
- v_cfg = PerceptionEncoderConfig(**v_cfg_dict)
39
-
40
- print(f"Initializing PerceptionEncoder with name: {v_cfg.name}...")
41
- vision_encoder = PerceptionEncoder(v_cfg)
42
-
43
- # Load weights from checkpoint
44
- print("Loading weights from SAM Audio checkpoint...")
45
- from huggingface_hub import hf_hub_download
46
- checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt")
47
- state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True)
48
-
49
- # Filter vision encoder weights
50
- vision_state = {}
51
- prefix = "vision_encoder."
52
- for key, value in state_dict.items():
53
- if key.startswith(prefix):
54
- new_key = key[len(prefix):]
55
- vision_state[new_key] = value
56
-
57
- if vision_state:
58
- print(f" Loading {len(vision_state)} tensors into vision encoder...")
59
- vision_encoder.load_state_dict(vision_state)
60
- print(" ✓ Vision encoder weights loaded.")
61
- else:
62
- print(" WARNING: No 'vision_encoder' weights found in checkpoint. Using base weights.")
63
-
64
- image_size = vision_encoder.image_size
65
- print(f" Image size: {image_size}")
66
-
67
-
68
- wrapper = VisionEncoderWrapper(vision_encoder).eval()
69
-
70
- # Create dummy input
71
- image_size = vision_encoder.image_size
72
- dummy_input = torch.randn(1, 3, image_size, image_size)
73
-
74
- output_path = os.path.join(output_dir, "vision_encoder.onnx")
75
- os.makedirs(output_dir, exist_ok=True)
76
-
77
- print(f"Exporting to {output_path}...")
78
- input_names = ["video_frames"]
79
- output_names = ["vision_features"]
80
- opset_version = 17 # Using 17 for better support of ViT/ConvNext
81
- torch.onnx.export(
82
- wrapper,
83
- dummy_input,
84
- output_path,
85
- input_names=input_names,
86
- output_names=output_names,
87
- dynamic_axes={
88
- "video_frames": {0: "num_frames"},
89
- "vision_features": {0: "num_frames"},
90
- },
91
- opset_version=opset_version,
92
- do_constant_folding=True,
93
- dynamo=False,
94
- external_data=True,
95
- )
96
-
97
- # Check if data was saved separately
98
- data_path = output_path + ".data"
99
- if os.path.exists(data_path):
100
- print(f" Large model detected, weights saved to {data_path}")
101
-
102
- print("✓ Vision encoder export complete!")
103
-
104
- if __name__ == "__main__":
105
- import argparse
106
- parser = argparse.ArgumentParser()
107
- parser.add_argument("--model", type=str, default="facebook/sam-audio-small")
108
- parser.add_argument("--output", type=str, default="onnx_models")
109
- args = parser.parse_args()
110
-
111
- export_vision_encoder(args.model, args.output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vision_encoder.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2143dc1338dcc11905d425b39f55a030e712920c3ab7a96db19867ca6fd82126
3
- size 1269876747
 
 
 
 
vision_encoder.onnx.data DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a6226f043f9600557b0bd9273e68a758ba024d5841627fff59cc0ec4f2a83275
3
- size 1268318208