khsyee's picture
Change using wrapper. but fail
1c1cd5e
raw
history blame
1.63 kB
import os
import urllib
import torch
from segment_anything import sam_model_registry
from segment_anything.modeling import Sam
from wrapper import ImageEncoderViTWrapper
CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
MODEL_TYPE = "default"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(
checkpoint_path: str = CHECKPOINT_PATH,
checkpoint_name: str = CHECKPOINT_NAME,
checkpoint_url: str = CHECKPOINT_URL,
model_type: str = MODEL_TYPE,
) -> Sam:
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)
checkpoint = os.path.join(checkpoint_path, checkpoint_name)
if not os.path.exists(checkpoint):
print("Downloading the model weights...")
urllib.request.urlretrieve(checkpoint_url, checkpoint)
print(f"The model weights saved as {checkpoint}")
print(f"Load the model weights from {checkpoint}")
return sam_model_registry[model_type](checkpoint=checkpoint)
if __name__ == "__main__":
# model = load_model().image_encoder.eval().to(device)
image_encoder = load_model().image_encoder
print(type(image_encoder))
image_encoder_wrapper = ImageEncoderViTWrapper(image_encoder).eval().to(device)
image_encoder_wrapper.change_block()
print(type(image_encoder_wrapper.image_encoder.blocks[0]))
with torch.jit.optimized_execution(True):
script_model = torch.jit.script(image_encoder_wrapper)
script_model.save("model.pt")