""" losses for DRDM """ import numpy as np import sys import torch import torch.nn.functional as F EPS=1e-7 # eps_scale = 10e-5 # eps_scale = 10e-4 # eps_scale = 1e-4 eps_scale = 1e-5 class LMSE(torch.nn.Module): """ Labeled Mean Square Error (LMSE) """ def __init__(self, eps=1e-7, relate_eps=5e-1, win=None, smooth=False): super(LMSE, self).__init__() self.eps = eps self.relate_eps = relate_eps self.ndims = 3 self.smooth = smooth self.win = win # Set window size if self.win is None: self.win = [5] * self.ndims if smooth: self.kernels = self._build_kernel(std=0.0) def _build_kernel(self, std=0.0): if std == 0.0: return torch.ones([1, 1, *self.win]) else: tail = int(np.ceil(std)) * 3 k = torch.exp(-0.5 * torch.arange(-tail, tail + 1, dtype=torch.float32) ** 2 / std ** 2) kernel = k / torch.sum(k) kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1) return kernel.unsqueeze(0).unsqueeze(0) def forward(self, I, J, label=None): """ Computes the labeled mean squared error between I and J (ref). If label is provided, computes the MSE only over the labeled regions. """ padding = [(w-1) // 2 for w in self.win] if self.smooth: I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=padding) J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=padding) mse = (I - J) ** 2 if self.relate_eps is not None: mse = mse/((J**2) + self.relate_eps) if label is not None: label = label.float() mse = mse * label mse_sum = torch.sum(mse, dim=(2, 3, 4)) label_sum = torch.sum(label, dim=(2, 3, 4)) + self.eps loss = torch.mean(mse_sum / label_sum) else: loss = torch.mean(mse) return loss class LNCC(torch.nn.Module): """ Local (over window) normalized cross-correlation (LNCC) """ def __init__(self, win=None, num_ch=1, eps=1e-7, central=True, smooth=False): super(LNCC, self).__init__() self.win = win self.eps = eps self.central = central self.ndims = 3 self.strides = [1] * (self.ndims + 2) self.smooth = smooth # Set window size if self.win is None: self.win = [11] * self.ndims if smooth: self.kernels = self._build_kernel(std=0.5) self.sum_filt = self._build_kernel(std=0.0) def _build_kernel(self, std=0.0): if std == 0.0: return torch.ones([1, 1, *self.win]) else: tail = int(np.ceil(std)) * 3 k = torch.exp(-0.5 * torch.arange(-tail, tail + 1, dtype=torch.float32) ** 2 / std ** 2) kernel = k / torch.sum(k) kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1) return kernel.unsqueeze(0).unsqueeze(0) def lncc(self, I, J, label=None): self.sum_filt = self.sum_filt.to(I.device) padding = [(w-1) // 2 for w in self.win] if self.smooth: I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=padding) J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=padding) # Compute CC squares I2 = I * I J2 = J * J IJ = I * J if self.central: # Compute local sums via convolution I_sum = torch.nn.functional.conv3d(I, self.sum_filt, stride=1, padding=padding) J_sum = torch.nn.functional.conv3d(J, self.sum_filt, stride=1, padding=padding) I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=padding) J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=padding) IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=padding) # Compute cross-correlation win_size = np.prod(self.win) cross = IJ_sum - (I_sum * J_sum) / win_size I_var = I2_sum - (I_sum * I_sum) / win_size J_var = J2_sum - (J_sum * J_sum) / win_size else: # Compute local sums via convolution I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=padding) J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=padding) IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=padding) cross = IJ_sum I_var = I2_sum J_var = J2_sum cc = (cross * cross) / (I_var * J_var + self.eps) if label is not None: label = label.float() cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps) return torch.mean(cc) def forward(self, I, J, label=None): return -self.lncc(I, J, label=label) class NCC(torch.nn.Module): # def __init__(self, eps_scale=10e-7,img_sz=256): def __init__(self, eps_scale=10e-5,img_sz=256): super(NCC, self).__init__() self.eps_scale=eps_scale#*img_sz/256 # self.scale=10e4 self.scale=1e2 def forward(self,pred,inv_lab=None,ddf_stn=None,mask=None): if ddf_stn is None: trm_pred=pred else: trm_pred=-ddf_stn(pred, inv_lab) trm_pred = self.scale * trm_pred inv_lab = self.scale * inv_lab if mask is None: loss_gen = torch.mean(torch.sum(trm_pred*inv_lab,dim=1)/(torch.sqrt(torch.sum(torch.square(trm_pred),dim=1)*torch.sum(torch.square(inv_lab),dim=1)+self.eps_scale))) else: batch_size = inv_lab.shape[0] loss_gen = torch.sum(torch.sum(trm_pred*inv_lab,dim=1)*mask/(torch.sqrt(torch.sum(torch.square(trm_pred),dim=1)*torch.sum(torch.square(inv_lab),dim=1)+self.eps_scale)))/torch.sum(mask)/batch_size return loss_gen class MRSE(torch.nn.Module): def __init__(self, eps_scale=eps_scale,img_sz=256): super(MRSE, self).__init__() self.eps_scale=eps_scale#*img_sz/256 self.scale = 10e1 def forward(self,pred,inv_lab=None,ddf_stn=None,mask=None): if ddf_stn is None: trm_pred=pred else: trm_pred=-ddf_stn(pred, inv_lab) trm_pred = self.scale * trm_pred inv_lab = self.scale * inv_lab if mask is None: loss_gen = torch.mean( torch.sum(torch.square(trm_pred + inv_lab), dim=1) / (torch.sum(torch.square(inv_lab), dim=1) + self.eps_scale) ) else: batch_size = inv_lab.shape[0] loss_gen = torch.sum( torch.sum(torch.square(trm_pred + inv_lab), dim=1) * mask / (torch.sum(torch.square(inv_lab), dim=1) + self.eps_scale) )/torch.sum(mask)/batch_size return loss_gen/1 class RMSE(torch.nn.Module): def __init__(self, eps_scale=eps_scale,img_sz=256,ndims=2): super(RMSE, self).__init__() self.eps_scale=eps_scale#*img_sz/256 self.ndims=ndims def forward(self,pred,inv_lab=None,ddf_stn=None): if ddf_stn is None: trm_pred=pred else: trm_pred=-ddf_stn(pred, inv_lab) loss_gen = torch.mean(torch.mean(torch.sum(torch.square(trm_pred - inv_lab), dim=1), dim=list(range(1, 1 + self.ndims))) / ( torch.mean(torch.sum(torch.square(inv_lab), dim=1), dim=list(range(1, 1 + self.ndims))) + self.eps_scale)) return loss_gen # loss_gen = torch.mean(torch.mean(torch.sum(torch.square(ddf_stn(pre_dvf_I, dvf_I) + dvf_I), dim=1),dim=list(range(1,1+ndims))) / (torch.mean(torch.sum(torch.square(dvf_I), dim=1),dim=list(range(1,1+ndims))) + EPS)) class Grad(torch.nn.Module): """ N-D gradient loss """ def __init__(self, penalty=['l1'],ndims=3, eps=1e-8, outrange_weight=1e4,outrange_thresh=0.5, detj_weight=2, apear_scale=4, dist=1, sign=1,waive_thresh=10**-5): super(Grad, self).__init__() self.penalty = penalty self.eps = eps self.outrange_weight = outrange_weight self.detj_weight=detj_weight self.apear_scale = apear_scale self.ndims=ndims self.max_sz = torch.reshape(torch.tensor([outrange_thresh]*ndims, dtype=torch.float32) , [1]+[ndims]+[1]*(ndims)) self.act = torch.nn.ReLU(inplace=False) self.dist=dist self.sign=sign self.waive_thresh=waive_thresh def _diffs(self, y,dist=None): if dist is None: dist=self.dist # vol_shape = y.size()[2:] # vol_shape = y.get_shape().as_list()[1:-1] # ndims = len(vol_shape) df = [None] * self.ndims for i in range(self.ndims): d = i + 2 # permute dimensions to put the ith dimension first r = [d, *range(d), *range(d + 1, self.ndims + 2)] yp = y.permute(r) dfi = (yp[dist:, ...] - yp[:-dist, ...])/float(dist) # permute back # note: this might not be necessary for this loss specifically, # since the results are just summed over anyway. r = [*range(1, d + 1), 0, *range(d + 1, self.ndims + 2)] df[i] = dfi.permute(r) return df def _eq_diffs(self, y,dist=None): if dist is None: dist=self.dist # vol_shape = y.get_shape().as_list()[1:-1] vol_shape = y.size()[2:] ndims = len(vol_shape) pad = [0, 0] * (ndims + 1) +[dist, 0] pad1 = [0, 0] * (ndims + 1) +[0, dist] # df = [None, None] * ndims df = [None] * ndims for i in range(ndims): d = i + 2 r=[d, *range(d), *range(d + 1, ndims + 2)] ri=[*range(1, d + 1), 0, *range(d + 1, ndims + 2)] yt = y.permute(r) dy=(yt[dist:, ...] - yt[:-dist, ...])/float(dist) df[i] = (F.pad(dy, pad,mode='constant',value=0)).permute(ri) # df[2*i] = (F.pad(dy, pad,mode='constant',value=0)).permute(ri) # df[2*i+1] = (F.pad(dy, pad1, mode='constant', value=0)).permute(ri) y.permute(ri) return df def _weighted_diffs_error(self, y,dist=None,w=None,expect=None,mean_dim=None): if dist is None: dist=self.dist vol_shape = y.size()[2:] ndims = len(vol_shape) df = [None] * ndims for i in range(ndims): d = i + 2 r=[d, *range(d), *range(d + 1, ndims + 2)] ri=[*range(1, d + 1), 0, *range(d + 1, ndims + 2)] yt = y.permute(r) wt = w.permute(r) dy=(torch.abs(yt[dist:, ...] - yt[:-dist, ...])-expect.permute(r))*(wt[dist:, ...]*wt[:-dist, ...]) df[i] = torch.mean((dy).permute(ri),dim=mean_dim,keepdim=True) y.permute(ri) w.permute(ri) return df def _outl_dist(self, y,range_thresh=0.2): self.device = y.device vol_shape = y.size()[2:] self.max_sz=self.max_sz.to(self.device) act=torch.nn.ReLU(inplace=True) loss=0. for i in range(self.ndims): d = i + 2 # permute dimensions to put the ith dimension first r = [d, *range(d), *range(d + 1, self.ndims + 2)] ri = [*range(1, d + 1), 0, *range(d + 1, self.ndims + 2)] yt = y.permute(r) loss += torch.mean(torch.square(act(-range_thresh-yt[0,:,i, ...])))+torch.mean(torch.square(act(yt[-1,:,i, ...]-range_thresh))) # loss += torch.mean(torch.square(act(-range_thresh-yt[0,:,i, ...])+act(yt[-1,:,i, ...]-range_thresh))) y.permute(ri) return loss/self.ndims def _center_dist(self, y): self.device = y.device vol_shape = y.size()[2:] self.max_sz=self.max_sz.to(self.device) select_loc = [s // 2 for s in vol_shape] if self.ndims==3: # return torch.mean(self.act(torch.abs(y[:,:, select_loc[0], select_loc[1], select_loc[2]]) - self.max_sz)) return torch.mean(torch.square(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1], select_loc[2]]) - self.max_sz))) elif self.ndims == 2: # return torch.mean(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1]]) - self.max_sz)) return torch.mean(torch.square(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1]]) - self.max_sz))) # def _eval_detJ(self, disp=None, weight=None): # weight = 1 # if self.ndims==3: # detj = (disp[0][:, 0, ...] * disp[1][:, 1, ...] * disp[2][:, 2, ...]) + ( # disp[0][:, 1, ...] * disp[1][:, 2, ...] * disp[2][:, 0, ...]) + ( # disp[0][:, 2, ...] * disp[1][:, 0, ...] * disp[2][:, 1, ...]) - ( # disp[0][:, 2, ...] * disp[1][:, 1, ...] * disp[2][:, 0, ...]) - ( # disp[0][:, 0, ...] * disp[1][:, 2, ...] * disp[2][:, 1, ...]) - ( # disp[0][:, 1, ...] * disp[1][:, 0, ...] * disp[2][:, 2, ...]) # elif self.ndims==2: # detj = (disp[0][:, 0, ...] * disp[1][:, 1, ...]) - (disp[0][:, 1, ...] * disp[1][:, 0, ...]) # return detj * weight def _eval_detJ(self, disp, add_identity=True, spacing=1.0): """ disp: list length ndims disp[i] is derivative wrt spatial dim i (forward diff), tensor shape [B, C=ndims, ...] add_identity: True if y_pred is displacement u and phi=x+u spacing: voxel spacing (or 1.0). If you care about physical units, divide derivatives by spacing (and dist). Sign won't change. """ # Optional scaling (won't affect sign as long as spacing>0) if spacing != 1.0: disp = [d / spacing for d in disp] if self.ndims == 2: dux_dx = disp[0][:, 0, ...] duy_dx = disp[0][:, 1, ...] dux_dy = disp[1][:, 0, ...] duy_dy = disp[1][:, 1, ...] if add_identity: j11 = 1.0 + dux_dx j22 = 1.0 + duy_dy else: j11 = dux_dx j22 = duy_dy detj = j11 * j22 - dux_dy * duy_dx return detj elif self.ndims == 3: dux_dx = disp[0][:, 0, ...] duy_dx = disp[0][:, 1, ...] duz_dx = disp[0][:, 2, ...] dux_dy = disp[1][:, 0, ...] duy_dy = disp[1][:, 1, ...] duz_dy = disp[1][:, 2, ...] dux_dz = disp[2][:, 0, ...] duy_dz = disp[2][:, 1, ...] duz_dz = disp[2][:, 2, ...] if add_identity: j11 = 1.0 + dux_dx j22 = 1.0 + duy_dy j33 = 1.0 + duz_dz else: j11 = dux_dx j22 = duy_dy j33 = duz_dz j12 = dux_dy; j13 = dux_dz j21 = duy_dx; j23 = duy_dz j31 = duz_dx; j32 = duz_dy detj = ( j11 * (j22 * j33 - j23 * j32) - j12 * (j21 * j33 - j23 * j31) + j13 * (j21 * j32 - j22 * j31) ) return detj else: raise ValueError(f"Unsupported ndims={self.ndims}") def forward(self, y_pred=None,x_in=None, img=None, msk=None): reg_loss = 0 act=torch.nn.ReLU(inplace=True) dg = 1 if img is not None: dg = torch.exp(-self.apear_scale * sum([torch.sum(g * g, dim=1, keepdim=True) for g in self._eq_diffs(img)]) / torch.sum(torch.square(0.2 + img), dim=1, keepdim=True)) if msk is not None: dg = dg * msk if 'l1' in self.penalty: df = [torch.mean(dg*F.relu(torch.abs(f) - self.waive_thresh,inplace=True)) for f in self._eq_diffs(y_pred)] reg_loss += sum(df) / len(df) if 'l2' in self.penalty: df = [torch.mean(dg*F.relu(f * f - self.waive_thresh**2,inplace=True)) for f in self._eq_diffs(y_pred)] reg_loss += torch.sqrt(sum(df) / len(df)) if 'negdetj' in self.penalty: df = self.detj_weight*torch.mean(act(-self._eval_detJ(self._eq_diffs(y_pred,dist=1)))) # , dg[...,0]) reg_loss += 0.5*df if 'range' in self.penalty: reg_loss += self.outrange_weight * (self._center_dist(y_pred)) #self._outl_dist(y_pred))#+ if 'param' in self.penalty or 'detj' in self.penalty or 'std' in self.penalty: mean_dim=list(range(1, self.ndims + 2)) dg = torch.sum(torch.abs(img),dim=1,keepdim=True)* torch.exp(-self.apear_scale * torch.nn.ReLU(inplace=True)(.1-sum([torch.sum(g * g, dim=1, keepdim=True) for g in self._eq_diffs(img,dist=3)]) / torch.sum(torch.square(.1 + img), dim=1, keepdim=True))) dg = dg/(EPS+torch.mean(dg,dim=mean_dim,keepdim=True)) y_pred = torch.clamp(y_pred, min=-0.8, max=0.8) x_in = x_in if isinstance(x_in,list) else [x_in] if 'std' in self.penalty: reg_loss += self.sign*torch.mean(torch.clamp(grad_std((y_pred-torch.mean(y_pred,dim=list(range(2,ndims+2)),keepdim=True))*dg), max=.2, min=0)) if 'param' in self.penalty: for id, d in enumerate(self.dist): df = torch.mean(torch.abs(sum(self._weighted_diffs_error(y_pred, dist=d, w=dg, expect=torch.abs(x_in[-1][:, id:id + 1, ...]),mean_dim=mean_dim)))) reg_loss += 1 * (df) / len(self.dist) if 'detj' in self.penalty: df = torch.mean(torch.abs( torch.mean((torch.abs(self._eval_detJ(self._eq_diffs(y_pred, dist=1))) - torch.abs(x_in[0])) * dg, dim=mean_dim))) reg_loss += 0.5*df return reg_loss def avg_std_skew_kurt(array,ndims=2): dim = list(range(2, ndims + 2)) mean = torch.mean(array,dim=dim) diffs = array - mean var = torch.mean(torch.pow(diffs, 2.0),dim=dim) std = torch.pow(var, 0.5) zscores = diffs / std skews = torch.mean(torch.pow(zscores, 3.0),dim=dim) kurtoses = torch.mean(torch.pow(zscores, 4.0),dim=dim) - 3.0 return [mean,std,skews,kurtoses] def grad_std(array,ndims=2): dim = list(range(2, ndims + 2)) array=torch.clamp(array,min=-0.8,max=0.8) dim0=list(range(1,ndims+2)) std = torch.sqrt(torch.mean(torch.square(array - torch.mean(array, dim=dim, keepdim=True)), dim=dim0)) return std def avg_std(array,ndims=2): dim = list(range(2, ndims + 2)) return [torch.mean(array,dim=dim),grad_std(array,dim=dim)] if __name__ == "__main__": ndims=2 dist=[16,32] ddf = torch.rand(1,2,128,128) # ddf[:,:,0,:]=ddf[:,:,0,:]-1 # ddf[:,:,1,:]=ddf[:,:,1,:]+1 # ddf[:,:,0,0]=ddf[:,:,0,0] -1 # ddf[:,:,1,1]=ddf[:,:,1,1] +1 # ddf[:,0,0,1]=ddf[:,0,0,1] +1 # ddf[:,1,0,1]=ddf[:,1,0,1] -1 # ddf[:,0,0,1]=ddf[:,0,0,1] -1 # ddf[:,1,0,1]=ddf[:,1,0,1] +1 # ddf[:,1,1,0]=ddf[:,1,1,0] -1 # ddf[:,0,1,0]=ddf[:,0,1,0] +1 ddf=ddf img = torch.rand(1,1,128,128) x_in=np.reshape([0.2,0.3],newshape=[1,ndims]+[1]*ndims) x_in=[torch.tensor(x_in).type(torch.float32),0.] Loss_detj = Grad(penalty=['detj'],ndims=ndims,dist=dist) loss_detj = Loss_detj(ddf,x_in,img) print(loss_detj)