File size: 3,893 Bytes
0abf616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/env python3
import os
import torch
import torch.nn as nn
import onnx
from sam_audio.model.vision_encoder import PerceptionEncoder
from onnx_export.standalone_config import PerceptionEncoderConfig

class VisionEncoderWrapper(nn.Module):
    """
    Wrapper for the Vision Encoder (CLIP visual backbone).
    """
    def __init__(self, vision_encoder):
        super().__init__()
        self.model = vision_encoder.model
        self.normalize = vision_encoder.normalize_feature

    def forward(self, x):
        # x: (N, 3, H, W) where N is number of frames
        # returns: (N, 1024) features
        return self.model.encode_image(x, normalize=self.normalize)

def export_vision_encoder(model_id="facebook/sam-audio-small", output_dir="onnx_models"):
    """Export the vision encoder to ONNX."""
    print(f"Loading Vision Encoder from {model_id}...")
    
    import torch
    from transformers import AutoConfig
    from sam_audio.model.vision_encoder import PerceptionEncoder
    from onnx_export.standalone_config import PerceptionEncoderConfig
    
    print("Fetching config...")
    cfg_hf = AutoConfig.from_pretrained(model_id)
    cfg_dict = cfg_hf.to_dict()
    
    # Extract vision encoder config
    v_cfg_dict = cfg_dict.get("vision_encoder", {})
    v_cfg = PerceptionEncoderConfig(**v_cfg_dict)
    
    print(f"Initializing PerceptionEncoder with name: {v_cfg.name}...")
    vision_encoder = PerceptionEncoder(v_cfg)
    
    # Load weights from checkpoint
    print("Loading weights from SAM Audio checkpoint...")
    from huggingface_hub import hf_hub_download
    checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt")
    state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True)
    
    # Filter vision encoder weights
    vision_state = {}
    prefix = "vision_encoder."
    for key, value in state_dict.items():
        if key.startswith(prefix):
            new_key = key[len(prefix):]
            vision_state[new_key] = value
            
    if vision_state:
        print(f"  Loading {len(vision_state)} tensors into vision encoder...")
        vision_encoder.load_state_dict(vision_state)
        print("  ✓ Vision encoder weights loaded.")
    else:
        print("  WARNING: No 'vision_encoder' weights found in checkpoint. Using base weights.")

    image_size = vision_encoder.image_size
    print(f"  Image size: {image_size}")


    wrapper = VisionEncoderWrapper(vision_encoder).eval()
    
    # Create dummy input
    image_size = vision_encoder.image_size
    dummy_input = torch.randn(1, 3, image_size, image_size)
    
    output_path = os.path.join(output_dir, "vision_encoder.onnx")
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Exporting to {output_path}...")
    input_names = ["video_frames"]
    output_names = ["vision_features"]
    opset_version = 17 # Using 17 for better support of ViT/ConvNext
    torch.onnx.export(
        wrapper,
        dummy_input,
        output_path,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes={
            "video_frames": {0: "num_frames"},
            "vision_features": {0: "num_frames"},
        },
        opset_version=opset_version,
        do_constant_folding=True,
        dynamo=False,
        external_data=True,
    )
    
    # Check if data was saved separately
    data_path = output_path + ".data"
    if os.path.exists(data_path):
        print(f"  Large model detected, weights saved to {data_path}")
    
    print("✓ Vision encoder export complete!")

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="facebook/sam-audio-small")
    parser.add_argument("--output", type=str, default="onnx_models")
    args = parser.parse_args()
    
    export_vision_encoder(args.model, args.output)