import os import urllib import torch from segment_anything import sam_model_registry from segment_anything.modeling import Sam from wrapper import ImageEncoderViTWrapper CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM") CHECKPOINT_NAME = "sam_vit_h_4b8939.pth" CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" MODEL_TYPE = "default" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model( checkpoint_path: str = CHECKPOINT_PATH, checkpoint_name: str = CHECKPOINT_NAME, checkpoint_url: str = CHECKPOINT_URL, model_type: str = MODEL_TYPE, ) -> Sam: if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) checkpoint = os.path.join(checkpoint_path, checkpoint_name) if not os.path.exists(checkpoint): print("Downloading the model weights...") urllib.request.urlretrieve(checkpoint_url, checkpoint) print(f"The model weights saved as {checkpoint}") print(f"Load the model weights from {checkpoint}") return sam_model_registry[model_type](checkpoint=checkpoint) if __name__ == "__main__": # model = load_model().image_encoder.eval().to(device) image_encoder = load_model().image_encoder print(type(image_encoder)) image_encoder_wrapper = ImageEncoderViTWrapper(image_encoder).eval().to(device) image_encoder_wrapper.change_block() print(type(image_encoder_wrapper.image_encoder.blocks[0])) with torch.jit.optimized_execution(True): script_model = torch.jit.script(image_encoder_wrapper) script_model.save("model.pt")