|
|
import os |
|
|
import urllib |
|
|
|
|
|
import torch |
|
|
from segment_anything import sam_model_registry |
|
|
from segment_anything.modeling import Sam |
|
|
|
|
|
from wrapper import ImageEncoderViTWrapper |
|
|
|
|
|
CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM") |
|
|
CHECKPOINT_NAME = "sam_vit_h_4b8939.pth" |
|
|
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" |
|
|
MODEL_TYPE = "default" |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
def load_model( |
|
|
checkpoint_path: str = CHECKPOINT_PATH, |
|
|
checkpoint_name: str = CHECKPOINT_NAME, |
|
|
checkpoint_url: str = CHECKPOINT_URL, |
|
|
model_type: str = MODEL_TYPE, |
|
|
) -> Sam: |
|
|
if not os.path.exists(checkpoint_path): |
|
|
os.makedirs(checkpoint_path) |
|
|
checkpoint = os.path.join(checkpoint_path, checkpoint_name) |
|
|
if not os.path.exists(checkpoint): |
|
|
print("Downloading the model weights...") |
|
|
urllib.request.urlretrieve(checkpoint_url, checkpoint) |
|
|
print(f"The model weights saved as {checkpoint}") |
|
|
print(f"Load the model weights from {checkpoint}") |
|
|
return sam_model_registry[model_type](checkpoint=checkpoint) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
image_encoder = load_model().image_encoder |
|
|
print(type(image_encoder)) |
|
|
image_encoder_wrapper = ImageEncoderViTWrapper(image_encoder).eval().to(device) |
|
|
image_encoder_wrapper.change_block() |
|
|
|
|
|
print(type(image_encoder_wrapper.image_encoder.blocks[0])) |
|
|
|
|
|
with torch.jit.optimized_execution(True): |
|
|
script_model = torch.jit.script(image_encoder_wrapper) |
|
|
script_model.save("model.pt") |
|
|
|