File size: 3,831 Bytes
5732928
 
 
 
 
 
 
 
 
 
 
2da0adb
 
5732928
 
 
 
 
 
 
 
 
 
 
 
5dff371
5732928
5dff371
5732928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import torch.nn.functional as F

from lib.core.config import cfg



class HACO(nn.Module):
    def __init__(self):
        super(HACO, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)

        # Load modules
        self.backbone = get_backbone_network(type=cfg.MODEL.backbone_type)
        self.decoder = get_decoder_network(type=cfg.MODEL.backbone_type)

    def forward(self, inputs, mode='test'):
        image = inputs['input']['image'].to(self.device)

        if 'vit' in cfg.MODEL.backbone_type:
            image = F.interpolate(image, size=(224, 224), mode='bilinear', align_corners=False)
        
        img_feat = self.backbone(image)
        contact_out = self.decoder(img_feat)
        
        return dict(contact_out=contact_out)



def get_backbone_network(type='hamer'):
    if type in ['hamer']:
        from lib.models.backbone.backbone_hamer_style import ViT_HaMeR
        backbone = ViT_HaMeR()
    elif type in ['resnet-18']:
        from lib.models.backbone.resnet import ResNetBackbone
        backbone = ResNetBackbone(18) # ResNet
        backbone.init_weights()
    elif type in ['resnet-34']:
        from lib.models.backbone.resnet import ResNetBackbone
        backbone = ResNetBackbone(34) # ResNet
        backbone.init_weights()
    elif type in ['resnet-50']:
        from lib.models.backbone.resnet import ResNetBackbone
        backbone = ResNetBackbone(50) # ResNet
        backbone.init_weights()
    elif type in ['resnet-101']:
        from lib.models.backbone.resnet import ResNetBackbone
        backbone = ResNetBackbone(101) # ResNet
        backbone.init_weights()
    elif type in ['resnet-152']:
        from lib.models.backbone.resnet import ResNetBackbone
        backbone = ResNetBackbone(152) # ResNet
        backbone.init_weights()
    elif type in ['hrnet-w32']:
        from lib.models.backbone.hrnet import HighResolutionNet
        from lib.utils.func_utils import load_config
        config = load_config(cfg.MODEL.hrnet_w32_backbone_config_path)
        pretrained = cfg.MODEL.hrnet_w32_backbone_pretrained_path
        backbone = HighResolutionNet(config)
        backbone.init_weights(pretrained=pretrained)
    elif type in ['hrnet-w48']:
        from lib.models.backbone.hrnet import HighResolutionNet
        from lib.utils.func_utils import load_config
        config = load_config(cfg.MODEL.hrnet_w48_backbone_config_path)
        pretrained = cfg.MODEL.hrnet_w48_backbone_pretrained_path
        backbone = HighResolutionNet(config)
        backbone.init_weights(pretrained=pretrained)
    elif type in ['handoccnet']:
        from lib.models.backbone.fpn import FPN
        backbone = FPN(pretrained=False)
        pretrained = cfg.MODEL.handoccnet_backbone_pretrained_path
        state_dict = {k[len('module.backbone.'):]: v for k, v in torch.load(pretrained)['network'].items() if k.startswith('module.backbone.')}
        backbone.load_state_dict(state_dict, strict=True)
    elif type in ['vit-s-16']:
        from lib.models.backbone.vit import ViTBackbone
        backbone = ViTBackbone(model_name='vit_small_patch16_224', pretrained=True)
    elif type in ['vit-b-16']:
        from lib.models.backbone.vit import ViTBackbone
        backbone = ViTBackbone(model_name='vit_base_patch16_224', pretrained=True)
    elif type in ['vit-l-16']:
        from lib.models.backbone.vit import ViTBackbone
        backbone = ViTBackbone(model_name='vit_large_patch16_224', pretrained=True)
    else:
        raise NotImplementedError

    return backbone



def get_decoder_network(type='hamer'):
    from lib.models.decoder.decoder_hamer_style import ContactTransformerDecoderHead
    decoder = ContactTransformerDecoderHead()

    return decoder