| | from collections import OrderedDict |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torchvision.models import ( |
| | ResNet50_Weights, |
| | VGG16_BN_Weights, |
| | VGG16_Weights, |
| | resnet50, |
| | vgg16, |
| | vgg16_bn, |
| | ) |
| |
|
| | from engine.BiRefNet.config import Config |
| | from engine.BiRefNet.models.backbones.pvt_v2 import ( |
| | pvt_v2_b0, |
| | pvt_v2_b1, |
| | pvt_v2_b2, |
| | pvt_v2_b5, |
| | ) |
| | from engine.BiRefNet.models.backbones.swin_v1 import ( |
| | swin_v1_b, |
| | swin_v1_l, |
| | swin_v1_s, |
| | swin_v1_t, |
| | ) |
| |
|
| | config = Config() |
| |
|
| |
|
| | def build_backbone(bb_name, pretrained=True, params_settings=""): |
| | if bb_name == "vgg16": |
| | bb_net = list( |
| | vgg16(pretrained=VGG16_Weights.DEFAULT if pretrained else None).children() |
| | )[0] |
| | bb = nn.Sequential( |
| | OrderedDict( |
| | { |
| | "conv1": bb_net[:4], |
| | "conv2": bb_net[4:9], |
| | "conv3": bb_net[9:16], |
| | "conv4": bb_net[16:23], |
| | } |
| | ) |
| | ) |
| | elif bb_name == "vgg16bn": |
| | bb_net = list( |
| | vgg16_bn( |
| | pretrained=VGG16_BN_Weights.DEFAULT if pretrained else None |
| | ).children() |
| | )[0] |
| | bb = nn.Sequential( |
| | OrderedDict( |
| | { |
| | "conv1": bb_net[:6], |
| | "conv2": bb_net[6:13], |
| | "conv3": bb_net[13:23], |
| | "conv4": bb_net[23:33], |
| | } |
| | ) |
| | ) |
| | elif bb_name == "resnet50": |
| | bb_net = list( |
| | resnet50( |
| | pretrained=ResNet50_Weights.DEFAULT if pretrained else None |
| | ).children() |
| | ) |
| | bb = nn.Sequential( |
| | OrderedDict( |
| | { |
| | "conv1": nn.Sequential(*bb_net[0:3]), |
| | "conv2": bb_net[4], |
| | "conv3": bb_net[5], |
| | "conv4": bb_net[6], |
| | } |
| | ) |
| | ) |
| | else: |
| | bb = eval("{}({})".format(bb_name, params_settings)) |
| | if pretrained: |
| | bb = load_weights(bb, bb_name) |
| | return bb |
| |
|
| |
|
| | def load_weights(model, model_name): |
| | save_model = torch.load( |
| | config.weights[model_name], map_location="cpu", weights_only=True |
| | ) |
| | model_dict = model.state_dict() |
| | state_dict = { |
| | k: v if v.size() == model_dict[k].size() else model_dict[k] |
| | for k, v in save_model.items() |
| | if k in model_dict.keys() |
| | } |
| | |
| | if not state_dict: |
| | save_model_keys = list(save_model.keys()) |
| | sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None |
| | state_dict = { |
| | k: v if v.size() == model_dict[k].size() else model_dict[k] |
| | for k, v in save_model[sub_item].items() |
| | if k in model_dict.keys() |
| | } |
| | if not state_dict or not sub_item: |
| | print( |
| | "Weights are not successully loaded. Check the state dict of weights file." |
| | ) |
| | return None |
| | else: |
| | print( |
| | 'Found correct weights in the "{}" item of loaded state_dict.'.format( |
| | sub_item |
| | ) |
| | ) |
| | model_dict.update(state_dict) |
| | model.load_state_dict(model_dict) |
| | return model |
| |
|