|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from lib.core.config import cfg |
|
|
|
|
|
|
|
|
|
|
|
class HACO(nn.Module): |
|
|
def __init__(self): |
|
|
super(HACO, self).__init__() |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.to(self.device) |
|
|
|
|
|
|
|
|
self.backbone = get_backbone_network(type=cfg.MODEL.backbone_type) |
|
|
self.decoder = get_decoder_network(type=cfg.MODEL.backbone_type) |
|
|
|
|
|
def forward(self, inputs, mode='test'): |
|
|
image = inputs['input']['image'].to(self.device) |
|
|
|
|
|
if 'vit' in cfg.MODEL.backbone_type: |
|
|
image = F.interpolate(image, size=(224, 224), mode='bilinear', align_corners=False) |
|
|
|
|
|
img_feat = self.backbone(image) |
|
|
contact_out = self.decoder(img_feat) |
|
|
|
|
|
return dict(contact_out=contact_out) |
|
|
|
|
|
|
|
|
|
|
|
def get_backbone_network(type='hamer'): |
|
|
if type in ['hamer']: |
|
|
from lib.models.backbone.backbone_hamer_style import ViT_HaMeR |
|
|
backbone = ViT_HaMeR() |
|
|
elif type in ['resnet-18']: |
|
|
from lib.models.backbone.resnet import ResNetBackbone |
|
|
backbone = ResNetBackbone(18) |
|
|
backbone.init_weights() |
|
|
elif type in ['resnet-34']: |
|
|
from lib.models.backbone.resnet import ResNetBackbone |
|
|
backbone = ResNetBackbone(34) |
|
|
backbone.init_weights() |
|
|
elif type in ['resnet-50']: |
|
|
from lib.models.backbone.resnet import ResNetBackbone |
|
|
backbone = ResNetBackbone(50) |
|
|
backbone.init_weights() |
|
|
elif type in ['resnet-101']: |
|
|
from lib.models.backbone.resnet import ResNetBackbone |
|
|
backbone = ResNetBackbone(101) |
|
|
backbone.init_weights() |
|
|
elif type in ['resnet-152']: |
|
|
from lib.models.backbone.resnet import ResNetBackbone |
|
|
backbone = ResNetBackbone(152) |
|
|
backbone.init_weights() |
|
|
elif type in ['hrnet-w32']: |
|
|
from lib.models.backbone.hrnet import HighResolutionNet |
|
|
from lib.utils.func_utils import load_config |
|
|
config = load_config(cfg.MODEL.hrnet_w32_backbone_config_path) |
|
|
pretrained = cfg.MODEL.hrnet_w32_backbone_pretrained_path |
|
|
backbone = HighResolutionNet(config) |
|
|
backbone.init_weights(pretrained=pretrained) |
|
|
elif type in ['hrnet-w48']: |
|
|
from lib.models.backbone.hrnet import HighResolutionNet |
|
|
from lib.utils.func_utils import load_config |
|
|
config = load_config(cfg.MODEL.hrnet_w48_backbone_config_path) |
|
|
pretrained = cfg.MODEL.hrnet_w48_backbone_pretrained_path |
|
|
backbone = HighResolutionNet(config) |
|
|
backbone.init_weights(pretrained=pretrained) |
|
|
elif type in ['handoccnet']: |
|
|
from lib.models.backbone.fpn import FPN |
|
|
backbone = FPN(pretrained=False) |
|
|
pretrained = cfg.MODEL.handoccnet_backbone_pretrained_path |
|
|
state_dict = {k[len('module.backbone.'):]: v for k, v in torch.load(pretrained)['network'].items() if k.startswith('module.backbone.')} |
|
|
backbone.load_state_dict(state_dict, strict=True) |
|
|
elif type in ['vit-s-16']: |
|
|
from lib.models.backbone.vit import ViTBackbone |
|
|
backbone = ViTBackbone(model_name='vit_small_patch16_224', pretrained=True) |
|
|
elif type in ['vit-b-16']: |
|
|
from lib.models.backbone.vit import ViTBackbone |
|
|
backbone = ViTBackbone(model_name='vit_base_patch16_224', pretrained=True) |
|
|
elif type in ['vit-l-16']: |
|
|
from lib.models.backbone.vit import ViTBackbone |
|
|
backbone = ViTBackbone(model_name='vit_large_patch16_224', pretrained=True) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
return backbone |
|
|
|
|
|
|
|
|
|
|
|
def get_decoder_network(type='hamer'): |
|
|
from lib.models.decoder.decoder_hamer_style import ContactTransformerDecoderHead |
|
|
decoder = ContactTransformerDecoderHead() |
|
|
|
|
|
return decoder |