import torch import torch.nn as nn from torch.autograd import Variable import torch.nn.functional as F import matplotlib.pyplot as plt def truncated_normal_(tensor, mean=0, std=1): size = tensor.shape tmp = tensor.new_empty(size + (4,)).normal_() valid = (tmp < 2) & (tmp > -2) ind = valid.max(-1, keepdim=True)[1] tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) tensor.data.mul_(std).add_(mean) def init_weights(m): if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') #nn.init.normal_(m.weight, std=0.001) #nn.init.normal_(m.bias, std=0.001) truncated_normal_(m.bias, mean=0, std=0.001) def init_weights_orthogonal_normal(m): if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: nn.init.orthogonal_(m.weight) truncated_normal_(m.bias, mean=0, std=0.001) #nn.init.normal_(m.bias, std=0.001) def l2_regularisation(m): l2_reg = None for W in m.parameters(): if l2_reg is None: l2_reg = W.norm(2) else: l2_reg = l2_reg + W.norm(2) return l2_reg def save_mask_prediction_example(mask, pred, iter): plt.imshow(pred[0,:,:],cmap='Greys') plt.savefig('images/'+str(iter)+"_prediction.png") plt.imshow(mask[0,:,:],cmap='Greys') plt.savefig('images/'+str(iter)+"_mask.png")