Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import torch | |
| from .transformer import INTR | |
| from .sam_transformer import SamTransformer | |
| from .sam import ImageEncoderViT, MaskDecoder, PromptEncoder, TwoWayTransformer | |
| def build_demo_model(): | |
| # model = INTR( | |
| # backbone_name='resnet50', | |
| # image_size=[768, 1024], | |
| # num_queries=15, | |
| # freeze_backbone=False, | |
| # transformer_hidden_dim=256, | |
| # transformer_dropout=0, | |
| # transformer_nhead=8, | |
| # transformer_dim_feedforward=2048, | |
| # transformer_num_encoder_layers=6, | |
| # transformer_num_decoder_layers=6, | |
| # transformer_normalize_before=False, | |
| # transformer_return_intermediate_dec=True, | |
| # layers_movable=1, | |
| # layers_rigid=1, | |
| # layers_kinematic=1, | |
| # layers_action=1, | |
| # layers_axis=3, | |
| # layers_affordance=3, | |
| # depth_on=True, | |
| # ) | |
| # sam_vit_b | |
| encoder_embed_dim=768 | |
| encoder_depth=12 | |
| encoder_num_heads=12 | |
| encoder_global_attn_indexes=[2, 5, 8, 11] | |
| prompt_embed_dim = 256 | |
| image_size = 1024 | |
| vit_patch_size = 16 | |
| image_embedding_size = image_size // vit_patch_size | |
| model = SamTransformer( | |
| image_encoder=ImageEncoderViT( | |
| depth=encoder_depth, | |
| embed_dim=encoder_embed_dim, | |
| img_size=image_size, | |
| mlp_ratio=4, | |
| norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), | |
| num_heads=encoder_num_heads, | |
| patch_size=vit_patch_size, | |
| qkv_bias=True, | |
| use_rel_pos=True, | |
| global_attn_indexes=encoder_global_attn_indexes, | |
| window_size=14, | |
| out_chans=prompt_embed_dim, | |
| ), | |
| prompt_encoder=PromptEncoder( | |
| embed_dim=prompt_embed_dim, | |
| image_embedding_size=(image_embedding_size, image_embedding_size), | |
| input_image_size=(image_size, image_size), | |
| mask_in_chans=16, | |
| ), | |
| mask_decoder=MaskDecoder( | |
| num_multimask_outputs=3, | |
| transformer=TwoWayTransformer( | |
| depth=2, | |
| embedding_dim=prompt_embed_dim, | |
| mlp_dim=2048, | |
| num_heads=8, | |
| ), | |
| transformer_dim=prompt_embed_dim, | |
| iou_head_depth=3, | |
| iou_head_hidden_dim=256, | |
| properties_on=True, | |
| ), | |
| affordance_decoder=MaskDecoder( | |
| num_multimask_outputs=3, | |
| transformer=TwoWayTransformer( | |
| depth=2, | |
| embedding_dim=prompt_embed_dim, | |
| mlp_dim=2048, | |
| num_heads=8, | |
| ), | |
| transformer_dim=prompt_embed_dim, | |
| iou_head_depth=3, | |
| iou_head_hidden_dim=256, | |
| properties_on=False, | |
| ), | |
| depth_decoder=MaskDecoder( | |
| num_multimask_outputs=3, | |
| transformer=TwoWayTransformer( | |
| depth=2, | |
| embedding_dim=prompt_embed_dim, | |
| mlp_dim=2048, | |
| num_heads=8, | |
| ), | |
| transformer_dim=prompt_embed_dim, | |
| iou_head_depth=3, | |
| iou_head_hidden_dim=256, | |
| properties_on=False, | |
| ), | |
| transformer_hidden_dim=prompt_embed_dim, | |
| backbone_name='vit_b', | |
| pixel_mean=[123.675, 116.28, 103.53], | |
| pixel_std=[58.395, 57.12, 57.375], | |
| ) | |
| return model | |