File size: 3,380 Bytes
5106c86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os

import torch
from torchvision import transforms
from PIL import Image


def make_transform(size: tuple, normalize=True):
    """
    将PIL图像处理为可以直接作为模型输入的张量
    :param size: 模型输入的图像尺寸
    :param normalize: 是否进行规范化(vgg的输入需要规范化)
    :return:
    """
    transform_lst = [transforms.Resize(size),  # 将图像大小调整为 450x300
        transforms.ToTensor()]  # 将 PIL 图像转换为 Tensor]

    if normalize:
        transform_lst.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
    transform = transforms.Compose(transform_lst)
    return transform


def load_image(image_path, transform):
    """
    加载图像
    :param image_path: 图像路径
    :param transform: 应用图像变换
    :return:
    """
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)
    return image


def save_model(model: torch.nn.Module, save_path):
    """
    保存pytorch模型文件
    :param model:
    :param save_path:
    :return:
    """
    save_dir, filename = os.path.split(save_path)
    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir)
    torch.save(model.state_dict(), os.path.join(save_dir, filename))


def denormalize(tensor):
    """反归一化张量以将其转换回图像格式"""
    mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1).to(tensor.device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1).to(tensor.device)
    tensor = tensor * std + mean
    tensor = tensor.clamp(0, 1)
    return tensor.to(tensor.device)


def gram_matrix(feature):
    """
    计算特征图的格莱姆矩阵,作为风格特征表示
    :param feature:输入特征图
    :return: 格莱姆矩阵
    """
    b, c, h, w = feature.size()
    feature = feature.view(b, c, h * w)  # 沿着h,w维度拉平
    G = torch.bmm(feature, feature.transpose(1, 2))
    # G = torch.mm(feature, feature.t())  # mm, t()仅仅适用于二维矩阵的运算
    return G


def calculate_content_loss(original_feat, generated_feat) -> torch.Tensor:
    """计算内容损失,即生成特征图与标准特征图的规范化误差平方和"""
    b, c, h, w = original_feat.shape
    x = 2. * c * h * w  # 规范化系数
    return torch.sum((generated_feat - original_feat)**2) / x


def calculate_style_loss(style_feat, generated_feat) -> torch.Tensor:
    """计算风格损失,即生成特征图与标准特征图的格拉姆矩阵的规范化误差平方和"""
    b, c, h, w = style_feat.shape
    G = gram_matrix(generated_feat)
    A = gram_matrix(style_feat)
    x = 4. * ((h * w) ** 2) * (c ** 2)  # 规范化系数
    return torch.sum((G - A)**2) / x


def save_image(tensor, output_dir, filename, denormalization=True):
    """
    保存图像到OUTPUT_DIR
    :param tensor: [0-1]区间的图像张量,形状为(1, 3, h, w)或(3, h, w)
    :param output_dir: 输出路径
    :param filename: 文件名
    :param denormalization: 是否使用反规范化
    :return:
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    if denormalization:
        tensor = denormalize(tensor)
    image = transforms.ToPILImage()(tensor[0].cpu().detach())  # 只保存第一张
    image.save(os.path.join(output_dir, filename))