Change using wrapper. but fail
Browse files- .gitignore +1 -0
- Makefile +8 -0
- load_model.py +45 -0
- requirements.txt +3 -0
- wrapper.py +41 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
Makefile
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
PYTHON=3.9
|
| 2 |
+
BASENAME=sam-vit-h-encoder-torchscript
|
| 3 |
+
|
| 4 |
+
env:
|
| 5 |
+
conda create -n $(BASENAME) python=$(PYTHON) -y
|
| 6 |
+
|
| 7 |
+
setup:
|
| 8 |
+
pip install -r requirements.txt
|
load_model.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import urllib
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from segment_anything import sam_model_registry
|
| 6 |
+
from segment_anything.modeling import Sam
|
| 7 |
+
|
| 8 |
+
from wrapper import ImageEncoderViTWrapper
|
| 9 |
+
|
| 10 |
+
CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
|
| 11 |
+
CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
|
| 12 |
+
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
| 13 |
+
MODEL_TYPE = "default"
|
| 14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_model(
|
| 18 |
+
checkpoint_path: str = CHECKPOINT_PATH,
|
| 19 |
+
checkpoint_name: str = CHECKPOINT_NAME,
|
| 20 |
+
checkpoint_url: str = CHECKPOINT_URL,
|
| 21 |
+
model_type: str = MODEL_TYPE,
|
| 22 |
+
) -> Sam:
|
| 23 |
+
if not os.path.exists(checkpoint_path):
|
| 24 |
+
os.makedirs(checkpoint_path)
|
| 25 |
+
checkpoint = os.path.join(checkpoint_path, checkpoint_name)
|
| 26 |
+
if not os.path.exists(checkpoint):
|
| 27 |
+
print("Downloading the model weights...")
|
| 28 |
+
urllib.request.urlretrieve(checkpoint_url, checkpoint)
|
| 29 |
+
print(f"The model weights saved as {checkpoint}")
|
| 30 |
+
print(f"Load the model weights from {checkpoint}")
|
| 31 |
+
return sam_model_registry[model_type](checkpoint=checkpoint)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
# model = load_model().image_encoder.eval().to(device)
|
| 36 |
+
image_encoder = load_model().image_encoder
|
| 37 |
+
print(type(image_encoder))
|
| 38 |
+
image_encoder_wrapper = ImageEncoderViTWrapper(image_encoder).eval().to(device)
|
| 39 |
+
image_encoder_wrapper.change_block()
|
| 40 |
+
|
| 41 |
+
print(type(image_encoder_wrapper.image_encoder.blocks[0]))
|
| 42 |
+
|
| 43 |
+
with torch.jit.optimized_execution(True):
|
| 44 |
+
script_model = torch.jit.script(image_encoder_wrapper)
|
| 45 |
+
script_model.save("model.pt")
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch == 2.0.0
|
| 2 |
+
torchvision == 0.15.1
|
| 3 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
wrapper.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from segment_anything.modeling import ImageEncoderViT
|
| 5 |
+
from segment_anything.modeling.image_encoder import Block, window_partition, window_unpartition
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BlockWrapper(nn.Module):
|
| 9 |
+
def __init__(self, block: Block):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.block = block
|
| 12 |
+
|
| 13 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 14 |
+
shortcut = x
|
| 15 |
+
x = self.block.norm1(x)
|
| 16 |
+
# Window partition
|
| 17 |
+
if self.block.window_size > 0:
|
| 18 |
+
H, W = x.shape[1], x.shape[2]
|
| 19 |
+
x, pad_hw = window_partition(x, self.block.window_size)
|
| 20 |
+
x = self.block.attn(x)
|
| 21 |
+
# Reverse window partition
|
| 22 |
+
x = window_unpartition(x, self.block.window_size, pad_hw, (H, W))
|
| 23 |
+
else:
|
| 24 |
+
x = self.block.attn(x)
|
| 25 |
+
|
| 26 |
+
x = shortcut + x
|
| 27 |
+
x = x + self.block.mlp(self.block.norm2(x))
|
| 28 |
+
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ImageEncoderViTWrapper(nn.Module):
|
| 33 |
+
def __init__(self, image_encoder: ImageEncoderViT):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.image_encoder = image_encoder
|
| 36 |
+
|
| 37 |
+
def change_block(self):
|
| 38 |
+
block_wrappers = nn.ModuleList()
|
| 39 |
+
for block in self.image_encoder.blocks:
|
| 40 |
+
block_wrappers.append(BlockWrapper(block))
|
| 41 |
+
self.image_encoder.blocks = block_wrappers
|