| |
| |
| """ |
| Adapted form MONAI Tutorial: https://github.com/Project-MONAI/tutorials/tree/main/2d_segmentation/torch |
| """ |
|
|
| import argparse |
| import os |
|
|
| join = os.path.join |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| from torch.utils.tensorboard import SummaryWriter |
| from stardist import star_dist,edt_prob |
| from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label |
| from stardist import random_label_cmap,ray_angles |
| import monai |
| from collections import OrderedDict |
| from compute_metric import eval_tp_fp_fn,remove_boundary_cells |
| from monai.data import decollate_batch, PILReader |
| from monai.inferers import sliding_window_inference |
| from monai.metrics import DiceMetric |
| from monai.transforms import ( |
| Activations, |
| AsChannelFirstd, |
| AddChanneld, |
| AsDiscrete, |
| Compose, |
| LoadImaged, |
| SpatialPadd, |
| RandSpatialCropd, |
| RandRotate90d, |
| ScaleIntensityd, |
| RandAxisFlipd, |
| RandZoomd, |
| RandGaussianNoised, |
| RandAdjustContrastd, |
| RandGaussianSmoothd, |
| RandHistogramShiftd, |
| EnsureTyped, |
| EnsureType, |
| ) |
| from monai.visualize import plot_2d_or_3d_image |
| import matplotlib.pyplot as plt |
| from datetime import datetime |
| import shutil |
| import tqdm |
| from models.unetr2d import UNETR2D |
| from models.swin_unetr import SwinUNETR |
| from models.flexible_unet import FlexibleUNet |
| from models.flexible_unet_convext import FlexibleUNetConvext |
| print("Successfully imported all requirements!") |
| torch.backends.cudnn.enabled =False |
|
|
| def main(): |
| parser = argparse.ArgumentParser("Baseline for Microscopy image segmentation") |
| |
| parser.add_argument( |
| "--data_path", |
| default="/data2/liuchenyu/external_processed/split", |
| type=str, |
| help="training data path; subfolders: images, labels", |
| ) |
| parser.add_argument( |
| "--work_dir", default="/data/louwei/nips_comp/convnext_fold0", help="path where to save models and logs" |
| ) |
| parser.add_argument("--seed", default=2022, type=int) |
| |
| parser.add_argument("--num_workers", default=8, type=int) |
| parser.add_argument("--local_rank", type=int) |
| |
| parser.add_argument( |
| "--model_name", default="efficientunet", help="select mode: unet, unetr, swinunetr" |
| ) |
| parser.add_argument("--num_class", default=3, type=int, help="segmentation classes") |
| parser.add_argument( |
| "--input_size", default=512, type=int, help="segmentation classes" |
| ) |
| |
| parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU") |
| parser.add_argument("--max_epochs", default=2000, type=int) |
| parser.add_argument("--val_interval", default=5, type=int) |
| parser.add_argument("--epoch_tolerance", default=100, type=int) |
| parser.add_argument("--initial_lr", type=float, default=1e-4, help="learning rate") |
|
|
| args = parser.parse_args() |
| torch.cuda.set_device(args.local_rank) |
| torch.distributed.init_process_group(backend='nccl') |
| monai.config.print_config() |
| n_rays = 32 |
| pre_trained = True |
| |
| np.random.seed(args.seed) |
| model_path = join(args.work_dir, args.model_name + "_3class") |
| os.makedirs(model_path, exist_ok=True) |
| run_id = datetime.now().strftime("%Y%m%d-%H%M") |
| |
| model_file = "models/flexible_unet_convext.py" |
| shutil.copyfile( |
| __file__, join(model_path, os.path.basename(__file__)) |
| ) |
| shutil.copyfile( |
| model_file, join(model_path, os.path.basename(model_file)) |
| ) |
| all_image_path = '/data/louwei/nips_comp/train_cellpose_multi0/' |
| all_img_path = join(all_image_path, "train/images") |
| all_gt_path = join(all_image_path, "train/tif") |
| |
| all_img_names = sorted(os.listdir(all_img_path)) |
| all_gt_names = [img_name.split(".")[0] + ".tif" for img_name in all_img_names] |
| all_img_files = [join(all_img_path, all_img_names[i]) for i in range(len(all_img_names))] |
| all_gt_files = [join(all_gt_path, all_gt_names[i]) for i in range(len(all_img_names))] |
| img_path = join(args.data_path, "train/images") |
| gt_path = join(args.data_path, "train/tif") |
| val_img_path = join(args.data_path, "test/images") |
| val_gt_path = join(args.data_path, "test/tif") |
| img_names = sorted(os.listdir(img_path)) |
| gt_names = [img_name.split(".")[0] + ".tif" for img_name in img_names] |
| train_img_files = [join(img_path, img_names[i]) for i in range(len(img_names))] |
| train_gt_files = [join(gt_path, gt_names[i]) for i in range(len(img_names))] |
| cat_img_files = train_img_files + all_img_files |
| cat_gt_files = train_gt_files + all_gt_files |
| img_num = len(img_names) |
| val_frac = 0.1 |
| val_img_names = sorted(os.listdir(val_img_path)) |
| val_gt_names = [img_name.split(".")[0] + ".tif" for img_name in val_img_names] |
| |
| |
| |
| |
| |
|
|
| train_files = [ |
| {"img": cat_img_files[i], "label": cat_gt_files[i]} |
| for i in range(len(cat_img_files)) |
| ] |
| val_files = [ |
| {"img": join(val_img_path, val_img_names[i]), "label": join(val_gt_path, val_gt_names[i])} |
| for i in range(len(val_img_names)) |
| ] |
| print( |
| f"training image num: {len(train_files)}, validation image num: {len(val_files)}" |
| ) |
| |
| train_transforms = Compose( |
| [ |
| LoadImaged( |
| keys=["img", "label"], reader=PILReader, dtype=np.float32 |
| ), |
| AddChanneld(keys=["label"], allow_missing_keys=True), |
| AsChannelFirstd( |
| keys=["img"], channel_dim=-1, allow_missing_keys=True |
| ), |
| |
| |
| |
| SpatialPadd(keys=["img", "label"], spatial_size=args.input_size), |
| RandSpatialCropd( |
| keys=["img", "label"], roi_size=args.input_size, random_size=False |
| ), |
| RandAxisFlipd(keys=["img", "label"], prob=0.5), |
| RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]), |
| |
| RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1), |
| RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)), |
| RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)), |
| RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3), |
| RandZoomd( |
| keys=["img", "label"], |
| prob=0.15, |
| min_zoom=0.5, |
| max_zoom=2, |
| mode=["area", "nearest"], |
| ), |
| EnsureTyped(keys=["img", "label"]), |
| ] |
| ) |
|
|
| val_transforms = Compose( |
| [ |
| LoadImaged(keys=["img", "label"], reader=PILReader, dtype=np.float32), |
| AddChanneld(keys=["label"], allow_missing_keys=True), |
| AsChannelFirstd(keys=["img"], channel_dim=-1, allow_missing_keys=True), |
| |
| |
| EnsureTyped(keys=["img", "label"]), |
| ] |
| ) |
|
|
| |
| check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) |
| check_loader = DataLoader(check_ds, batch_size=1, num_workers=4) |
| check_data = monai.utils.misc.first(check_loader) |
| print( |
| "sanity check:", |
| check_data["img"].shape, |
| torch.max(check_data["img"]), |
| check_data["label"].shape, |
| torch.max(check_data["label"]), |
| ) |
|
|
| |
| train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) |
| |
| train_loader = DataLoader( |
| train_ds, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=args.num_workers, |
| pin_memory=torch.cuda.is_available(), |
| ) |
| |
| val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) |
| val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=1) |
|
|
| dice_metric = DiceMetric( |
| include_background=False, reduction="mean", get_not_nans=False |
| ) |
|
|
| post_pred = Compose( |
| [EnsureType(), Activations(softmax=True), AsDiscrete(threshold=0.5)] |
| ) |
| post_gt = Compose([EnsureType(), AsDiscrete(to_onehot=None)]) |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| if args.model_name.lower() == "unet": |
| model = monai.networks.nets.UNet( |
| spatial_dims=2, |
| in_channels=3, |
| out_channels=args.num_class, |
| channels=(16, 32, 64, 128, 256), |
| strides=(2, 2, 2, 2), |
| num_res_units=2, |
| ).to(device) |
|
|
| if args.model_name.lower() == "efficientunet": |
| model = FlexibleUNetConvext( |
| in_channels=3, |
| out_channels=n_rays+1, |
| backbone='convnext_small', |
| pretrained=True, |
| ).to(device) |
| |
| if args.model_name.lower() == "swinunetr": |
| model = SwinUNETR( |
| img_size=(args.input_size, args.input_size), |
| in_channels=3, |
| out_channels=n_rays+1, |
| feature_size=24, |
| spatial_dims=2, |
| ).to(device) |
| |
| |
| loss_dice = monai.losses.DiceLoss(squared_pred=True,jaccard=True) |
| loss_bce = nn.BCELoss() |
| loss_dist_mae = nn.L1Loss() |
| activatation = nn.ReLU() |
| sigmoid = nn.Sigmoid() |
| |
| initial_lr = args.initial_lr |
| encoder = list(map(id, model.encoder.parameters())) |
| base_params = filter(lambda p: id(p) not in encoder, model.parameters()) |
| params = [ |
| {"params": base_params, "lr":initial_lr}, |
| {"params": model.encoder.parameters(), "lr": initial_lr * 0.1}, |
| ] |
| optimizer = torch.optim.AdamW(params, initial_lr) |
| |
| |
| |
| |
| |
| |
| |
| print('distributed model') |
| model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) |
| print('successful model') |
| max_epochs = args.max_epochs |
| epoch_tolerance = args.epoch_tolerance |
| val_interval = args.val_interval |
| best_metric = -1 |
| best_metric_epoch = -1 |
| epoch_loss_values = list() |
| metric_values = list() |
| writer = SummaryWriter(model_path) |
| max_f1 = 0 |
| for epoch in range(0, max_epochs): |
| model.train() |
| epoch_loss = 0 |
| epoch_loss_prob = 0 |
| epoch_loss_dist_2 = 0 |
| epoch_loss_dist_1 = 0 |
| for step, batch_data in enumerate(tqdm.tqdm(train_loader), 1): |
| inputs, labels = batch_data["img"],batch_data["label"] |
| print(step) |
| processes_labels = [] |
| |
| for i in range(labels.shape[0]): |
| label = labels[i][0] |
| distances = star_dist(label,n_rays) |
| distances = np.transpose(distances,(2,0,1)) |
| |
| obj_probabilities = edt_prob(label.astype(int)) |
| obj_probabilities = np.expand_dims(obj_probabilities,0) |
| |
| final_label = np.concatenate((distances,obj_probabilities),axis=0) |
| |
| processes_labels.append(final_label) |
| |
| labels = np.stack(processes_labels) |
|
|
| |
| inputs, labels = torch.tensor(inputs).to(device), torch.tensor(labels).to(device) |
| |
| optimizer.zero_grad() |
| output_dist,output_prob = model(inputs) |
| |
| dist_output = output_dist |
| prob_output = output_prob |
| dist_label = labels[:,:n_rays,:,:] |
| prob_label = torch.unsqueeze(labels[:,-1,:,:], 1) |
| |
| |
| |
| |
| |
| loss_dist_1 = loss_dice(dist_output*prob_label,dist_label*prob_label) |
| |
| loss_prob = loss_bce(prob_output,prob_label) |
| |
| loss_dist_2 = loss_dist_mae(dist_output*prob_label,dist_label*prob_label) |
| |
| loss = loss_prob + loss_dist_2*0.3 + loss_dist_1 |
| loss.backward() |
| optimizer.step() |
| epoch_loss += loss.item() |
| epoch_loss_prob += loss_prob.item() |
| epoch_loss_dist_2 += loss_dist_2.item() |
| epoch_loss_dist_1 += loss_dist_1.item() |
| epoch_len = len(train_ds) // train_loader.batch_size |
| |
| epoch_loss /= step |
| epoch_loss_prob /= step |
| epoch_loss_dist_2 /= step |
| epoch_loss_dist_1 /= step |
| epoch_loss_values.append(epoch_loss) |
| print(f"epoch {epoch} average loss: {epoch_loss:.4f}") |
| writer.add_scalar("train_loss", epoch_loss, epoch) |
| print('dist dice: '+str(epoch_loss_dist_1)+' dist mae: '+str(epoch_loss_dist_2)+' prob bce: '+str(epoch_loss_prob)) |
| checkpoint = { |
| "epoch": epoch, |
| "model_state_dict": model.module.state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "loss": epoch_loss_values, |
| } |
| if epoch < 8: |
| continue |
| if epoch > 1 and epoch % val_interval == 0: |
| torch.save(checkpoint, join(model_path, str(epoch) + ".pth")) |
| model.eval() |
| with torch.no_grad(): |
| val_images = None |
| val_labels = None |
| val_outputs = None |
| seg_metric = OrderedDict() |
| seg_metric['F1_Score'] = [] |
| for val_data in tqdm.tqdm(val_loader): |
| val_images, val_labels = val_data["img"].to(device), val_data[ |
| "label" |
| ].to(device) |
| roi_size = (512, 512) |
| sw_batch_size = 4 |
| output_dist,output_prob = sliding_window_inference( |
| val_images, roi_size, sw_batch_size, model |
| ) |
| val_labels = val_labels[0][0].cpu().numpy() |
| prob = output_prob[0][0].cpu().numpy() |
| dist = output_dist[0].cpu().numpy() |
| |
| dist = np.transpose(dist,(1,2,0)) |
| dist = np.maximum(1e-3, dist) |
| points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4) |
|
|
| coord = dist_to_coord(disti,points) |
| |
| star_label = polygons_to_label(disti, points, prob=probi,shape=prob.shape) |
| gt = remove_boundary_cells(val_labels.astype(np.int32)) |
| seg = remove_boundary_cells(star_label.astype(np.int32)) |
| tp, fp, fn = eval_tp_fp_fn(gt, seg, threshold=0.5) |
| if tp == 0: |
| precision = 0 |
| recall = 0 |
| f1 = 0 |
| else: |
| precision = tp / (tp + fp) |
| recall = tp / (tp + fn) |
| f1 = 2*(precision * recall)/ (precision + recall) |
| f1 = np.round(f1, 4) |
| seg_metric['F1_Score'].append(np.round(f1, 4)) |
| avg_f1 = np.mean(seg_metric['F1_Score']) |
| writer.add_scalar("val_f1score", avg_f1, epoch) |
| if avg_f1 > max_f1: |
| max_f1 = avg_f1 |
| print(str(epoch) + 'f1 score: ' + str(max_f1)) |
| torch.save(checkpoint, join(model_path, "best_model.pth")) |
| np.savez_compressed( |
| join(model_path, "train_log.npz"), |
| val_dice=metric_values, |
| epoch_loss=epoch_loss_values, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|