BoSAM / segment_anything /build_sam.py
ziyanlu's picture
Upload folder using huggingface_hub
9859ea2 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from functools import partial
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
from torch.nn import functional as F
def build_sam_vit_h(args):
return _build_sam(
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indexes=[7, 15, 23, 31],
image_size=args.image_size,
checkpoint=args.sam_checkpoint,
)
build_sam = build_sam_vit_h
def build_sam_vit_l(args):
return _build_sam(
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_global_attn_indexes=[5, 11, 17, 23],
image_size=args.image_size,
checkpoint=args.sam_checkpoint,
)
def build_sam_vit_b(args):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
image_size=args.image_size,
checkpoint=args.sam_checkpoint,
)
sam_model_registry = {
"default": build_sam_vit_h,
"vit_h": build_sam_vit_h,
"vit_l": build_sam_vit_l,
"vit_b": build_sam_vit_b,
}
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
image_size,
checkpoint,
):
prompt_embed_dim = 256
image_size = image_size
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
sam = Sam(
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos = True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
sam.train()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
try:
if 'model' in state_dict.keys():
sam.load_state_dict(state_dict['model'])
else:
sam.load_state_dict(state_dict)
except:
print('*******interpolate')
new_state_dict = load_from(sam, state_dict, image_size, vit_patch_size)
sam.load_state_dict(new_state_dict)
print(f"*******load {checkpoint}")
return sam
def load_from(sam, state_dicts, image_size, vit_patch_size):
sam_dict = sam.state_dict()
except_keys = ['mask_tokens', 'output_hypernetworks_mlps', 'iou_prediction_head']
new_state_dict = {k: v for k, v in state_dicts.items() if
k in sam_dict.keys() and except_keys[0] not in k and except_keys[1] not in k and except_keys[2] not in k}
pos_embed = new_state_dict['image_encoder.pos_embed']
token_size = int(image_size // vit_patch_size)
if pos_embed.shape[1] != token_size:
# resize pos embedding, which may sacrifice the performance, but I have no better idea
pos_embed = pos_embed.permute(0, 3, 1, 2) # [b, c, h, w]
pos_embed = F.interpolate(pos_embed, (token_size, token_size), mode='bilinear', align_corners=False)
pos_embed = pos_embed.permute(0, 2, 3, 1) # [b, h, w, c]
new_state_dict['image_encoder.pos_embed'] = pos_embed
rel_pos_keys = [k for k in sam_dict.keys() if 'rel_pos' in k]
global_rel_pos_keys = [k for k in rel_pos_keys if
'2' in k or
'5' in k or
'7' in k or
'8' in k or
'11' in k or
'13' in k or
'15' in k or
'23' in k or
'31' in k]
# print(sam_dict)
for k in global_rel_pos_keys:
h_check, w_check = sam_dict[k].shape
rel_pos_params = new_state_dict[k]
h, w = rel_pos_params.shape
rel_pos_params = rel_pos_params.unsqueeze(0).unsqueeze(0)
if h != h_check or w != w_check:
rel_pos_params = F.interpolate(rel_pos_params, (h_check, w_check), mode='bilinear', align_corners=False)
new_state_dict[k] = rel_pos_params[0, 0, ...]
sam_dict.update(new_state_dict)
return sam_dict