BrainFM / Trainer /models /__init__.py
peirong26's picture
Upload 187 files
2571f24 verified
"""
Submodule interface.
"""
import torch
from .backbone import build_backbone
from .criterion import *
from .evaluator import Evaluator
from .head import get_head
from .joiner import get_processors, get_joiner
import utils.misc as utils
#########################################
# some constants
label_list_segmentation_brainseg_left = [0, 1, 2, 3, 4, 7, 8, 9, 10, 14, 15, 17, 31, 34, 36, 38, 40, 42]
n_labels_brainseg_left = len(label_list_segmentation_brainseg_left)
label_list_segmentation_brainseg_with_extracerebral = [0, 11, 12, 13, 16, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 14, 15, 17, 47, 49, 51, 53, 55,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 48, 50, 52, 54, 56]
n_neutral_labels_brainseg_with_extracerebral = 20
n_labels_brainseg_with_extracerebral = len(label_list_segmentation_brainseg_with_extracerebral)
nlat = int((n_labels_brainseg_with_extracerebral - n_neutral_labels_brainseg_with_extracerebral) / 2.0)
vflip = np.concatenate([np.array(range(n_neutral_labels_brainseg_with_extracerebral)),
np.array(range(n_neutral_labels_brainseg_with_extracerebral + nlat, n_labels_brainseg_with_extracerebral)),
np.array(range(n_neutral_labels_brainseg_with_extracerebral, n_neutral_labels_brainseg_with_extracerebral + nlat))])
############################################
############# helper functions #############
############################################
def process_args(gen_args, train_args, task):
"""
task options: feat-anat, feat-seg, feat-anat-seg, anat, seg, reg, sr, bf
"""
gen_args.tasks = [key for (key, value) in vars(task).items() if value]
gen_args.generator.size = gen_args.generator.size # update real sample size (if sample_size is downsampled)
train_args.size = gen_args.generator.size
if gen_args.generator.left_hemis_only:
gen_args.label_list_segmentation = label_list_segmentation_brainseg_left
gen_args.n_labels = n_labels_brainseg_left
else:
gen_args.label_list_segmentation = label_list_segmentation_brainseg_with_extracerebral
gen_args.n_labels = n_labels_brainseg_with_extracerebral
train_args.out_channels = {}
train_args.output_names = []
train_args.aux_output_names = []
train_args.target_names = []
if not 'contrastive' in gen_args.tasks:
if 'T1' in gen_args.tasks:
train_args.out_channels['T1'] = 2 if train_args.losses.uncertainty is not None else 1
train_args.output_names += ['T1']
train_args.target_names += ['T1']
if train_args.losses.uncertainty is not None:
train_args.aux_output_names += ['T1_sigma']
if 'T2' in gen_args.tasks:
train_args.out_channels['T2'] = 2 if train_args.losses.uncertainty is not None else 1
train_args.output_names += ['T2']
train_args.target_names += ['T2']
if train_args.losses.uncertainty is not None:
train_args.aux_output_names += ['T2_sigma']
if 'FLAIR' in gen_args.tasks:
train_args.out_channels['FLAIR'] = 2 if train_args.losses.uncertainty is not None else 1
train_args.output_names += ['FLAIR']
train_args.target_names += ['FLAIR']
if train_args.losses.uncertainty is not None:
train_args.aux_output_names += ['FLAIR_sigma']
if 'CT' in gen_args.tasks:
train_args.out_channels['CT'] = 2 if train_args.losses.uncertainty is not None else 1
train_args.output_names += ['CT']
train_args.target_names += ['CT']
if train_args.losses.uncertainty is not None: # TODO
train_args.aux_output_names += ['CT_sigma']
if 'bias_field' in gen_args.tasks:
train_args.out_channels['bias_field_log'] = 2 if train_args.losses.uncertainty is not None else 1
train_args.output_names += ['bias_field']
train_args.target_names += ['bias_field']
if 'segmentation' in gen_args.tasks:
train_args.out_channels['segmentation'] = gen_args.n_labels
train_args.output_names += ['label']
train_args.target_names += ['label']
if 'distance' in gen_args.tasks:
if gen_args.generator.left_hemis_only:
train_args.out_channels['distance'] = 2
train_args.output_names += ['distance', 'lp', 'lw']
train_args.target_names += ['distance', 'lp', 'lw']
else:
train_args.out_channels['distance'] = 4
train_args.output_names += ['distance', 'lp', 'lw', 'rp', 'rw']
train_args.target_names += ['distance', 'lp', 'lw', 'rp', 'rw']
if 'registration' in gen_args.tasks:
train_args.out_channels['registration'] = 3
train_args.output_names += ['registration', 'regx', 'regy', 'regz']
train_args.target_names += ['registration', 'regx', 'regy', 'regz']
if 'surface' in gen_args.tasks:
train_args.out_channels['surface'] = 8
train_args.output_names += ['surface']
train_args.target_names += ['surface']
if 'super_resolution' in gen_args.tasks:
train_args.out_channels['high_res_residual'] = 2 if train_args.losses.uncertainty is not None else 1
train_args.output_names += ['high_res', 'high_res_residual']
train_args.target_names += ['high_res', 'high_res_residual']
if 'pathology' in gen_args.tasks:
train_args.out_channels['pathology'] = 1
train_args.output_names += ['pathology']
train_args.target_names += ['pathology']
if 'age' in gen_args.tasks:
train_args.out_channels['age'] = -1
if train_args.losses.implicit_pathol: # TODO
train_args.output_names += ['implicit_pathol_orig']
train_args.output_names += ['implicit_pathol_pred']
#assert len(train_args.output_names) > 0
return gen_args, train_args
############################################
################ CRITERIONS ################
############################################
def get_evaluator(args, task, device):
"""
task options: sr, seg, anat, reg
"""
metric_names = []
if 'T1' in task or 'T2' in task or 'FLAIR' in task or 'CT' in task:
metric_names += ['feat_ssim', 'feat_ms_ssim', 'feat_l1']
else:
if 'T1' in task: # TODO
metric_names += ['recon_l1', 'recon_psnr', 'recon_ssim', 'recon_ms_ssim']
if 'super_resolution' in task:
metric_names += ['sr_l1', 'sr_psnr', 'sr_ssim', 'sr_ms_ssim']
if 'bias_field' in task:
metric_names += ['bf_normalized_l2', 'bf_corrected_l1']
if 'segmentation' in task:
metric_names += ['seg_dice']
if 'pathology' in task:
metric_names += ['pathol_dice']
assert len(metric_names) > 0
evaluator = Evaluator(
args = args,
metric_names = metric_names,
device = device,
)
return evaluator
def get_criterion(gen_args, train_args, tasks, device, exclude_keys = []):
"""
task options: sr, seg, anat, reg
"""
loss_names = []
weight_dict = {}
if 'contrastive' in tasks:
loss_names += ['contrastive']
weight_dict['loss_contrastive'] = train_args.weights.contrastive
return SetCriterion(
gen_args = gen_args,
train_args = train_args,
weight_dict = weight_dict,
loss_names = loss_names,
device = device,
)
for task in tasks:
if 'T1' in task or 'T2' in task or 'FLAIR' in task or 'CT' in task:
name = task
loss_names += [name]
weight_dict.update({'loss_%s' % name: train_args.weights.image})
if train_args.losses.image_grad:
loss_names += ['%s_grad' % name]
weight_dict['loss_%s_grad' % name] = train_args.weights.image_grad
if 'segmentation' in task:
loss_names += ['seg_ce', 'seg_dice']
weight_dict.update( {
'loss_seg_ce': train_args.weights.seg_ce,
'loss_seg_dice': train_args.weights.seg_dice,
} )
if 'bias_field' in task:
loss_names += ['bias_field_log']
weight_dict.update( {
'loss_bias_field_log': train_args.weights.bias_field_log,
} )
if 'super_resolution' in task:
loss_names += ['SR']
weight_dict.update( {
'loss_SR': train_args.weights.image,
} )
if train_args.losses.image_grad:
loss_names += ['SR_grad']
weight_dict['loss_SR_grad'] = train_args.weights.image_grad
if 'distance' in task:
loss_names += ['distance']
weight_dict.update( {
'loss_distance': train_args.weights.distance,
} )
if 'registration' in task:
loss_names += ['registration']
weight_dict.update( {
'loss_registration': train_args.weights.registration,
} )
if train_args.losses.registration_grad:
loss_names += ['registration_grad']
weight_dict['loss_registration_grad'] = train_args.weights.registration_grad
if train_args.losses.registration_smooth:
loss_names += ['registration_smooth']
weight_dict['loss_registration_smooth'] = train_args.weights.registration_smooth
if train_args.losses.registration_hessian:
loss_names += ['registration_hessian']
weight_dict['loss_registration_hessian'] = train_args.weights.registration_hessian
if 'surface' in task:
loss_names += ['surface']
weight_dict['loss_surface'] = train_args.weights.surface
if 'age' in task:
loss_names += ['age']
weight_dict['loss_age'] = train_args.weights.age
if 'pathology' in task and 'pathology' not in exclude_keys:
loss_names += ['pathol_ce', 'pathol_dice']
weight_dict.update( {
'loss_pathol_ce': train_args.weights.pathol_ce,
'loss_pathol_dice': train_args.weights.pathol_dice,
} )
if train_args.losses.implicit_pathol:
loss_names += ['implicit_pathol_ce', 'implicit_pathol_dice']
weight_dict.update( {
'loss_implicit_pathol_ce': train_args.weights.implicit_pathol_ce,
'loss_implicit_pathol_dice': train_args.weights.implicit_pathol_dice,
} )
assert len(loss_names) > 0
criterion = SetMultiCriterion(
gen_args = gen_args,
train_args = train_args,
weight_dict = weight_dict,
loss_names = loss_names,
device = device,
)
return criterion
def get_postprocessor(gen_args, train_args, outputs, samples, target, feats, tasks):
"""
output: list of output dict
feat: list of output dict from pre-trained feat extractor
"""
if 'distance' in tasks and target is not None:
if gen_args.generator.left_hemis_only:
target.update({'lp': target['distance'][:, 0][:, None],
'lw': target['distance'][:, 1][:, None]})
else:
target.update({'lp': target['distance'][:, 0][:, None],
'lw': target['distance'][:, 1][:, None],
'rp': target['distance'][:, 2][:, None],
'rw': target['distance'][:, 3][:, None]})
del target['distance']
if 'registration' in tasks and target is not None:
target.update({'regx': target['registration'][:, 0][:, None],
'regy': target['registration'][:, 1][:, None],
'regz': target['registration'][:, 2][:, None]})
del target['registration']
if 'CT' in tasks and target is not None:
target['CT'] = target['CT'] * 1000
if 'segmentation' in tasks and target is not None:
target['label'] = torch.tensor(gen_args.label_list_segmentation,
device = target['segmentation'].device)[torch.argmax(target['segmentation'], 1, keepdim = True)] # (b, n_labels, s, r, c) -> (b, s, r, c)
for i, output in enumerate(outputs):
if feats is not None:
output.update({'feat': feats[i]['feat']})
if 'super_resolution' in tasks:
output.update({'high_res': output['high_res_residual'] + samples[i]['input']})
if 'high_res_residual' in samples[i]:
samples[i].update({'high_res': samples[i]['high_res_residual'] + samples[i]['input']})
if 'bias_field' in tasks:
output.update({'bias_field': torch.exp(output['bias_field_log'])})
del output['bias_field_log']
if 'bias_field_log' in samples[i]:
samples[i].update({'bias_field': torch.exp(samples[i]['bias_field_log'])})
del samples[i]['bias_field_log']
if 'distance' in tasks:
a = 2
if gen_args.generator.left_hemis_only:
output.update({'lp': output['distance'][:, 0][:, None],
'lw': output['distance'][:, 1][:, None]})
fake = 70 * (1 - (torch.tanh(a * (output['lw'] + 0.3)) + 1) / 2) + 40 * (1 - (torch.tanh(a * output['lp']) + 1) / 2)
else:
output.update({'lp': output['distance'][:, 0][:, None],
'lw': output['distance'][:, 1][:, None],
'rp': output['distance'][:, 2][:, None],
'rw': output['distance'][:, 3][:, None]})
fakeL = 70 * (1 - (torch.tanh(a * (output['lw'] + 0.3)) + 1) / 2) + 40 * (1 - (torch.tanh(a * output['lp']) + 1) / 2)
fakeR = 70 * (1 - (torch.tanh(a * (output['rw'] + 0.3)) + 1) / 2) + 40 * (1 - (torch.tanh(a * output['rp']) + 1) / 2)
fake = fakeL + fakeR
output.update({'fake_cortical': fake})
del output['distance']
if 'registration' in tasks:
output.update({'regx': output['registration'][:, 0][:, None],
'regy': output['registration'][:, 1][:, None],
'regz': output['registration'][:, 2][:, None]})
del output['registration']
if 'segmentation' in tasks:
output['label'] = torch.tensor(gen_args.label_list_segmentation,
device = output['segmentation'].device)[torch.argmax(output['segmentation'], 1, keepdim = True)] # (b, n_labels, s, r, c) -> (b, s, r, c)
if 'CT' in tasks:
output['CT'] = output['CT'] * 1000
return outputs, samples, target
#############################################
################ OPTIMIZERS #################
#############################################
def build_optimizer(train_args, params_groups):
if train_args.optimizer == "adam":
return torch.optim.Adam(params_groups)
elif train_args.optimizer == "adamw":
return torch.optim.AdamW(params_groups) # to use with ViTs
elif train_args.optimizer == "sgd":
return torch.optim.SGD(params_groups, lr=0, momentum=0.9) # lr is set by scheduler
elif train_args.optimizer == "lars":
return utils.LARS(params_groups) # to use with convnet and large batches
else:
ValueError('optim type {args.optimizer.type} supported!')
def build_schedulers(train_args, itr_per_epoch, lr, min_lr):
if train_args.lr_scheduler == "cosine":
lr_scheduler = utils.cosine_scheduler(
lr, # * (args.batch_size * utils.get_world_size()) / 256., # linear scaling rule
min_lr,
train_args.n_epochs, itr_per_epoch,
warmup_epochs=train_args.warmup_epochs
)
elif train_args.lr_scheduler == "multistep":
lr_scheduler = utils.multistep_scheduler(
lr,
train_args.lr_drops,
train_args.n_epochs, itr_per_epoch,
warmup_epochs=train_args.warmup_epochs,
gamma=train_args.lr_drop_multi
)
wd_scheduler = utils.cosine_scheduler(
train_args.weight_decay, # set as 0 to disable it
train_args.weight_decay_end,
train_args.n_epochs, itr_per_epoch
)
return lr_scheduler, wd_scheduler
############################################
################## MODELS ##################
############################################
def build_model(gen_args, train_args, device = 'cpu'):
gen_args, train_args = process_args(gen_args, train_args, task = gen_args.task)
backbone = build_backbone(train_args, train_args.backbone)
head = get_head(train_args, train_args.task_f_maps, train_args.out_channels, True, -1)
model = get_joiner(gen_args.tasks, backbone, head, device)
processors = get_processors(gen_args, train_args, gen_args.tasks, device)
criterion = get_criterion(gen_args, train_args, gen_args.tasks, device)
criterion.to(device)
model.to(device)
postprocessor = get_postprocessor
return gen_args, train_args, model, processors, criterion, postprocessor
def build_conditioned_model(gen_args, train_args, device = 'cpu'): # mask-conditioned inpaiting
gen_args, train_args = process_args(gen_args, train_args, task = gen_args.task)
backbone = build_backbone(train_args, train_args.backbone, num_cond = len(train_args.condition.split('+')))
head = get_head(train_args, train_args.task_f_maps, train_args.out_channels, True, -1, stage = 1, exclude_keys = ['pathology'])
model = get_joiner(gen_args.tasks, backbone, head, device)
processors = get_processors(gen_args, train_args, gen_args.tasks, device, exclude_keys = ['pathology'])
criterion = get_criterion(gen_args, train_args, gen_args.tasks, device, exclude_keys = ['pathology'])
criterion.to(device)
model.to(device)
postprocessor = get_postprocessor
return gen_args, train_args, model, processors, criterion, postprocessor
def build_inpaint_model(gen_args, train_args, device = 'cpu'): # two-stage inpainting
gen_args, train_args = process_args(gen_args, train_args, task = gen_args.task)
# stage-0: pathology mask prediction
pathol_backbone = build_backbone(train_args, train_args.backbone.split('+')[0], num_cond = 0)
pathol_head = get_head(train_args, train_args.task_f_maps, train_args.out_channels, True, -1, stage = 0)
pathol_model = get_joiner(gen_args.tasks, pathol_backbone, pathol_head, device, postfix = '_pathol')
pathol_processors = get_processors(train_args, ['pathology'], device)
# stage-1: pathology-mask-conditioned task prediction (inpainting)
task_backbone = build_backbone(train_args, train_args.backbone.split('+')[1], num_cond = 1)
task_head = get_head(train_args, train_args.task_f_maps, train_args.out_channels, True, -1, stage = 1)
task_model = get_joiner(gen_args.tasks, task_backbone, task_head, device, postfix = '_task')
task_processors = get_processors(gen_args, train_args, gen_args.tasks, device, exclude_keys = ['pathology'])
criterion = get_criterion(gen_args, train_args, gen_args.tasks, device)
criterion.to(device)
pathol_model.to(device)
task_model.to(device)
postprocessor = get_postprocessor
return gen_args, train_args, pathol_model, task_model, pathol_processors, task_processors, criterion, postprocessor