khsyee commited on
Commit
1c1cd5e
·
1 Parent(s): 78c1bf2

Change using wrapper. but fail

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. Makefile +8 -0
  3. load_model.py +45 -0
  4. requirements.txt +3 -0
  5. wrapper.py +41 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
Makefile ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ PYTHON=3.9
2
+ BASENAME=sam-vit-h-encoder-torchscript
3
+
4
+ env:
5
+ conda create -n $(BASENAME) python=$(PYTHON) -y
6
+
7
+ setup:
8
+ pip install -r requirements.txt
load_model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib
3
+
4
+ import torch
5
+ from segment_anything import sam_model_registry
6
+ from segment_anything.modeling import Sam
7
+
8
+ from wrapper import ImageEncoderViTWrapper
9
+
10
+ CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
11
+ CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
12
+ CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
13
+ MODEL_TYPE = "default"
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+
17
+ def load_model(
18
+ checkpoint_path: str = CHECKPOINT_PATH,
19
+ checkpoint_name: str = CHECKPOINT_NAME,
20
+ checkpoint_url: str = CHECKPOINT_URL,
21
+ model_type: str = MODEL_TYPE,
22
+ ) -> Sam:
23
+ if not os.path.exists(checkpoint_path):
24
+ os.makedirs(checkpoint_path)
25
+ checkpoint = os.path.join(checkpoint_path, checkpoint_name)
26
+ if not os.path.exists(checkpoint):
27
+ print("Downloading the model weights...")
28
+ urllib.request.urlretrieve(checkpoint_url, checkpoint)
29
+ print(f"The model weights saved as {checkpoint}")
30
+ print(f"Load the model weights from {checkpoint}")
31
+ return sam_model_registry[model_type](checkpoint=checkpoint)
32
+
33
+
34
+ if __name__ == "__main__":
35
+ # model = load_model().image_encoder.eval().to(device)
36
+ image_encoder = load_model().image_encoder
37
+ print(type(image_encoder))
38
+ image_encoder_wrapper = ImageEncoderViTWrapper(image_encoder).eval().to(device)
39
+ image_encoder_wrapper.change_block()
40
+
41
+ print(type(image_encoder_wrapper.image_encoder.blocks[0]))
42
+
43
+ with torch.jit.optimized_execution(True):
44
+ script_model = torch.jit.script(image_encoder_wrapper)
45
+ script_model.save("model.pt")
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch == 2.0.0
2
+ torchvision == 0.15.1
3
+ git+https://github.com/facebookresearch/segment-anything.git
wrapper.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from segment_anything.modeling import ImageEncoderViT
5
+ from segment_anything.modeling.image_encoder import Block, window_partition, window_unpartition
6
+
7
+
8
+ class BlockWrapper(nn.Module):
9
+ def __init__(self, block: Block):
10
+ super().__init__()
11
+ self.block = block
12
+
13
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
14
+ shortcut = x
15
+ x = self.block.norm1(x)
16
+ # Window partition
17
+ if self.block.window_size > 0:
18
+ H, W = x.shape[1], x.shape[2]
19
+ x, pad_hw = window_partition(x, self.block.window_size)
20
+ x = self.block.attn(x)
21
+ # Reverse window partition
22
+ x = window_unpartition(x, self.block.window_size, pad_hw, (H, W))
23
+ else:
24
+ x = self.block.attn(x)
25
+
26
+ x = shortcut + x
27
+ x = x + self.block.mlp(self.block.norm2(x))
28
+
29
+ return x
30
+
31
+
32
+ class ImageEncoderViTWrapper(nn.Module):
33
+ def __init__(self, image_encoder: ImageEncoderViT):
34
+ super().__init__()
35
+ self.image_encoder = image_encoder
36
+
37
+ def change_block(self):
38
+ block_wrappers = nn.ModuleList()
39
+ for block in self.image_encoder.blocks:
40
+ block_wrappers.append(BlockWrapper(block))
41
+ self.image_encoder.blocks = block_wrappers