import os import torch from torchvision import transforms from PIL import Image def make_transform(size: tuple, normalize=True): """ 将PIL图像处理为可以直接作为模型输入的张量 :param size: 模型输入的图像尺寸 :param normalize: 是否进行规范化(vgg的输入需要规范化) :return: """ transform_lst = [transforms.Resize(size), # 将图像大小调整为 450x300 transforms.ToTensor()] # 将 PIL 图像转换为 Tensor] if normalize: transform_lst.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) transform = transforms.Compose(transform_lst) return transform def load_image(image_path, transform): """ 加载图像 :param image_path: 图像路径 :param transform: 应用图像变换 :return: """ image = Image.open(image_path).convert('RGB') image = transform(image).unsqueeze(0) return image def save_model(model: torch.nn.Module, save_path): """ 保存pytorch模型文件 :param model: :param save_path: :return: """ save_dir, filename = os.path.split(save_path) if save_dir and not os.path.exists(save_dir): os.makedirs(save_dir) torch.save(model.state_dict(), os.path.join(save_dir, filename)) def denormalize(tensor): """反归一化张量以将其转换回图像格式""" mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1).to(tensor.device) std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1).to(tensor.device) tensor = tensor * std + mean tensor = tensor.clamp(0, 1) return tensor.to(tensor.device) def gram_matrix(feature): """ 计算特征图的格莱姆矩阵,作为风格特征表示 :param feature:输入特征图 :return: 格莱姆矩阵 """ b, c, h, w = feature.size() feature = feature.view(b, c, h * w) # 沿着h,w维度拉平 G = torch.bmm(feature, feature.transpose(1, 2)) # G = torch.mm(feature, feature.t()) # mm, t()仅仅适用于二维矩阵的运算 return G def calculate_content_loss(original_feat, generated_feat) -> torch.Tensor: """计算内容损失,即生成特征图与标准特征图的规范化误差平方和""" b, c, h, w = original_feat.shape x = 2. * c * h * w # 规范化系数 return torch.sum((generated_feat - original_feat)**2) / x def calculate_style_loss(style_feat, generated_feat) -> torch.Tensor: """计算风格损失,即生成特征图与标准特征图的格拉姆矩阵的规范化误差平方和""" b, c, h, w = style_feat.shape G = gram_matrix(generated_feat) A = gram_matrix(style_feat) x = 4. * ((h * w) ** 2) * (c ** 2) # 规范化系数 return torch.sum((G - A)**2) / x def save_image(tensor, output_dir, filename, denormalization=True): """ 保存图像到OUTPUT_DIR :param tensor: [0-1]区间的图像张量,形状为(1, 3, h, w)或(3, h, w) :param output_dir: 输出路径 :param filename: 文件名 :param denormalization: 是否使用反规范化 :return: """ if not os.path.exists(output_dir): os.makedirs(output_dir) if denormalization: tensor = denormalize(tensor) image = transforms.ToPILImage()(tensor[0].cpu().detach()) # 只保存第一张 image.save(os.path.join(output_dir, filename))