| """
|
| losses for DRDM
|
| """
|
|
|
| import numpy as np
|
| import sys
|
| import torch
|
| import torch.nn.functional as F
|
|
|
|
|
| EPS=1e-7
|
|
|
|
|
|
|
|
|
| eps_scale = 1e-5
|
|
|
|
|
|
|
| class LMSE(torch.nn.Module):
|
| """
|
| Labeled Mean Square Error (LMSE)
|
| """
|
|
|
| def __init__(self, eps=1e-7, relate_eps=5e-1, win=None, smooth=False):
|
| super(LMSE, self).__init__()
|
| self.eps = eps
|
| self.relate_eps = relate_eps
|
| self.ndims = 3
|
| self.smooth = smooth
|
| self.win = win
|
|
|
| if self.win is None:
|
| self.win = [5] * self.ndims
|
| if smooth:
|
| self.kernels = self._build_kernel(std=0.0)
|
|
|
| def _build_kernel(self, std=0.0):
|
| if std == 0.0:
|
| return torch.ones([1, 1, *self.win])
|
| else:
|
| tail = int(np.ceil(std)) * 3
|
| k = torch.exp(-0.5 * torch.arange(-tail, tail + 1, dtype=torch.float32) ** 2 / std ** 2)
|
| kernel = k / torch.sum(k)
|
| kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
|
| return kernel.unsqueeze(0).unsqueeze(0)
|
|
|
| def forward(self, I, J, label=None):
|
| """
|
| Computes the labeled mean squared error between I and J (ref).
|
| If label is provided, computes the MSE only over the labeled regions.
|
| """
|
| padding = [(w-1) // 2 for w in self.win]
|
| if self.smooth:
|
| I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=padding)
|
| J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=padding)
|
| mse = (I - J) ** 2
|
| if self.relate_eps is not None:
|
| mse = mse/((J**2) + self.relate_eps)
|
| if label is not None:
|
| label = label.float()
|
| mse = mse * label
|
| mse_sum = torch.sum(mse, dim=(2, 3, 4))
|
| label_sum = torch.sum(label, dim=(2, 3, 4)) + self.eps
|
| loss = torch.mean(mse_sum / label_sum)
|
| else:
|
| loss = torch.mean(mse)
|
| return loss
|
|
|
| class LNCC(torch.nn.Module):
|
| """
|
| Local (over window) normalized cross-correlation (LNCC)
|
| """
|
|
|
| def __init__(self, win=None, num_ch=1, eps=1e-7, central=True, smooth=False):
|
| super(LNCC, self).__init__()
|
| self.win = win
|
| self.eps = eps
|
| self.central = central
|
| self.ndims = 3
|
| self.strides = [1] * (self.ndims + 2)
|
| self.smooth = smooth
|
|
|
|
|
| if self.win is None:
|
| self.win = [11] * self.ndims
|
|
|
| if smooth:
|
| self.kernels = self._build_kernel(std=0.5)
|
| self.sum_filt = self._build_kernel(std=0.0)
|
|
|
| def _build_kernel(self, std=0.0):
|
| if std == 0.0:
|
| return torch.ones([1, 1, *self.win])
|
| else:
|
| tail = int(np.ceil(std)) * 3
|
| k = torch.exp(-0.5 * torch.arange(-tail, tail + 1, dtype=torch.float32) ** 2 / std ** 2)
|
| kernel = k / torch.sum(k)
|
| kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
|
| return kernel.unsqueeze(0).unsqueeze(0)
|
|
|
| def lncc(self, I, J, label=None):
|
| self.sum_filt = self.sum_filt.to(I.device)
|
| padding = [(w-1) // 2 for w in self.win]
|
|
|
| if self.smooth:
|
| I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=padding)
|
| J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=padding)
|
|
|
|
|
| I2 = I * I
|
| J2 = J * J
|
| IJ = I * J
|
|
|
| if self.central:
|
|
|
| I_sum = torch.nn.functional.conv3d(I, self.sum_filt, stride=1, padding=padding)
|
| J_sum = torch.nn.functional.conv3d(J, self.sum_filt, stride=1, padding=padding)
|
| I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=padding)
|
| J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=padding)
|
| IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=padding)
|
|
|
|
|
| win_size = np.prod(self.win)
|
| cross = IJ_sum - (I_sum * J_sum) / win_size
|
| I_var = I2_sum - (I_sum * I_sum) / win_size
|
| J_var = J2_sum - (J_sum * J_sum) / win_size
|
| else:
|
|
|
| I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=padding)
|
| J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=padding)
|
| IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=padding)
|
|
|
| cross = IJ_sum
|
| I_var = I2_sum
|
| J_var = J2_sum
|
|
|
| cc = (cross * cross) / (I_var * J_var + self.eps)
|
| if label is not None:
|
| label = label.float()
|
| cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
|
|
|
| return torch.mean(cc)
|
|
|
| def forward(self, I, J, label=None):
|
| return -self.lncc(I, J, label=label)
|
|
|
|
|
|
|
| class NCC(torch.nn.Module):
|
|
|
| def __init__(self, eps_scale=10e-5,img_sz=256):
|
| super(NCC, self).__init__()
|
| self.eps_scale=eps_scale
|
|
|
| self.scale=1e2
|
|
|
| def forward(self,pred,inv_lab=None,ddf_stn=None,mask=None):
|
| if ddf_stn is None:
|
| trm_pred=pred
|
| else:
|
| trm_pred=-ddf_stn(pred, inv_lab)
|
| trm_pred = self.scale * trm_pred
|
| inv_lab = self.scale * inv_lab
|
| if mask is None:
|
| loss_gen = torch.mean(torch.sum(trm_pred*inv_lab,dim=1)/(torch.sqrt(torch.sum(torch.square(trm_pred),dim=1)*torch.sum(torch.square(inv_lab),dim=1)+self.eps_scale)))
|
| else:
|
| batch_size = inv_lab.shape[0]
|
| loss_gen = torch.sum(torch.sum(trm_pred*inv_lab,dim=1)*mask/(torch.sqrt(torch.sum(torch.square(trm_pred),dim=1)*torch.sum(torch.square(inv_lab),dim=1)+self.eps_scale)))/torch.sum(mask)/batch_size
|
| return loss_gen
|
|
|
| class MRSE(torch.nn.Module):
|
| def __init__(self, eps_scale=eps_scale,img_sz=256):
|
| super(MRSE, self).__init__()
|
| self.eps_scale=eps_scale
|
| self.scale = 10e1
|
|
|
| def forward(self,pred,inv_lab=None,ddf_stn=None,mask=None):
|
| if ddf_stn is None:
|
| trm_pred=pred
|
| else:
|
| trm_pred=-ddf_stn(pred, inv_lab)
|
| trm_pred = self.scale * trm_pred
|
| inv_lab = self.scale * inv_lab
|
| if mask is None:
|
| loss_gen = torch.mean(
|
| torch.sum(torch.square(trm_pred + inv_lab), dim=1)
|
| / (torch.sum(torch.square(inv_lab), dim=1) + self.eps_scale)
|
| )
|
| else:
|
| batch_size = inv_lab.shape[0]
|
| loss_gen = torch.sum(
|
| torch.sum(torch.square(trm_pred + inv_lab), dim=1) * mask
|
| / (torch.sum(torch.square(inv_lab), dim=1) + self.eps_scale)
|
| )/torch.sum(mask)/batch_size
|
| return loss_gen/1
|
|
|
| class RMSE(torch.nn.Module):
|
| def __init__(self, eps_scale=eps_scale,img_sz=256,ndims=2):
|
| super(RMSE, self).__init__()
|
| self.eps_scale=eps_scale
|
| self.ndims=ndims
|
|
|
| def forward(self,pred,inv_lab=None,ddf_stn=None):
|
| if ddf_stn is None:
|
| trm_pred=pred
|
| else:
|
| trm_pred=-ddf_stn(pred, inv_lab)
|
| loss_gen = torch.mean(torch.mean(torch.sum(torch.square(trm_pred - inv_lab), dim=1),
|
| dim=list(range(1, 1 + self.ndims))) / (
|
| torch.mean(torch.sum(torch.square(inv_lab), dim=1), dim=list(range(1, 1 + self.ndims))) + self.eps_scale))
|
| return loss_gen
|
|
|
|
|
| class Grad(torch.nn.Module):
|
| """
|
| N-D gradient loss
|
| """
|
|
|
| def __init__(self, penalty=['l1'],ndims=3, eps=1e-8, outrange_weight=1e4,outrange_thresh=0.5, detj_weight=2, apear_scale=4, dist=1, sign=1,waive_thresh=10**-5):
|
| super(Grad, self).__init__()
|
| self.penalty = penalty
|
| self.eps = eps
|
| self.outrange_weight = outrange_weight
|
| self.detj_weight=detj_weight
|
| self.apear_scale = apear_scale
|
| self.ndims=ndims
|
| self.max_sz = torch.reshape(torch.tensor([outrange_thresh]*ndims, dtype=torch.float32) , [1]+[ndims]+[1]*(ndims))
|
| self.act = torch.nn.ReLU(inplace=False)
|
| self.dist=dist
|
| self.sign=sign
|
| self.waive_thresh=waive_thresh
|
|
|
| def _diffs(self, y,dist=None):
|
| if dist is None:
|
| dist=self.dist
|
|
|
|
|
|
|
|
|
| df = [None] * self.ndims
|
| for i in range(self.ndims):
|
| d = i + 2
|
|
|
| r = [d, *range(d), *range(d + 1, self.ndims + 2)]
|
| yp = y.permute(r)
|
| dfi = (yp[dist:, ...] - yp[:-dist, ...])/float(dist)
|
|
|
|
|
|
|
|
|
| r = [*range(1, d + 1), 0, *range(d + 1, self.ndims + 2)]
|
| df[i] = dfi.permute(r)
|
| return df
|
|
|
| def _eq_diffs(self, y,dist=None):
|
| if dist is None:
|
| dist=self.dist
|
|
|
| vol_shape = y.size()[2:]
|
| ndims = len(vol_shape)
|
| pad = [0, 0] * (ndims + 1) +[dist, 0]
|
| pad1 = [0, 0] * (ndims + 1) +[0, dist]
|
|
|
| df = [None] * ndims
|
| for i in range(ndims):
|
| d = i + 2
|
| r=[d, *range(d), *range(d + 1, ndims + 2)]
|
| ri=[*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
|
| yt = y.permute(r)
|
| dy=(yt[dist:, ...] - yt[:-dist, ...])/float(dist)
|
| df[i] = (F.pad(dy, pad,mode='constant',value=0)).permute(ri)
|
|
|
|
|
| y.permute(ri)
|
| return df
|
|
|
| def _weighted_diffs_error(self, y,dist=None,w=None,expect=None,mean_dim=None):
|
| if dist is None:
|
| dist=self.dist
|
| vol_shape = y.size()[2:]
|
| ndims = len(vol_shape)
|
| df = [None] * ndims
|
|
|
| for i in range(ndims):
|
| d = i + 2
|
| r=[d, *range(d), *range(d + 1, ndims + 2)]
|
| ri=[*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
|
| yt = y.permute(r)
|
| wt = w.permute(r)
|
| dy=(torch.abs(yt[dist:, ...] - yt[:-dist, ...])-expect.permute(r))*(wt[dist:, ...]*wt[:-dist, ...])
|
| df[i] = torch.mean((dy).permute(ri),dim=mean_dim,keepdim=True)
|
| y.permute(ri)
|
| w.permute(ri)
|
| return df
|
|
|
| def _outl_dist(self, y,range_thresh=0.2):
|
| self.device = y.device
|
| vol_shape = y.size()[2:]
|
| self.max_sz=self.max_sz.to(self.device)
|
| act=torch.nn.ReLU(inplace=True)
|
| loss=0.
|
| for i in range(self.ndims):
|
| d = i + 2
|
|
|
| r = [d, *range(d), *range(d + 1, self.ndims + 2)]
|
| ri = [*range(1, d + 1), 0, *range(d + 1, self.ndims + 2)]
|
| yt = y.permute(r)
|
| loss += torch.mean(torch.square(act(-range_thresh-yt[0,:,i, ...])))+torch.mean(torch.square(act(yt[-1,:,i, ...]-range_thresh)))
|
|
|
| y.permute(ri)
|
| return loss/self.ndims
|
|
|
| def _center_dist(self, y):
|
| self.device = y.device
|
| vol_shape = y.size()[2:]
|
| self.max_sz=self.max_sz.to(self.device)
|
| select_loc = [s // 2 for s in vol_shape]
|
| if self.ndims==3:
|
|
|
| return torch.mean(torch.square(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1], select_loc[2]]) - self.max_sz)))
|
| elif self.ndims == 2:
|
|
|
| return torch.mean(torch.square(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1]]) - self.max_sz)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _eval_detJ(self, disp, add_identity=True, spacing=1.0):
|
| """
|
| disp: list length ndims
|
| disp[i] is derivative wrt spatial dim i (forward diff),
|
| tensor shape [B, C=ndims, ...]
|
| add_identity: True if y_pred is displacement u and phi=x+u
|
| spacing: voxel spacing (or 1.0). If you care about physical units,
|
| divide derivatives by spacing (and dist). Sign won't change.
|
| """
|
|
|
| if spacing != 1.0:
|
| disp = [d / spacing for d in disp]
|
|
|
| if self.ndims == 2:
|
| dux_dx = disp[0][:, 0, ...]
|
| duy_dx = disp[0][:, 1, ...]
|
| dux_dy = disp[1][:, 0, ...]
|
| duy_dy = disp[1][:, 1, ...]
|
|
|
| if add_identity:
|
| j11 = 1.0 + dux_dx
|
| j22 = 1.0 + duy_dy
|
| else:
|
| j11 = dux_dx
|
| j22 = duy_dy
|
|
|
| detj = j11 * j22 - dux_dy * duy_dx
|
| return detj
|
|
|
| elif self.ndims == 3:
|
| dux_dx = disp[0][:, 0, ...]
|
| duy_dx = disp[0][:, 1, ...]
|
| duz_dx = disp[0][:, 2, ...]
|
|
|
| dux_dy = disp[1][:, 0, ...]
|
| duy_dy = disp[1][:, 1, ...]
|
| duz_dy = disp[1][:, 2, ...]
|
|
|
| dux_dz = disp[2][:, 0, ...]
|
| duy_dz = disp[2][:, 1, ...]
|
| duz_dz = disp[2][:, 2, ...]
|
|
|
| if add_identity:
|
| j11 = 1.0 + dux_dx
|
| j22 = 1.0 + duy_dy
|
| j33 = 1.0 + duz_dz
|
| else:
|
| j11 = dux_dx
|
| j22 = duy_dy
|
| j33 = duz_dz
|
|
|
| j12 = dux_dy; j13 = dux_dz
|
| j21 = duy_dx; j23 = duy_dz
|
| j31 = duz_dx; j32 = duz_dy
|
|
|
| detj = (
|
| j11 * (j22 * j33 - j23 * j32)
|
| - j12 * (j21 * j33 - j23 * j31)
|
| + j13 * (j21 * j32 - j22 * j31)
|
| )
|
| return detj
|
|
|
| else:
|
| raise ValueError(f"Unsupported ndims={self.ndims}")
|
|
|
|
|
| def forward(self, y_pred=None,x_in=None, img=None, msk=None):
|
| reg_loss = 0
|
| act=torch.nn.ReLU(inplace=True)
|
|
|
| dg = 1
|
| if img is not None:
|
| dg = torch.exp(-self.apear_scale * sum([torch.sum(g * g, dim=1, keepdim=True) for g in self._eq_diffs(img)]) / torch.sum(torch.square(0.2 + img), dim=1, keepdim=True))
|
| if msk is not None:
|
| dg = dg * msk
|
|
|
| if 'l1' in self.penalty:
|
| df = [torch.mean(dg*F.relu(torch.abs(f) - self.waive_thresh,inplace=True)) for f in self._eq_diffs(y_pred)]
|
| reg_loss += sum(df) / len(df)
|
|
|
| if 'l2' in self.penalty:
|
| df = [torch.mean(dg*F.relu(f * f - self.waive_thresh**2,inplace=True)) for f in self._eq_diffs(y_pred)]
|
| reg_loss += torch.sqrt(sum(df) / len(df))
|
|
|
| if 'negdetj' in self.penalty:
|
| df = self.detj_weight*torch.mean(act(-self._eval_detJ(self._eq_diffs(y_pred,dist=1))))
|
| reg_loss += 0.5*df
|
| if 'range' in self.penalty:
|
| reg_loss += self.outrange_weight * (self._center_dist(y_pred))
|
| if 'param' in self.penalty or 'detj' in self.penalty or 'std' in self.penalty:
|
| mean_dim=list(range(1, self.ndims + 2))
|
| dg = torch.sum(torch.abs(img),dim=1,keepdim=True)* torch.exp(-self.apear_scale * torch.nn.ReLU(inplace=True)(.1-sum([torch.sum(g * g, dim=1, keepdim=True) for g in self._eq_diffs(img,dist=3)]) / torch.sum(torch.square(.1 + img), dim=1, keepdim=True)))
|
| dg = dg/(EPS+torch.mean(dg,dim=mean_dim,keepdim=True))
|
|
|
| y_pred = torch.clamp(y_pred, min=-0.8, max=0.8)
|
| x_in = x_in if isinstance(x_in,list) else [x_in]
|
| if 'std' in self.penalty:
|
| reg_loss += self.sign*torch.mean(torch.clamp(grad_std((y_pred-torch.mean(y_pred,dim=list(range(2,ndims+2)),keepdim=True))*dg), max=.2, min=0))
|
| if 'param' in self.penalty:
|
| for id, d in enumerate(self.dist):
|
| df = torch.mean(torch.abs(sum(self._weighted_diffs_error(y_pred, dist=d, w=dg, expect=torch.abs(x_in[-1][:, id:id + 1, ...]),mean_dim=mean_dim))))
|
| reg_loss += 1 * (df) / len(self.dist)
|
|
|
| if 'detj' in self.penalty:
|
| df = torch.mean(torch.abs(
|
| torch.mean((torch.abs(self._eval_detJ(self._eq_diffs(y_pred, dist=1))) - torch.abs(x_in[0])) * dg, dim=mean_dim)))
|
| reg_loss += 0.5*df
|
|
|
| return reg_loss
|
|
|
|
|
| def avg_std_skew_kurt(array,ndims=2):
|
| dim = list(range(2, ndims + 2))
|
| mean = torch.mean(array,dim=dim)
|
| diffs = array - mean
|
| var = torch.mean(torch.pow(diffs, 2.0),dim=dim)
|
| std = torch.pow(var, 0.5)
|
| zscores = diffs / std
|
| skews = torch.mean(torch.pow(zscores, 3.0),dim=dim)
|
| kurtoses = torch.mean(torch.pow(zscores, 4.0),dim=dim) - 3.0
|
| return [mean,std,skews,kurtoses]
|
|
|
| def grad_std(array,ndims=2):
|
| dim = list(range(2, ndims + 2))
|
| array=torch.clamp(array,min=-0.8,max=0.8)
|
| dim0=list(range(1,ndims+2))
|
| std = torch.sqrt(torch.mean(torch.square(array - torch.mean(array, dim=dim, keepdim=True)), dim=dim0))
|
| return std
|
|
|
| def avg_std(array,ndims=2):
|
| dim = list(range(2, ndims + 2))
|
| return [torch.mean(array,dim=dim),grad_std(array,dim=dim)]
|
|
|
|
|
| if __name__ == "__main__":
|
| ndims=2
|
| dist=[16,32]
|
| ddf = torch.rand(1,2,128,128)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| ddf=ddf
|
| img = torch.rand(1,1,128,128)
|
| x_in=np.reshape([0.2,0.3],newshape=[1,ndims]+[1]*ndims)
|
| x_in=[torch.tensor(x_in).type(torch.float32),0.]
|
|
|
| Loss_detj = Grad(penalty=['detj'],ndims=ndims,dist=dist)
|
| loss_detj = Loss_detj(ddf,x_in,img)
|
| print(loss_detj)
|
|
|