| import os | |
| import urllib | |
| import torch | |
| from segment_anything.modeling import Sam | |
| from custom_encoder import build_sam_vit_h_torchscript | |
| CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM") | |
| CHECKPOINT_NAME = "sam_vit_h_4b8939.pth" | |
| CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" | |
| MODEL_TYPE = "default" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model( | |
| checkpoint_path: str = CHECKPOINT_PATH, | |
| checkpoint_name: str = CHECKPOINT_NAME, | |
| checkpoint_url: str = CHECKPOINT_URL, | |
| model_type: str = MODEL_TYPE, | |
| ) -> Sam: | |
| if not os.path.exists(checkpoint_path): | |
| os.makedirs(checkpoint_path) | |
| checkpoint = os.path.join(checkpoint_path, checkpoint_name) | |
| if not os.path.exists(checkpoint): | |
| print("Downloading the model weights...") | |
| urllib.request.urlretrieve(checkpoint_url, checkpoint) | |
| print(f"The model weights saved as {checkpoint}") | |
| print(f"Load the model weights from {checkpoint}") | |
| return build_sam_vit_h_torchscript(checkpoint=checkpoint) | |
| if __name__ == "__main__": | |
| model = load_model().image_encoder.eval().to(device) | |
| with torch.jit.optimized_execution(True): | |
| script_model = torch.jit.script(model) | |
| script_model.save("model_repository/sam_torchscript_fp32/model.pt") | |