Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import torch | |
| from .modeling import ImageEncoderViT, PromptEncoder | |
| from .modeling.tiny_vit_sam import TinyViT | |
| model_type_registry = dict( | |
| vit_l = dict( | |
| embed_dim=1024, | |
| depth=24, | |
| num_heads=16, | |
| global_attn_indexes=[5, 11, 17, 23] | |
| ), | |
| vit_h = dict( | |
| embed_dim=1280, | |
| depth=32, | |
| num_heads=16, | |
| global_attn_indexes=[7, 15, 23, 31], | |
| ), | |
| vit_b = dict( | |
| embed_dim=768, | |
| depth=12, | |
| num_heads=12, | |
| global_attn_indexes=[2, 5, 8, 11],), | |
| ) | |
| def build_image_encoder(model_type: str): | |
| if model_type == 'vit_t': | |
| image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, | |
| embed_dims=[64, 128, 160, 320], | |
| depths=[2, 2, 6, 2], | |
| num_heads=[2, 4, 5, 10], | |
| window_sizes=[7, 7, 14, 7], | |
| mlp_ratio=4., | |
| drop_rate=0., | |
| drop_path_rate=0.0, | |
| use_checkpoint=False, | |
| mbconv_expand_ratio=4.0, | |
| local_conv_size=3, | |
| layer_lr_decay=0.8 | |
| ) | |
| else: | |
| assert model_type in model_type_registry | |
| image_encoder = ImageEncoderViT( | |
| img_size=1024, | |
| mlp_ratio=4, | |
| norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), | |
| patch_size=16, | |
| qkv_bias=True, | |
| use_rel_pos=True, | |
| window_size=14, | |
| out_chans=256, | |
| **model_type_registry[model_type] | |
| ) | |
| return image_encoder | |
| def build_prompt_encoder(image_size = 1024, vit_patch_size = 16): | |
| image_embedding_size = image_size // vit_patch_size | |
| prompt_encoder=PromptEncoder( | |
| embed_dim=256, | |
| image_embedding_size=(image_embedding_size, image_embedding_size), | |
| input_image_size=(image_size, image_size), | |
| mask_in_chans=16, | |
| ) | |
| return prompt_encoder | |