File size: 1,348 Bytes
1c1cd5e
 
 
 
 
 
65dd0ae
1c1cd5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65dd0ae
1c1cd5e
 
 
65dd0ae
1c1cd5e
 
65dd0ae
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import urllib

import torch
from segment_anything.modeling import Sam

from custom_encoder import build_sam_vit_h_torchscript

CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
MODEL_TYPE = "default"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_model(
    checkpoint_path: str = CHECKPOINT_PATH,
    checkpoint_name: str = CHECKPOINT_NAME,
    checkpoint_url: str = CHECKPOINT_URL,
    model_type: str = MODEL_TYPE,
) -> Sam:
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    checkpoint = os.path.join(checkpoint_path, checkpoint_name)
    if not os.path.exists(checkpoint):
        print("Downloading the model weights...")
        urllib.request.urlretrieve(checkpoint_url, checkpoint)
        print(f"The model weights saved as {checkpoint}")
    print(f"Load the model weights from {checkpoint}")
    return build_sam_vit_h_torchscript(checkpoint=checkpoint)


if __name__ == "__main__":
    model = load_model().image_encoder.eval().to(device)

    with torch.jit.optimized_execution(True):
        script_model = torch.jit.script(model)
    script_model.save("model_repository/sam_torchscript_fp32/model.pt")