Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from src.utils import feature_normalize | |
| ### START### CONTEXTUAL LOSS #### | |
| class ContextualLoss(nn.Module): | |
| """ | |
| input is Al, Bl, channel = 1, range ~ [0, 255] | |
| """ | |
| def __init__(self): | |
| super(ContextualLoss, self).__init__() | |
| return None | |
| def forward(self, X_features, Y_features, h=0.1, feature_centering=True): | |
| """ | |
| X_features&Y_features are are feature vectors or feature 2d array | |
| h: bandwidth | |
| return the per-sample loss | |
| """ | |
| batch_size = X_features.shape[0] | |
| feature_depth = X_features.shape[1] | |
| # to normalized feature vectors | |
| if feature_centering: | |
| X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze( | |
| dim=-1 | |
| ) | |
| Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze( | |
| dim=-1 | |
| ) | |
| X_features = feature_normalize(X_features).view( | |
| batch_size, feature_depth, -1 | |
| ) # batch_size * feature_depth * feature_size^2 | |
| Y_features = feature_normalize(Y_features).view( | |
| batch_size, feature_depth, -1 | |
| ) # batch_size * feature_depth * feature_size^2 | |
| # conine distance = 1 - similarity | |
| X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth | |
| d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2 | |
| # normalized distance: dij_bar | |
| d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2 | |
| # pairwise affinity | |
| w = torch.exp((1 - d_norm) / h) | |
| A_ij = w / torch.sum(w, dim=-1, keepdim=True) | |
| # contextual loss per sample | |
| CX = torch.mean(torch.max(A_ij, dim=1)[0], dim=-1) | |
| return -torch.log(CX) | |
| class ContextualLoss_forward(nn.Module): | |
| """ | |
| input is Al, Bl, channel = 1, range ~ [0, 255] | |
| """ | |
| def __init__(self): | |
| super(ContextualLoss_forward, self).__init__() | |
| return None | |
| def forward(self, X_features, Y_features, h=0.1, feature_centering=True): | |
| """ | |
| X_features&Y_features are are feature vectors or feature 2d array | |
| h: bandwidth | |
| return the per-sample loss | |
| """ | |
| batch_size = X_features.shape[0] | |
| feature_depth = X_features.shape[1] | |
| # to normalized feature vectors | |
| if feature_centering: | |
| X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze( | |
| dim=-1 | |
| ) | |
| Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze( | |
| dim=-1 | |
| ) | |
| X_features = feature_normalize(X_features).view( | |
| batch_size, feature_depth, -1 | |
| ) # batch_size * feature_depth * feature_size^2 | |
| Y_features = feature_normalize(Y_features).view( | |
| batch_size, feature_depth, -1 | |
| ) # batch_size * feature_depth * feature_size^2 | |
| # conine distance = 1 - similarity | |
| X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth | |
| d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2 | |
| # normalized distance: dij_bar | |
| d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2 | |
| # pairwise affinity | |
| w = torch.exp((1 - d_norm) / h) | |
| A_ij = w / torch.sum(w, dim=-1, keepdim=True) | |
| # contextual loss per sample | |
| CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1) | |
| return -torch.log(CX) | |
| ### END### CONTEXTUAL LOSS #### | |
| ########################## | |
| def mse_loss_fn(input, target=0): | |
| return torch.mean((input - target) ** 2) | |
| ### START### PERCEPTUAL LOSS ### | |
| def Perceptual_loss(domain_invariant, weight_perceptual): | |
| instancenorm = nn.InstanceNorm2d(512, affine=False) | |
| def __call__(A_relu5_1, predict_relu5_1): | |
| if domain_invariant: | |
| feat_loss = ( | |
| mse_loss_fn(instancenorm(predict_relu5_1), instancenorm(A_relu5_1.detach())) * weight_perceptual * 1e5 * 0.2 | |
| ) | |
| else: | |
| feat_loss = mse_loss_fn(predict_relu5_1, A_relu5_1.detach()) * weight_perceptual | |
| return feat_loss | |
| return __call__ | |
| ### END### PERCEPTUAL LOSS ### | |
| def l1_loss_fn(input, target=0): | |
| return torch.mean(torch.abs(input - target)) | |
| ### END################# | |
| ### START### ADVERSIAL LOSS ### | |
| def generator_loss_fn(real_data_lab, fake_data_lab, discriminator, weight_gan, device): | |
| if weight_gan > 0: | |
| y_pred_fake, _ = discriminator(fake_data_lab) | |
| y_pred_real, _ = discriminator(real_data_lab) | |
| y = torch.ones_like(y_pred_real) | |
| generator_loss = ( | |
| ( | |
| torch.mean((y_pred_real - torch.mean(y_pred_fake) + y) ** 2) | |
| + torch.mean((y_pred_fake - torch.mean(y_pred_real) - y) ** 2) | |
| ) | |
| / 2 | |
| * weight_gan | |
| ) | |
| return generator_loss | |
| return torch.Tensor([0]).to(device) | |
| def discriminator_loss_fn(real_data_lab, fake_data_lab, discriminator): | |
| y_pred_fake, _ = discriminator(fake_data_lab.detach()) | |
| y_pred_real, _ = discriminator(real_data_lab.detach()) | |
| y = torch.ones_like(y_pred_real) | |
| discriminator_loss = ( | |
| torch.mean((y_pred_real - torch.mean(y_pred_fake) - y) ** 2) | |
| + torch.mean((y_pred_fake - torch.mean(y_pred_real) + y) ** 2) | |
| ) / 2 | |
| return discriminator_loss | |
| ### END### ADVERSIAL LOSS ##### | |
| def consistent_loss_fn( | |
| I_current_lab_predict, | |
| I_last_ab_predict, | |
| I_current_nonlocal_lab_predict, | |
| I_last_nonlocal_lab_predict, | |
| flow_forward, | |
| mask, | |
| warping_layer, | |
| weight_consistent=0.02, | |
| weight_nonlocal_consistent=0.0, | |
| device="cuda", | |
| ): | |
| def weighted_mse_loss(input, target, weights): | |
| out = (input - target) ** 2 | |
| out = out * weights.expand_as(out) | |
| return out.mean() | |
| def consistent(): | |
| I_current_lab_predict_warp = warping_layer(I_current_lab_predict, flow_forward) | |
| I_current_ab_predict_warp = I_current_lab_predict_warp[:, 1:3, :, :] | |
| consistent_loss = weighted_mse_loss(I_current_ab_predict_warp, I_last_ab_predict, mask) * weight_consistent | |
| return consistent_loss | |
| def nonlocal_consistent(): | |
| I_current_nonlocal_lab_predict_warp = warping_layer(I_current_nonlocal_lab_predict, flow_forward) | |
| nonlocal_consistent_loss = ( | |
| weighted_mse_loss( | |
| I_current_nonlocal_lab_predict_warp[:, 1:3, :, :], | |
| I_last_nonlocal_lab_predict[:, 1:3, :, :], | |
| mask, | |
| ) | |
| * weight_nonlocal_consistent | |
| ) | |
| return nonlocal_consistent_loss | |
| consistent_loss = consistent() if weight_consistent else torch.Tensor([0]).to(device) | |
| nonlocal_consistent_loss = nonlocal_consistent() if weight_nonlocal_consistent else torch.Tensor([0]).to(device) | |
| return consistent_loss + nonlocal_consistent_loss | |
| ### END### CONSISTENCY LOSS ##### | |
| ### START### SMOOTHNESS LOSS ### | |
| def smoothness_loss_fn( | |
| I_current_l, | |
| I_current_lab, | |
| I_current_ab_predict, | |
| A_relu2_1, | |
| weighted_layer_color, | |
| nonlocal_weighted_layer, | |
| weight_smoothness=5.0, | |
| weight_nonlocal_smoothness=0.0, | |
| device="cuda", | |
| ): | |
| def smoothness(scale_factor=1.0): | |
| I_current_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1) | |
| IA_ab_weighed = weighted_layer_color( | |
| I_current_lab, | |
| I_current_lab_predict, | |
| patch_size=3, | |
| alpha=10, | |
| scale_factor=scale_factor, | |
| ) | |
| smoothness_loss = ( | |
| mse_loss_fn( | |
| nn.functional.interpolate(I_current_ab_predict, scale_factor=scale_factor), | |
| IA_ab_weighed, | |
| ) | |
| * weight_smoothness | |
| ) | |
| return smoothness_loss | |
| def nonlocal_smoothness(scale_factor=0.25, alpha_nonlocal_smoothness=0.5): | |
| nonlocal_smooth_feature = feature_normalize(A_relu2_1) | |
| I_current_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1) | |
| I_current_ab_weighted_nonlocal = nonlocal_weighted_layer( | |
| I_current_lab_predict, | |
| nonlocal_smooth_feature.detach(), | |
| patch_size=3, | |
| alpha=alpha_nonlocal_smoothness, | |
| scale_factor=scale_factor, | |
| ) | |
| nonlocal_smoothness_loss = ( | |
| mse_loss_fn( | |
| nn.functional.interpolate(I_current_ab_predict, scale_factor=scale_factor), | |
| I_current_ab_weighted_nonlocal, | |
| ) | |
| * weight_nonlocal_smoothness | |
| ) | |
| return nonlocal_smoothness_loss | |
| smoothness_loss = smoothness() if weight_smoothness else torch.Tensor([0]).to(device) | |
| nonlocal_smoothness_loss = nonlocal_smoothness() if weight_nonlocal_smoothness else torch.Tensor([0]).to(device) | |
| return smoothness_loss + nonlocal_smoothness_loss | |
| ### END### SMOOTHNESS LOSS ##### | |