Omini3D / Diffusion /losses_ncc0.py
maxmo2009's picture
Initial upload: OmniMorph codebase
75854b3 verified
"""
losses for DRDM
"""
import numpy as np
import sys
import torch
import torch.nn.functional as F
EPS=1e-7
# eps_scale = 10e-5
# eps_scale = 10e-4
# eps_scale = 1e-4
eps_scale = 1e-5
class LMSE(torch.nn.Module):
"""
Labeled Mean Square Error (LMSE)
"""
def __init__(self, eps=1e-7, relate_eps=5e-1, win=None, smooth=False):
super(LMSE, self).__init__()
self.eps = eps
self.relate_eps = relate_eps
self.ndims = 3
self.smooth = smooth
self.win = win
# Set window size
if self.win is None:
self.win = [5] * self.ndims
if smooth:
self.kernels = self._build_kernel(std=0.0)
def _build_kernel(self, std=0.0):
if std == 0.0:
return torch.ones([1, 1, *self.win])
else:
tail = int(np.ceil(std)) * 3
k = torch.exp(-0.5 * torch.arange(-tail, tail + 1, dtype=torch.float32) ** 2 / std ** 2)
kernel = k / torch.sum(k)
kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
return kernel.unsqueeze(0).unsqueeze(0)
def forward(self, I, J, label=None):
"""
Computes the labeled mean squared error between I and J (ref).
If label is provided, computes the MSE only over the labeled regions.
"""
padding = [(w-1) // 2 for w in self.win]
if self.smooth:
I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=padding)
J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=padding)
mse = (I - J) ** 2
if self.relate_eps is not None:
mse = mse/((J**2) + self.relate_eps)
if label is not None:
label = label.float()
mse = mse * label
mse_sum = torch.sum(mse, dim=(2, 3, 4))
label_sum = torch.sum(label, dim=(2, 3, 4)) + self.eps
loss = torch.mean(mse_sum / label_sum)
else:
loss = torch.mean(mse)
return loss
class LNCC(torch.nn.Module):
"""
Local (over window) normalized cross-correlation (LNCC)
"""
def __init__(self, win=None, num_ch=1, eps=1e-7, central=True, smooth=False):
super(LNCC, self).__init__()
self.win = win
self.eps = eps
self.central = central
self.ndims = 3
self.strides = [1] * (self.ndims + 2)
self.smooth = smooth
# Set window size
if self.win is None:
self.win = [11] * self.ndims
if smooth:
self.kernels = self._build_kernel(std=0.5)
self.sum_filt = self._build_kernel(std=0.0)
def _build_kernel(self, std=0.0):
if std == 0.0:
return torch.ones([1, 1, *self.win])
else:
tail = int(np.ceil(std)) * 3
k = torch.exp(-0.5 * torch.arange(-tail, tail + 1, dtype=torch.float32) ** 2 / std ** 2)
kernel = k / torch.sum(k)
kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
return kernel.unsqueeze(0).unsqueeze(0)
def lncc(self, I, J, label=None):
self.sum_filt = self.sum_filt.to(I.device)
padding = [(w-1) // 2 for w in self.win]
if self.smooth:
I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=padding)
J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=padding)
# Compute CC squares
I2 = I * I
J2 = J * J
IJ = I * J
if self.central:
# Compute local sums via convolution
I_sum = torch.nn.functional.conv3d(I, self.sum_filt, stride=1, padding=padding)
J_sum = torch.nn.functional.conv3d(J, self.sum_filt, stride=1, padding=padding)
I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=padding)
J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=padding)
IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=padding)
# Compute cross-correlation
win_size = np.prod(self.win)
cross = IJ_sum - (I_sum * J_sum) / win_size
I_var = I2_sum - (I_sum * I_sum) / win_size
J_var = J2_sum - (J_sum * J_sum) / win_size
else:
# Compute local sums via convolution
I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=padding)
J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=padding)
IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=padding)
cross = IJ_sum
I_var = I2_sum
J_var = J2_sum
cc = (cross * cross) / (I_var * J_var + self.eps)
if label is not None:
label = label.float()
cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
return torch.mean(cc)
def forward(self, I, J, label=None):
return -self.lncc(I, J, label=label)
class NCC(torch.nn.Module):
# def __init__(self, eps_scale=10e-7,img_sz=256):
def __init__(self, eps_scale=10e-5,img_sz=256):
super(NCC, self).__init__()
self.eps_scale=eps_scale#*img_sz/256
# self.scale=10e4
self.scale=1e2
def forward(self,pred,inv_lab=None,ddf_stn=None,mask=None):
if ddf_stn is None:
trm_pred=pred
else:
trm_pred=-ddf_stn(pred, inv_lab)
trm_pred = self.scale * trm_pred
inv_lab = self.scale * inv_lab
if mask is None:
loss_gen = torch.mean(torch.sum(trm_pred*inv_lab,dim=1)/(torch.sqrt(torch.sum(torch.square(trm_pred),dim=1)*torch.sum(torch.square(inv_lab),dim=1)+self.eps_scale)))
else:
batch_size = inv_lab.shape[0]
loss_gen = torch.sum(torch.sum(trm_pred*inv_lab,dim=1)*mask/(torch.sqrt(torch.sum(torch.square(trm_pred),dim=1)*torch.sum(torch.square(inv_lab),dim=1)+self.eps_scale)))/torch.sum(mask)/batch_size
return loss_gen
class MRSE(torch.nn.Module):
def __init__(self, eps_scale=eps_scale,img_sz=256):
super(MRSE, self).__init__()
self.eps_scale=eps_scale#*img_sz/256
self.scale = 10e1
def forward(self,pred,inv_lab=None,ddf_stn=None,mask=None):
if ddf_stn is None:
trm_pred=pred
else:
trm_pred=-ddf_stn(pred, inv_lab)
trm_pred = self.scale * trm_pred
inv_lab = self.scale * inv_lab
if mask is None:
loss_gen = torch.mean(
torch.sum(torch.square(trm_pred + inv_lab), dim=1)
/ (torch.sum(torch.square(inv_lab), dim=1) + self.eps_scale)
)
else:
batch_size = inv_lab.shape[0]
loss_gen = torch.sum(
torch.sum(torch.square(trm_pred + inv_lab), dim=1) * mask
/ (torch.sum(torch.square(inv_lab), dim=1) + self.eps_scale)
)/torch.sum(mask)/batch_size
return loss_gen/1
class RMSE(torch.nn.Module):
def __init__(self, eps_scale=eps_scale,img_sz=256,ndims=2):
super(RMSE, self).__init__()
self.eps_scale=eps_scale#*img_sz/256
self.ndims=ndims
def forward(self,pred,inv_lab=None,ddf_stn=None):
if ddf_stn is None:
trm_pred=pred
else:
trm_pred=-ddf_stn(pred, inv_lab)
loss_gen = torch.mean(torch.mean(torch.sum(torch.square(trm_pred - inv_lab), dim=1),
dim=list(range(1, 1 + self.ndims))) / (
torch.mean(torch.sum(torch.square(inv_lab), dim=1), dim=list(range(1, 1 + self.ndims))) + self.eps_scale))
return loss_gen
# loss_gen = torch.mean(torch.mean(torch.sum(torch.square(ddf_stn(pre_dvf_I, dvf_I) + dvf_I), dim=1),dim=list(range(1,1+ndims))) / (torch.mean(torch.sum(torch.square(dvf_I), dim=1),dim=list(range(1,1+ndims))) + EPS))
class Grad(torch.nn.Module):
"""
N-D gradient loss
"""
def __init__(self, penalty=['l1'],ndims=3, eps=1e-8, outrange_weight=1e4,outrange_thresh=0.5, detj_weight=2, apear_scale=4, dist=1, sign=1,waive_thresh=10**-5):
super(Grad, self).__init__()
self.penalty = penalty
self.eps = eps
self.outrange_weight = outrange_weight
self.detj_weight=detj_weight
self.apear_scale = apear_scale
self.ndims=ndims
self.max_sz = torch.reshape(torch.tensor([outrange_thresh]*ndims, dtype=torch.float32) , [1]+[ndims]+[1]*(ndims))
self.act = torch.nn.ReLU(inplace=False)
self.dist=dist
self.sign=sign
self.waive_thresh=waive_thresh
def _diffs(self, y,dist=None):
if dist is None:
dist=self.dist
# vol_shape = y.size()[2:]
# vol_shape = y.get_shape().as_list()[1:-1]
# ndims = len(vol_shape)
df = [None] * self.ndims
for i in range(self.ndims):
d = i + 2
# permute dimensions to put the ith dimension first
r = [d, *range(d), *range(d + 1, self.ndims + 2)]
yp = y.permute(r)
dfi = (yp[dist:, ...] - yp[:-dist, ...])/float(dist)
# permute back
# note: this might not be necessary for this loss specifically,
# since the results are just summed over anyway.
r = [*range(1, d + 1), 0, *range(d + 1, self.ndims + 2)]
df[i] = dfi.permute(r)
return df
def _eq_diffs(self, y,dist=None):
if dist is None:
dist=self.dist
# vol_shape = y.get_shape().as_list()[1:-1]
vol_shape = y.size()[2:]
ndims = len(vol_shape)
pad = [0, 0] * (ndims + 1) +[dist, 0]
pad1 = [0, 0] * (ndims + 1) +[0, dist]
# df = [None, None] * ndims
df = [None] * ndims
for i in range(ndims):
d = i + 2
r=[d, *range(d), *range(d + 1, ndims + 2)]
ri=[*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
yt = y.permute(r)
dy=(yt[dist:, ...] - yt[:-dist, ...])/float(dist)
df[i] = (F.pad(dy, pad,mode='constant',value=0)).permute(ri)
# df[2*i] = (F.pad(dy, pad,mode='constant',value=0)).permute(ri)
# df[2*i+1] = (F.pad(dy, pad1, mode='constant', value=0)).permute(ri)
y.permute(ri)
return df
def _weighted_diffs_error(self, y,dist=None,w=None,expect=None,mean_dim=None):
if dist is None:
dist=self.dist
vol_shape = y.size()[2:]
ndims = len(vol_shape)
df = [None] * ndims
for i in range(ndims):
d = i + 2
r=[d, *range(d), *range(d + 1, ndims + 2)]
ri=[*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
yt = y.permute(r)
wt = w.permute(r)
dy=(torch.abs(yt[dist:, ...] - yt[:-dist, ...])-expect.permute(r))*(wt[dist:, ...]*wt[:-dist, ...])
df[i] = torch.mean((dy).permute(ri),dim=mean_dim,keepdim=True)
y.permute(ri)
w.permute(ri)
return df
def _outl_dist(self, y,range_thresh=0.2):
self.device = y.device
vol_shape = y.size()[2:]
self.max_sz=self.max_sz.to(self.device)
act=torch.nn.ReLU(inplace=True)
loss=0.
for i in range(self.ndims):
d = i + 2
# permute dimensions to put the ith dimension first
r = [d, *range(d), *range(d + 1, self.ndims + 2)]
ri = [*range(1, d + 1), 0, *range(d + 1, self.ndims + 2)]
yt = y.permute(r)
loss += torch.mean(torch.square(act(-range_thresh-yt[0,:,i, ...])))+torch.mean(torch.square(act(yt[-1,:,i, ...]-range_thresh)))
# loss += torch.mean(torch.square(act(-range_thresh-yt[0,:,i, ...])+act(yt[-1,:,i, ...]-range_thresh)))
y.permute(ri)
return loss/self.ndims
def _center_dist(self, y):
self.device = y.device
vol_shape = y.size()[2:]
self.max_sz=self.max_sz.to(self.device)
select_loc = [s // 2 for s in vol_shape]
if self.ndims==3:
# return torch.mean(self.act(torch.abs(y[:,:, select_loc[0], select_loc[1], select_loc[2]]) - self.max_sz))
return torch.mean(torch.square(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1], select_loc[2]]) - self.max_sz)))
elif self.ndims == 2:
# return torch.mean(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1]]) - self.max_sz))
return torch.mean(torch.square(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1]]) - self.max_sz)))
# def _eval_detJ(self, disp=None, weight=None):
# weight = 1
# if self.ndims==3:
# detj = (disp[0][:, 0, ...] * disp[1][:, 1, ...] * disp[2][:, 2, ...]) + (
# disp[0][:, 1, ...] * disp[1][:, 2, ...] * disp[2][:, 0, ...]) + (
# disp[0][:, 2, ...] * disp[1][:, 0, ...] * disp[2][:, 1, ...]) - (
# disp[0][:, 2, ...] * disp[1][:, 1, ...] * disp[2][:, 0, ...]) - (
# disp[0][:, 0, ...] * disp[1][:, 2, ...] * disp[2][:, 1, ...]) - (
# disp[0][:, 1, ...] * disp[1][:, 0, ...] * disp[2][:, 2, ...])
# elif self.ndims==2:
# detj = (disp[0][:, 0, ...] * disp[1][:, 1, ...]) - (disp[0][:, 1, ...] * disp[1][:, 0, ...])
# return detj * weight
def _eval_detJ(self, disp, add_identity=True, spacing=1.0):
"""
disp: list length ndims
disp[i] is derivative wrt spatial dim i (forward diff),
tensor shape [B, C=ndims, ...]
add_identity: True if y_pred is displacement u and phi=x+u
spacing: voxel spacing (or 1.0). If you care about physical units,
divide derivatives by spacing (and dist). Sign won't change.
"""
# Optional scaling (won't affect sign as long as spacing>0)
if spacing != 1.0:
disp = [d / spacing for d in disp]
if self.ndims == 2:
dux_dx = disp[0][:, 0, ...]
duy_dx = disp[0][:, 1, ...]
dux_dy = disp[1][:, 0, ...]
duy_dy = disp[1][:, 1, ...]
if add_identity:
j11 = 1.0 + dux_dx
j22 = 1.0 + duy_dy
else:
j11 = dux_dx
j22 = duy_dy
detj = j11 * j22 - dux_dy * duy_dx
return detj
elif self.ndims == 3:
dux_dx = disp[0][:, 0, ...]
duy_dx = disp[0][:, 1, ...]
duz_dx = disp[0][:, 2, ...]
dux_dy = disp[1][:, 0, ...]
duy_dy = disp[1][:, 1, ...]
duz_dy = disp[1][:, 2, ...]
dux_dz = disp[2][:, 0, ...]
duy_dz = disp[2][:, 1, ...]
duz_dz = disp[2][:, 2, ...]
if add_identity:
j11 = 1.0 + dux_dx
j22 = 1.0 + duy_dy
j33 = 1.0 + duz_dz
else:
j11 = dux_dx
j22 = duy_dy
j33 = duz_dz
j12 = dux_dy; j13 = dux_dz
j21 = duy_dx; j23 = duy_dz
j31 = duz_dx; j32 = duz_dy
detj = (
j11 * (j22 * j33 - j23 * j32)
- j12 * (j21 * j33 - j23 * j31)
+ j13 * (j21 * j32 - j22 * j31)
)
return detj
else:
raise ValueError(f"Unsupported ndims={self.ndims}")
def forward(self, y_pred=None,x_in=None, img=None, msk=None):
reg_loss = 0
act=torch.nn.ReLU(inplace=True)
dg = 1
if img is not None:
dg = torch.exp(-self.apear_scale * sum([torch.sum(g * g, dim=1, keepdim=True) for g in self._eq_diffs(img)]) / torch.sum(torch.square(0.2 + img), dim=1, keepdim=True))
if msk is not None:
dg = dg * msk
if 'l1' in self.penalty:
df = [torch.mean(dg*F.relu(torch.abs(f) - self.waive_thresh,inplace=True)) for f in self._eq_diffs(y_pred)]
reg_loss += sum(df) / len(df)
if 'l2' in self.penalty:
df = [torch.mean(dg*F.relu(f * f - self.waive_thresh**2,inplace=True)) for f in self._eq_diffs(y_pred)]
reg_loss += torch.sqrt(sum(df) / len(df))
if 'negdetj' in self.penalty:
df = self.detj_weight*torch.mean(act(-self._eval_detJ(self._eq_diffs(y_pred,dist=1)))) # , dg[...,0])
reg_loss += 0.5*df
if 'range' in self.penalty:
reg_loss += self.outrange_weight * (self._center_dist(y_pred)) #self._outl_dist(y_pred))#+
if 'param' in self.penalty or 'detj' in self.penalty or 'std' in self.penalty:
mean_dim=list(range(1, self.ndims + 2))
dg = torch.sum(torch.abs(img),dim=1,keepdim=True)* torch.exp(-self.apear_scale * torch.nn.ReLU(inplace=True)(.1-sum([torch.sum(g * g, dim=1, keepdim=True) for g in self._eq_diffs(img,dist=3)]) / torch.sum(torch.square(.1 + img), dim=1, keepdim=True)))
dg = dg/(EPS+torch.mean(dg,dim=mean_dim,keepdim=True))
y_pred = torch.clamp(y_pred, min=-0.8, max=0.8)
x_in = x_in if isinstance(x_in,list) else [x_in]
if 'std' in self.penalty:
reg_loss += self.sign*torch.mean(torch.clamp(grad_std((y_pred-torch.mean(y_pred,dim=list(range(2,ndims+2)),keepdim=True))*dg), max=.2, min=0))
if 'param' in self.penalty:
for id, d in enumerate(self.dist):
df = torch.mean(torch.abs(sum(self._weighted_diffs_error(y_pred, dist=d, w=dg, expect=torch.abs(x_in[-1][:, id:id + 1, ...]),mean_dim=mean_dim))))
reg_loss += 1 * (df) / len(self.dist)
if 'detj' in self.penalty:
df = torch.mean(torch.abs(
torch.mean((torch.abs(self._eval_detJ(self._eq_diffs(y_pred, dist=1))) - torch.abs(x_in[0])) * dg, dim=mean_dim)))
reg_loss += 0.5*df
return reg_loss
def avg_std_skew_kurt(array,ndims=2):
dim = list(range(2, ndims + 2))
mean = torch.mean(array,dim=dim)
diffs = array - mean
var = torch.mean(torch.pow(diffs, 2.0),dim=dim)
std = torch.pow(var, 0.5)
zscores = diffs / std
skews = torch.mean(torch.pow(zscores, 3.0),dim=dim)
kurtoses = torch.mean(torch.pow(zscores, 4.0),dim=dim) - 3.0
return [mean,std,skews,kurtoses]
def grad_std(array,ndims=2):
dim = list(range(2, ndims + 2))
array=torch.clamp(array,min=-0.8,max=0.8)
dim0=list(range(1,ndims+2))
std = torch.sqrt(torch.mean(torch.square(array - torch.mean(array, dim=dim, keepdim=True)), dim=dim0))
return std
def avg_std(array,ndims=2):
dim = list(range(2, ndims + 2))
return [torch.mean(array,dim=dim),grad_std(array,dim=dim)]
if __name__ == "__main__":
ndims=2
dist=[16,32]
ddf = torch.rand(1,2,128,128)
# ddf[:,:,0,:]=ddf[:,:,0,:]-1
# ddf[:,:,1,:]=ddf[:,:,1,:]+1
# ddf[:,:,0,0]=ddf[:,:,0,0] -1
# ddf[:,:,1,1]=ddf[:,:,1,1] +1
# ddf[:,0,0,1]=ddf[:,0,0,1] +1
# ddf[:,1,0,1]=ddf[:,1,0,1] -1
# ddf[:,0,0,1]=ddf[:,0,0,1] -1
# ddf[:,1,0,1]=ddf[:,1,0,1] +1
# ddf[:,1,1,0]=ddf[:,1,1,0] -1
# ddf[:,0,1,0]=ddf[:,0,1,0] +1
ddf=ddf
img = torch.rand(1,1,128,128)
x_in=np.reshape([0.2,0.3],newshape=[1,ndims]+[1]*ndims)
x_in=[torch.tensor(x_in).type(torch.float32),0.]
Loss_detj = Grad(penalty=['detj'],ndims=ndims,dist=dist)
loss_detj = Loss_detj(ddf,x_in,img)
print(loss_detj)