| import torch |
| import torch.nn as nn |
| from torchvision.models import vgg19, vgg16 |
| import os.path as osp |
|
|
|
|
| |
| |
| |
| class VGG_LOSS(nn.Module): |
| def __init__(self, model_type='vgg19', layer_names=('conv_1_1', 'conv_2_1'), loss_type='l1'): |
| super(VGG_LOSS, self).__init__() |
| |
| mdir = osp.dirname(osp.realpath(__file__)) |
| if model_type== 'vgg16': |
| vgg_model = vgg16(pretrained=False) |
| pre_trained = torch.load('../vgg16-397923af.pth') |
| vgg_model.cuda() |
| vgg_model.load_state_dict(pre_trained) |
| elif model_type== 'vgg19': |
| vgg_model = vgg19(pretrained=False) |
| pre_trained = torch.load('../vgg19-dcbb9e9d.pth') |
| vgg_model.cuda() |
| vgg_model.load_state_dict(pre_trained) |
|
|
| |
| self.layer_names = get_layer_name_id(model_type, layer_names) |
| self.layer_ids = inverse_dict(self.layer_names) |
| self.lid_list = list(self.layer_names.values()) |
| self.lname_input = 'input' if ('input' in layer_names) else None |
|
|
| |
| lid_max = max(self.lid_list) |
| self.network = vgg_model.features[:lid_max + 1] |
|
|
| |
| self.mean_shift = MeanShift() |
|
|
| |
| loss_fun = nn.L1Loss() |
| if loss_type == 'l1': |
| loss_fun = nn.L1Loss() |
| elif loss_type == 'l2': |
| loss_fun = nn.MSELoss() |
| else: |
| pass |
| self.loss_fun = loss_fun |
|
|
| |
| self.set_not_requires_grad() |
| return |
|
|
| def forward(self, img_gt, img_infer, img_range=(-1.0, 1.0)): |
| ''' |
| 计算vgg损失 |
| ''' |
| feas_gt = self.get_feas(img_gt, img_range) |
| feas_infer = self.get_feas(img_infer, img_range) |
|
|
| loss_total = 0 |
| for lname, gt in feas_gt.items(): |
| infer = feas_infer[lname] |
| loss_tmp = self.loss_fun(gt, infer) |
| loss_total = loss_total + loss_tmp |
| return loss_total |
|
|
| def get_feas(self, xx, in_range): |
| ''' |
| 获取中间特征 |
| ''' |
| |
| xx = reset_range(xx, in_range) |
| xx = self.mean_shift(xx) |
|
|
| |
| out_feas = dict() |
| if self.lname_input is not None: |
| inname = self.lname_input |
| out_feas[inname] = xx.clone() |
| for lid, layer in enumerate(self.network): |
| xx = layer(xx) |
| if lid in self.lid_list: |
| layer_name = self.layer_ids[lid] |
| out_feas[layer_name] = xx.clone() |
| return out_feas |
|
|
| def set_not_requires_grad(self): |
| for para in self.parameters(): |
| para.requires_grad = False |
| self.eval() |
| return |
|
|
| def reset_range(indata, in_range): |
| ''' |
| 将数据范围调整为0~1 |
| ''' |
| minv, maxv = in_range |
| midv = 1.0 / (maxv - minv) |
| return (indata - minv) * midv |
|
|
| def get_layer_name_id(vgg_type, lnames): |
| ''' |
| 根据层名称获取层编号 |
| ''' |
| out_dict = dict() |
| layer_id_dict = vgg_all_layers(vgg_type) |
| for lname in lnames: |
| lid = layer_id_dict[lname] |
| out_dict[lname] = lid |
| return out_dict |
|
|
| def vgg_all_layers(vgg_type): |
| ''' |
| 获取vgg中间层名称及层号 |
| ''' |
| vgg_layer_vgg19 = { |
| 'conv_1_1': 0, 'conv_1_2': 2, 'pool_1': 4, |
| 'conv_2_1': 5, 'conv_2_2': 7, 'pool_2': 9, |
| 'conv_3_1': 10, 'conv_3_2': 12, 'conv_3_3': 14, 'conv_3_4': 16, 'pool_3': 18, |
| 'conv_4_1': 19, 'conv_4_2': 21, 'conv_4_3': 23, 'conv_4_4': 25, 'pool_4': 27, |
| 'conv_5_1': 28, 'conv_5_2': 30, 'conv_5_3': 32, 'conv_5_4': 34, 'pool_5': 36 |
| } |
| vgg_layer_vgg16 = { |
| 'conv_1_1': 0, 'conv_1_2': 2, 'pool_1': 4, |
| 'conv_2_1': 5, 'conv_2_2': 7, 'pool_2': 9, |
| 'conv_3_1': 10, 'conv_3_2': 12, 'conv_3_3': 14, 'pool_3': 16, |
| 'conv_4_1': 17, 'conv_4_2': 19, 'conv_4_3': 21, 'pool_4': 23, |
| 'conv_5_1': 24, 'conv_5_2': 26, 'conv_5_3': 28, 'pool_5': 30 |
| } |
|
|
| if vgg_type=='vgg16': |
| vgg_layer_dict = vgg_layer_vgg16 |
| elif vgg_type=='vgg19': |
| vgg_layer_dict = vgg_layer_vgg19 |
| else: |
| raise ValueError('Vgg network type should be either vgg16 or vgg19.') |
|
|
| vgg_fea_dict = {} |
| for lname, lindex in vgg_layer_dict.items(): |
| vgg_fea_dict[lname] = lindex |
| if 'conv' in lname: |
| lname_relu = lname.replace('conv', 'relu') |
| lindex_relu = lindex + 1 |
| vgg_fea_dict[lname_relu] = lindex_relu |
| return vgg_fea_dict |
|
|
| def inverse_dict(inp_dict): |
| ''' |
| 交换字典的键值及键 |
| ''' |
| out_dict = dict() |
| for key, val in inp_dict.items(): |
| out_dict[val] = key |
| return out_dict |
|
|
| class MeanShift(nn.Conv2d): |
| ''' |
| 固定参数卷积层,用于将普通RGB图像(范围0~1)转换为VGG输入格式 |
| ''' |
| def __init__(self, rgb_mean=(0.485, 0.456, 0.406), rgb_std=(0.229, 0.224, 0.225)): |
| super(MeanShift, self).__init__(in_channels=3, out_channels=3, kernel_size=1) |
| std = torch.Tensor(rgb_std) |
| self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) |
| self.bias.data = torch.Tensor(rgb_mean) / std |
| return |
|
|
|
|
|
|
| |
| |
| if __name__ == '__main__': |
| import torchvision.utils as tv_utils |
| import cv2 |
| import numpy as np |
| import os |
| import os.path as osp |
|
|
|
|
| |
| inp = r'D:\tmp\test\baboon.png' |
|
|
| img_in = cv2.imread(inp) |
| img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2RGB).transpose(2,0,1) |
| |
| img_in = (np.float32(img_in) - 0.0) / 255.0 |
| img_in_t = torch.from_numpy(img_in).unsqueeze(0) |
|
|
| |
| layer_names = ('conv_1_1', 'conv_2_1', 'conv_3_1', 'conv_4_1') |
| vgg_test = VGG_LOSS(layer_names=layer_names) |
|
|
| |
| |
| |
|
|
| |
| feas = vgg_test.get_feas(img_in_t, in_range=(0.0, 1.0)) |
|
|
| |
| def vgg_fea2img(vgg_fea): |
| mid_feas = vgg_fea.data |
| mid_feas = torch.transpose(mid_feas, 0, 1) |
| fea_nrow = round((mid_feas.shape[0]) ** 0.5) |
| fea_grid = tv_utils.make_grid(mid_feas, nrow=fea_nrow, normalize=True, scale_each=True) |
| fea_grid = fea_grid.cpu().float().numpy().transpose((1, 2, 0)) |
| fea_grid = (fea_grid * 255.0).round().clip(0, 255).astype(np.uint8) |
| return fea_grid |
|
|
| out_dir = osp.splitext(inp)[0] |
| if not osp.exists(out_dir): |
| os.mkdir(out_dir) |
| bind = 0 |
| for layer_name, mid_feas in feas.items(): |
| out_fea = vgg_fea2img(mid_feas[bind:(bind + 1)]) |
| |
| out_path = osp.join(out_dir, '{}_{}.png'.format(bind, layer_name)) |
| cv2.imwrite(out_path, out_fea) |