Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ContentLoss(nn.Module): | |
| def __init__(self, target,): | |
| super().__init__() | |
| self.target = target.detach() | |
| def forward(self, input): | |
| self.loss = F.mse_loss(input, self.target) | |
| return input | |
| class StyleLoss(nn.Module): | |
| def __init__(self, target_feature): | |
| super().__init__() | |
| self.target = self.gram_matrix(target_feature).detach() | |
| def gram_matrix(self,input): | |
| a, b, c, d = input.size() | |
| features = input.view(a * b, c * d) | |
| G = torch.mm(features, features.t()) | |
| return G.div(a * b * c * d) | |
| def forward(self, input): | |
| G = self.gram_matrix(input) | |
| self.loss = F.mse_loss(G, self.target) | |
| return input | |
| class Normalization(nn.Module): | |
| def __init__(self, mean, std): | |
| super().__init__() | |
| self.mean = torch.tensor(mean).view(-1, 1, 1) | |
| self.std = torch.tensor(std).view(-1, 1, 1) | |
| def forward(self, img): | |
| return (img - self.mean) / self.std |