khsyee commited on
Commit
65dd0ae
·
1 Parent(s): 1c1cd5e

Change using inheritance

Browse files
README.md CHANGED
@@ -1,3 +1,18 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ ## Run
6
+ Set conda env.
7
+ ```
8
+ make env
9
+ conda activate sam-vit-h-encoder-torchscript
10
+ make setup
11
+ ```
12
+
13
+ Load the SAM model and convert image encoder to torchscript.
14
+ ```
15
+ python convert_torchscript.py
16
+ ```
17
+
18
+ Check `model.pt` in `model_repository/sam_torchscript_fp32/1`.
load_model.py → convert_torchscript.py RENAMED
@@ -2,10 +2,9 @@ import os
2
  import urllib
3
 
4
  import torch
5
- from segment_anything import sam_model_registry
6
  from segment_anything.modeling import Sam
7
 
8
- from wrapper import ImageEncoderViTWrapper
9
 
10
  CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
11
  CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
@@ -28,18 +27,12 @@ def load_model(
28
  urllib.request.urlretrieve(checkpoint_url, checkpoint)
29
  print(f"The model weights saved as {checkpoint}")
30
  print(f"Load the model weights from {checkpoint}")
31
- return sam_model_registry[model_type](checkpoint=checkpoint)
32
 
33
 
34
  if __name__ == "__main__":
35
- # model = load_model().image_encoder.eval().to(device)
36
- image_encoder = load_model().image_encoder
37
- print(type(image_encoder))
38
- image_encoder_wrapper = ImageEncoderViTWrapper(image_encoder).eval().to(device)
39
- image_encoder_wrapper.change_block()
40
-
41
- print(type(image_encoder_wrapper.image_encoder.blocks[0]))
42
 
43
  with torch.jit.optimized_execution(True):
44
- script_model = torch.jit.script(image_encoder_wrapper)
45
- script_model.save("model.pt")
 
2
  import urllib
3
 
4
  import torch
 
5
  from segment_anything.modeling import Sam
6
 
7
+ from custom_encoder import build_sam_vit_h_torchscript
8
 
9
  CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
10
  CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
 
27
  urllib.request.urlretrieve(checkpoint_url, checkpoint)
28
  print(f"The model weights saved as {checkpoint}")
29
  print(f"Load the model weights from {checkpoint}")
30
+ return build_sam_vit_h_torchscript(checkpoint=checkpoint)
31
 
32
 
33
  if __name__ == "__main__":
34
+ model = load_model().image_encoder.eval().to(device)
 
 
 
 
 
 
35
 
36
  with torch.jit.optimized_execution(True):
37
+ script_model = torch.jit.script(model)
38
+ script_model.save("model_repository/sam_torchscript_fp32/model.pt")
custom_encoder.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Optional, Tuple, Type
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from segment_anything.modeling import (MaskDecoder, PromptEncoder, Sam,
7
+ TwoWayTransformer)
8
+ from segment_anything.modeling.common import LayerNorm2d
9
+ from segment_anything.modeling.image_encoder import (Block, PatchEmbed,
10
+ window_partition,
11
+ window_unpartition)
12
+
13
+
14
+ class CustomBlock(Block):
15
+ def __init__(self, **kargs) -> None:
16
+ super().__init__(**kargs)
17
+
18
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
19
+ shortcut = x
20
+ x = self.norm1(x)
21
+ # Window partition
22
+ if self.window_size > 0:
23
+ H, W = x.shape[1], x.shape[2]
24
+ x, pad_hw = window_partition(x, self.window_size)
25
+ x = self.attn(x)
26
+ # Reverse window partition
27
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
28
+ else:
29
+ x = self.attn(x)
30
+
31
+ x = shortcut + x
32
+ x = x + self.mlp(self.norm2(x))
33
+
34
+ return x
35
+
36
+
37
+ class CustomImageEncoderViT(nn.Module):
38
+ def __init__(
39
+ self,
40
+ img_size: int = 1024,
41
+ patch_size: int = 16,
42
+ in_chans: int = 3,
43
+ embed_dim: int = 768,
44
+ depth: int = 12,
45
+ num_heads: int = 12,
46
+ mlp_ratio: float = 4.0,
47
+ out_chans: int = 256,
48
+ qkv_bias: bool = True,
49
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
50
+ act_layer: Type[nn.Module] = nn.GELU,
51
+ use_abs_pos: bool = True,
52
+ use_rel_pos: bool = False,
53
+ rel_pos_zero_init: bool = True,
54
+ window_size: int = 0,
55
+ global_attn_indexes: Tuple[int, ...] = (),
56
+ ) -> None:
57
+ super().__init__()
58
+ self.img_size = img_size
59
+
60
+ self.patch_embed = PatchEmbed(
61
+ kernel_size=(patch_size, patch_size),
62
+ stride=(patch_size, patch_size),
63
+ in_chans=in_chans,
64
+ embed_dim=embed_dim,
65
+ )
66
+
67
+ self.pos_embed: Optional[nn.Parameter] = None
68
+ if use_abs_pos:
69
+ # Initialize absolute positional embedding with pretrain image size.
70
+ self.pos_embed = nn.Parameter(
71
+ torch.zeros(
72
+ 1, img_size // patch_size, img_size // patch_size, embed_dim
73
+ )
74
+ )
75
+
76
+ self.blocks = nn.ModuleList()
77
+ for i in range(depth):
78
+ block = CustomBlock(
79
+ dim=embed_dim,
80
+ num_heads=num_heads,
81
+ mlp_ratio=mlp_ratio,
82
+ qkv_bias=qkv_bias,
83
+ norm_layer=norm_layer,
84
+ act_layer=act_layer,
85
+ use_rel_pos=use_rel_pos,
86
+ rel_pos_zero_init=rel_pos_zero_init,
87
+ window_size=window_size if i not in global_attn_indexes else 0,
88
+ input_size=(img_size // patch_size, img_size // patch_size),
89
+ )
90
+ self.blocks.append(block)
91
+
92
+ self.neck = nn.Sequential(
93
+ nn.Conv2d(
94
+ embed_dim,
95
+ out_chans,
96
+ kernel_size=1,
97
+ bias=False,
98
+ ),
99
+ LayerNorm2d(out_chans),
100
+ nn.Conv2d(
101
+ out_chans,
102
+ out_chans,
103
+ kernel_size=3,
104
+ padding=1,
105
+ bias=False,
106
+ ),
107
+ LayerNorm2d(out_chans),
108
+ )
109
+
110
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
111
+ x = self.patch_embed(x)
112
+ if self.pos_embed is not None:
113
+ x = x + self.pos_embed
114
+
115
+ for blk in self.blocks:
116
+ x = blk(x)
117
+
118
+ x = self.neck(x.permute(0, 3, 1, 2))
119
+
120
+ return x
121
+
122
+
123
+ def _build_sam_torchscript(
124
+ encoder_embed_dim,
125
+ encoder_depth,
126
+ encoder_num_heads,
127
+ encoder_global_attn_indexes,
128
+ checkpoint=None,
129
+ ):
130
+ prompt_embed_dim = 256
131
+ image_size = 1024
132
+ vit_patch_size = 16
133
+ image_embedding_size = image_size // vit_patch_size
134
+ sam = Sam(
135
+ image_encoder=CustomImageEncoderViT(
136
+ depth=encoder_depth,
137
+ embed_dim=encoder_embed_dim,
138
+ img_size=image_size,
139
+ mlp_ratio=4,
140
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
141
+ num_heads=encoder_num_heads,
142
+ patch_size=vit_patch_size,
143
+ qkv_bias=True,
144
+ use_rel_pos=True,
145
+ global_attn_indexes=encoder_global_attn_indexes,
146
+ window_size=14,
147
+ out_chans=prompt_embed_dim,
148
+ ),
149
+ prompt_encoder=PromptEncoder(
150
+ embed_dim=prompt_embed_dim,
151
+ image_embedding_size=(image_embedding_size, image_embedding_size),
152
+ input_image_size=(image_size, image_size),
153
+ mask_in_chans=16,
154
+ ),
155
+ mask_decoder=MaskDecoder(
156
+ num_multimask_outputs=3,
157
+ transformer=TwoWayTransformer(
158
+ depth=2,
159
+ embedding_dim=prompt_embed_dim,
160
+ mlp_dim=2048,
161
+ num_heads=8,
162
+ ),
163
+ transformer_dim=prompt_embed_dim,
164
+ iou_head_depth=3,
165
+ iou_head_hidden_dim=256,
166
+ ),
167
+ pixel_mean=[123.675, 116.28, 103.53],
168
+ pixel_std=[58.395, 57.12, 57.375],
169
+ )
170
+ sam.eval()
171
+ if checkpoint is not None:
172
+ with open(checkpoint, "rb") as f:
173
+ state_dict = torch.load(f)
174
+ sam.load_state_dict(state_dict)
175
+ return sam
176
+
177
+
178
+ def build_sam_vit_h_torchscript(checkpoint=None):
179
+ return _build_sam_torchscript(
180
+ encoder_embed_dim=1280,
181
+ encoder_depth=32,
182
+ encoder_num_heads=16,
183
+ encoder_global_attn_indexes=[7, 15, 23, 31],
184
+ checkpoint=checkpoint,
185
+ )
wrapper.py DELETED
@@ -1,41 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from segment_anything.modeling import ImageEncoderViT
5
- from segment_anything.modeling.image_encoder import Block, window_partition, window_unpartition
6
-
7
-
8
- class BlockWrapper(nn.Module):
9
- def __init__(self, block: Block):
10
- super().__init__()
11
- self.block = block
12
-
13
- def forward(self, x: torch.Tensor) -> torch.Tensor:
14
- shortcut = x
15
- x = self.block.norm1(x)
16
- # Window partition
17
- if self.block.window_size > 0:
18
- H, W = x.shape[1], x.shape[2]
19
- x, pad_hw = window_partition(x, self.block.window_size)
20
- x = self.block.attn(x)
21
- # Reverse window partition
22
- x = window_unpartition(x, self.block.window_size, pad_hw, (H, W))
23
- else:
24
- x = self.block.attn(x)
25
-
26
- x = shortcut + x
27
- x = x + self.block.mlp(self.block.norm2(x))
28
-
29
- return x
30
-
31
-
32
- class ImageEncoderViTWrapper(nn.Module):
33
- def __init__(self, image_encoder: ImageEncoderViT):
34
- super().__init__()
35
- self.image_encoder = image_encoder
36
-
37
- def change_block(self):
38
- block_wrappers = nn.ModuleList()
39
- for block in self.image_encoder.blocks:
40
- block_wrappers.append(BlockWrapper(block))
41
- self.image_encoder.blocks = block_wrappers