File size: 1,634 Bytes
1c1cd5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import os
import urllib
import torch
from segment_anything import sam_model_registry
from segment_anything.modeling import Sam
from wrapper import ImageEncoderViTWrapper
CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
MODEL_TYPE = "default"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(
checkpoint_path: str = CHECKPOINT_PATH,
checkpoint_name: str = CHECKPOINT_NAME,
checkpoint_url: str = CHECKPOINT_URL,
model_type: str = MODEL_TYPE,
) -> Sam:
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)
checkpoint = os.path.join(checkpoint_path, checkpoint_name)
if not os.path.exists(checkpoint):
print("Downloading the model weights...")
urllib.request.urlretrieve(checkpoint_url, checkpoint)
print(f"The model weights saved as {checkpoint}")
print(f"Load the model weights from {checkpoint}")
return sam_model_registry[model_type](checkpoint=checkpoint)
if __name__ == "__main__":
# model = load_model().image_encoder.eval().to(device)
image_encoder = load_model().image_encoder
print(type(image_encoder))
image_encoder_wrapper = ImageEncoderViTWrapper(image_encoder).eval().to(device)
image_encoder_wrapper.change_block()
print(type(image_encoder_wrapper.image_encoder.blocks[0]))
with torch.jit.optimized_execution(True):
script_model = torch.jit.script(image_encoder_wrapper)
script_model.save("model.pt")
|