case_dif / util /losses.py
Enes Bol
initial
fd4bbc8
"""
author: Min Seok Lee and Wooseok Shin
"""
import torch
import torch.nn.functional as F
def Optimizer(args, model):
if args.optimizer == 'Adam':
optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer == 'SGD':
optimizer = torch.optim.SGD(params=model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
return optimizer
def Scheduler(args, optimizer):
if args.scheduler == 'Reduce':
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=args.lr_factor, patience=args.patience)
elif args.scheduler == 'Step':
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=2, gamma=0.9)
return scheduler
def Criterion(args):
if args.criterion == 'API':
criterion = adaptive_pixel_intensity_loss
elif args.criterion == 'bce':
criterion = torch.nn.BCELoss()
return criterion
def adaptive_pixel_intensity_loss(pred, mask):
w1 = torch.abs(F.avg_pool2d(mask, kernel_size=3, stride=1, padding=1) - mask)
w2 = torch.abs(F.avg_pool2d(mask, kernel_size=15, stride=1, padding=7) - mask)
w3 = torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
omega = 1 + 0.5 * (w1 + w2 + w3) * mask
bce = F.binary_cross_entropy(pred, mask, reduce=None)
abce = (omega * bce).sum(dim=(2, 3)) / (omega + 0.5).sum(dim=(2, 3))
inter = ((pred * mask) * omega).sum(dim=(2, 3))
union = ((pred + mask) * omega).sum(dim=(2, 3))
aiou = 1 - (inter + 1) / (union - inter + 1)
mae = F.l1_loss(pred, mask, reduce=None)
amae = (omega * mae).sum(dim=(2, 3)) / (omega - 1).sum(dim=(2, 3))
return (0.7 * abce + 0.7 * aiou + 0.7 * amae).mean()