File size: 1,634 Bytes
1c1cd5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import urllib

import torch
from segment_anything import sam_model_registry
from segment_anything.modeling import Sam

from wrapper import ImageEncoderViTWrapper

CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
MODEL_TYPE = "default"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_model(
    checkpoint_path: str = CHECKPOINT_PATH,
    checkpoint_name: str = CHECKPOINT_NAME,
    checkpoint_url: str = CHECKPOINT_URL,
    model_type: str = MODEL_TYPE,
) -> Sam:
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    checkpoint = os.path.join(checkpoint_path, checkpoint_name)
    if not os.path.exists(checkpoint):
        print("Downloading the model weights...")
        urllib.request.urlretrieve(checkpoint_url, checkpoint)
        print(f"The model weights saved as {checkpoint}")
    print(f"Load the model weights from {checkpoint}")
    return sam_model_registry[model_type](checkpoint=checkpoint)


if __name__ == "__main__":
    # model = load_model().image_encoder.eval().to(device)
    image_encoder = load_model().image_encoder
    print(type(image_encoder))
    image_encoder_wrapper = ImageEncoderViTWrapper(image_encoder).eval().to(device)
    image_encoder_wrapper.change_block()

    print(type(image_encoder_wrapper.image_encoder.blocks[0]))

    with torch.jit.optimized_execution(True):
        script_model = torch.jit.script(image_encoder_wrapper)
    script_model.save("model.pt")