| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import logging |
| | import os |
| | import random |
| | import sys |
| | import time |
| | from datetime import datetime |
| | from typing import Sequence, Union |
| |
|
| | import monai |
| | import numpy as np |
| | import torch |
| | import torch.distributed as dist |
| | import torch.nn.functional as F |
| | import yaml |
| | from monai import transforms |
| | from monai.bundle import ConfigParser |
| | from monai.data import ThreadDataLoader, partition_dataset |
| | from monai.inferers import sliding_window_inference |
| | from monai.metrics import compute_dice |
| | from monai.utils import set_determinism |
| | from torch.nn.parallel import DistributedDataParallel |
| | from torch.utils.tensorboard import SummaryWriter |
| |
|
| |
|
| | def run(config_file: Union[str, Sequence[str]]): |
| | logging.basicConfig(stream=sys.stdout, level=logging.INFO) |
| |
|
| | parser = ConfigParser() |
| | parser.read_config(config_file) |
| |
|
| | arch_ckpt_path = parser["arch_ckpt_path"] |
| | amp = parser["amp"] |
| | data_file_base_dir = parser["data_file_base_dir"] |
| | data_list_file_path = parser["data_list_file_path"] |
| | determ = parser["determ"] |
| | learning_rate = parser["learning_rate"] |
| | learning_rate_arch = parser["learning_rate_arch"] |
| | learning_rate_milestones = np.array(parser["learning_rate_milestones"]) |
| | num_images_per_batch = parser["num_images_per_batch"] |
| | num_epochs = parser["num_epochs"] |
| | num_epochs_per_validation = parser["num_epochs_per_validation"] |
| | num_epochs_warmup = parser["num_epochs_warmup"] |
| | num_sw_batch_size = parser["num_sw_batch_size"] |
| | output_classes = parser["output_classes"] |
| | overlap_ratio = parser["overlap_ratio"] |
| | patch_size_valid = parser["patch_size_valid"] |
| | ram_cost_factor = parser["ram_cost_factor"] |
| | print("[info] GPU RAM cost factor:", ram_cost_factor) |
| |
|
| | train_transforms = parser.get_parsed_content("transform_train") |
| | val_transforms = parser.get_parsed_content("transform_validation") |
| |
|
| | |
| | if determ: |
| | set_determinism(seed=0) |
| |
|
| | print("[info] number of GPUs:", torch.cuda.device_count()) |
| | if torch.cuda.device_count() > 1: |
| | |
| | dist.init_process_group(backend="nccl", init_method="env://") |
| | world_size = dist.get_world_size() |
| | else: |
| | world_size = 1 |
| | print("[info] world_size:", world_size) |
| |
|
| | with open(data_list_file_path, "r") as f: |
| | json_data = json.load(f) |
| |
|
| | list_train = json_data["training"] |
| | list_valid = json_data["validation"] |
| |
|
| | |
| | files = [] |
| | for _i in range(len(list_train)): |
| | str_img = os.path.join(data_file_base_dir, list_train[_i]["image"]) |
| | str_seg = os.path.join(data_file_base_dir, list_train[_i]["label"]) |
| |
|
| | if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): |
| | continue |
| |
|
| | files.append({"image": str_img, "label": str_seg}) |
| | train_files = files |
| |
|
| | random.shuffle(train_files) |
| |
|
| | train_files_w = train_files[: len(train_files) // 2] |
| | if torch.cuda.device_count() > 1: |
| | train_files_w = partition_dataset( |
| | data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True |
| | )[dist.get_rank()] |
| |
|
| | train_files_a = train_files[len(train_files) // 2 :] |
| | if torch.cuda.device_count() > 1: |
| | train_files_a = partition_dataset( |
| | data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True |
| | )[dist.get_rank()] |
| |
|
| | |
| | files = [] |
| | for _i in range(len(list_valid)): |
| | str_img = os.path.join(data_file_base_dir, list_valid[_i]["image"]) |
| | str_seg = os.path.join(data_file_base_dir, list_valid[_i]["label"]) |
| |
|
| | if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): |
| | continue |
| |
|
| | files.append({"image": str_img, "label": str_seg}) |
| | val_files = files |
| |
|
| | if torch.cuda.device_count() > 1: |
| | val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[ |
| | dist.get_rank() |
| | ] |
| |
|
| | |
| | if torch.cuda.device_count() > 1: |
| | device = torch.device(f"cuda:{dist.get_rank()}") |
| | else: |
| | device = torch.device("cuda:0") |
| | torch.cuda.set_device(device) |
| |
|
| | if torch.cuda.device_count() > 1: |
| | train_ds_a = monai.data.CacheDataset( |
| | data=train_files_a, transform=train_transforms, cache_rate=1.0, num_workers=8 |
| | ) |
| | train_ds_w = monai.data.CacheDataset( |
| | data=train_files_w, transform=train_transforms, cache_rate=1.0, num_workers=8 |
| | ) |
| | val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2) |
| | else: |
| | train_ds_a = monai.data.CacheDataset( |
| | data=train_files_a, transform=train_transforms, cache_rate=0.125, num_workers=8 |
| | ) |
| | train_ds_w = monai.data.CacheDataset( |
| | data=train_files_w, transform=train_transforms, cache_rate=0.125, num_workers=8 |
| | ) |
| | val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.125, num_workers=2) |
| |
|
| | train_loader_a = ThreadDataLoader(train_ds_a, num_workers=6, batch_size=num_images_per_batch, shuffle=True) |
| | train_loader_w = ThreadDataLoader(train_ds_w, num_workers=6, batch_size=num_images_per_batch, shuffle=True) |
| | val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False) |
| |
|
| | model = parser.get_parsed_content("network") |
| | dints_space = parser.get_parsed_content("dints_space") |
| |
|
| | model = model.to(device) |
| | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
| |
|
| | post_pred = transforms.Compose( |
| | [transforms.EnsureType(), transforms.AsDiscrete(argmax=True, to_onehot=output_classes)] |
| | ) |
| | post_label = transforms.Compose([transforms.EnsureType(), transforms.AsDiscrete(to_onehot=output_classes)]) |
| |
|
| | |
| | loss_func = parser.get_parsed_content("loss") |
| |
|
| | |
| | optimizer = torch.optim.SGD( |
| | model.weight_parameters(), lr=learning_rate * world_size, momentum=0.9, weight_decay=0.00004 |
| | ) |
| | arch_optimizer_a = torch.optim.Adam( |
| | [dints_space.log_alpha_a], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0 |
| | ) |
| | arch_optimizer_c = torch.optim.Adam( |
| | [dints_space.log_alpha_c], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0 |
| | ) |
| |
|
| | if torch.cuda.device_count() > 1: |
| | model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True) |
| |
|
| | |
| | if amp: |
| | from torch.cuda.amp import GradScaler, autocast |
| |
|
| | scaler = GradScaler() |
| | if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
| | print("[info] amp enabled") |
| |
|
| | |
| | val_interval = num_epochs_per_validation |
| | best_metric = -1 |
| | best_metric_epoch = -1 |
| | idx_iter = 0 |
| |
|
| | if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
| | writer = SummaryWriter(log_dir=os.path.join(arch_ckpt_path, "Events")) |
| |
|
| | with open(os.path.join(arch_ckpt_path, "accuracy_history.csv"), "a") as f: |
| | f.write("epoch\tmetric\tloss\tlr\ttime\titer\n") |
| |
|
| | dataloader_a_iterator = iter(train_loader_a) |
| |
|
| | start_time = time.time() |
| | for epoch in range(num_epochs): |
| | decay = 0.5 ** np.sum( |
| | [(epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) > learning_rate_milestones] |
| | ) |
| | lr = learning_rate * decay * world_size |
| | for param_group in optimizer.param_groups: |
| | param_group["lr"] = lr |
| |
|
| | if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
| | print("-" * 10) |
| | print(f"epoch {epoch + 1}/{num_epochs}") |
| | print("learning rate is set to {}".format(lr)) |
| |
|
| | model.train() |
| | epoch_loss = 0 |
| | loss_torch = torch.zeros(2, dtype=torch.float, device=device) |
| | epoch_loss_arch = 0 |
| | loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device) |
| | step = 0 |
| |
|
| | for batch_data in train_loader_w: |
| | step += 1 |
| | inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) |
| | if world_size == 1: |
| | for _ in model.weight_parameters(): |
| | _.requires_grad = True |
| | else: |
| | for _ in model.module.weight_parameters(): |
| | _.requires_grad = True |
| | dints_space.log_alpha_a.requires_grad = False |
| | dints_space.log_alpha_c.requires_grad = False |
| |
|
| | optimizer.zero_grad() |
| |
|
| | if amp: |
| | with autocast(): |
| | outputs = model(inputs) |
| | if output_classes == 2: |
| | loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) |
| | else: |
| | loss = loss_func(outputs, labels) |
| |
|
| | scaler.scale(loss).backward() |
| | scaler.step(optimizer) |
| | scaler.update() |
| | else: |
| | outputs = model(inputs) |
| | if output_classes == 2: |
| | loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) |
| | else: |
| | loss = loss_func(outputs, labels) |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | epoch_loss += loss.item() |
| | loss_torch[0] += loss.item() |
| | loss_torch[1] += 1.0 |
| | epoch_len = len(train_loader_w) |
| | idx_iter += 1 |
| |
|
| | if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
| | print("[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") |
| | writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) |
| |
|
| | if epoch < num_epochs_warmup: |
| | continue |
| |
|
| | try: |
| | sample_a = next(dataloader_a_iterator) |
| | except StopIteration: |
| | dataloader_a_iterator = iter(train_loader_a) |
| | sample_a = next(dataloader_a_iterator) |
| | inputs_search, labels_search = (sample_a["image"].to(device), sample_a["label"].to(device)) |
| | if world_size == 1: |
| | for _ in model.weight_parameters(): |
| | _.requires_grad = False |
| | else: |
| | for _ in model.module.weight_parameters(): |
| | _.requires_grad = False |
| | dints_space.log_alpha_a.requires_grad = True |
| | dints_space.log_alpha_c.requires_grad = True |
| |
|
| | |
| | entropy_alpha_c = torch.tensor(0.0).to(device) |
| | entropy_alpha_a = torch.tensor(0.0).to(device) |
| | ram_cost_full = torch.tensor(0.0).to(device) |
| | ram_cost_usage = torch.tensor(0.0).to(device) |
| | ram_cost_loss = torch.tensor(0.0).to(device) |
| | topology_loss = torch.tensor(0.0).to(device) |
| |
|
| | probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True) |
| | entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean() |
| | entropy_alpha_c = -( |
| | F.softmax(dints_space.log_alpha_c, dim=-1) * F.log_softmax(dints_space.log_alpha_c, dim=-1) |
| | ).mean() |
| | topology_loss = dints_space.get_topology_entropy(probs_a) |
| |
|
| | ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape, full=True) |
| | ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape) |
| | ram_cost_loss = torch.abs(ram_cost_factor - ram_cost_usage / ram_cost_full) |
| |
|
| | arch_optimizer_a.zero_grad() |
| | arch_optimizer_c.zero_grad() |
| |
|
| | combination_weights = (epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) |
| |
|
| | if amp: |
| | with autocast(): |
| | outputs_search = model(inputs_search) |
| | if output_classes == 2: |
| | loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) |
| | else: |
| | loss = loss_func(outputs_search, labels_search) |
| |
|
| | loss += combination_weights * ( |
| | (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss + 0.001 * topology_loss |
| | ) |
| |
|
| | scaler.scale(loss).backward() |
| | scaler.step(arch_optimizer_a) |
| | scaler.step(arch_optimizer_c) |
| | scaler.update() |
| | else: |
| | outputs_search = model(inputs_search) |
| | if output_classes == 2: |
| | loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) |
| | else: |
| | loss = loss_func(outputs_search, labels_search) |
| |
|
| | loss += 1.0 * ( |
| | combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss + 0.001 * topology_loss |
| | ) |
| |
|
| | loss.backward() |
| | arch_optimizer_a.step() |
| | arch_optimizer_c.step() |
| |
|
| | epoch_loss_arch += loss.item() |
| | loss_torch_arch[0] += loss.item() |
| | loss_torch_arch[1] += 1.0 |
| |
|
| | if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
| | print( |
| | "[{0}] ".format(str(datetime.now())[:19]) |
| | + f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}" |
| | ) |
| | writer.add_scalar("train_loss_arch", loss.item(), epoch_len * epoch + step) |
| |
|
| | |
| | if torch.cuda.device_count() > 1: |
| | dist.barrier() |
| | dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM) |
| |
|
| | loss_torch = loss_torch.tolist() |
| | loss_torch_arch = loss_torch_arch.tolist() |
| | if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
| | loss_torch_epoch = loss_torch[0] / loss_torch[1] |
| | print( |
| | f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, " |
| | f"best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" |
| | ) |
| |
|
| | if epoch >= num_epochs_warmup: |
| | loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1] |
| | print( |
| | f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, " |
| | f"best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" |
| | ) |
| |
|
| | if (epoch + 1) % val_interval == 0 or (epoch + 1) == num_epochs: |
| | torch.cuda.empty_cache() |
| | model.eval() |
| | with torch.no_grad(): |
| | metric = torch.zeros((output_classes - 1) * 2, dtype=torch.float, device=device) |
| | metric_sum = 0.0 |
| | metric_count = 0 |
| | metric_mat = [] |
| | val_images = None |
| | val_labels = None |
| | val_outputs = None |
| |
|
| | _index = 0 |
| | for val_data in val_loader: |
| | val_images = val_data["image"].to(device) |
| | val_labels = val_data["label"].to(device) |
| |
|
| | roi_size = patch_size_valid |
| | sw_batch_size = num_sw_batch_size |
| |
|
| | if amp: |
| | with torch.cuda.amp.autocast(): |
| | pred = sliding_window_inference( |
| | val_images, |
| | roi_size, |
| | sw_batch_size, |
| | lambda x: model(x), |
| | mode="gaussian", |
| | overlap=overlap_ratio, |
| | ) |
| | else: |
| | pred = sliding_window_inference( |
| | val_images, |
| | roi_size, |
| | sw_batch_size, |
| | lambda x: model(x), |
| | mode="gaussian", |
| | overlap=overlap_ratio, |
| | ) |
| | val_outputs = pred |
| |
|
| | val_outputs = post_pred(val_outputs[0, ...]) |
| | val_outputs = val_outputs[None, ...] |
| | val_labels = post_label(val_labels[0, ...]) |
| | val_labels = val_labels[None, ...] |
| |
|
| | value = compute_dice(y_pred=val_outputs, y=val_labels, include_background=False) |
| |
|
| | print(_index + 1, "/", len(val_loader), value) |
| |
|
| | metric_count += len(value) |
| | metric_sum += value.sum().item() |
| | metric_vals = value.cpu().numpy() |
| | if len(metric_mat) == 0: |
| | metric_mat = metric_vals |
| | else: |
| | metric_mat = np.concatenate((metric_mat, metric_vals), axis=0) |
| |
|
| | for _c in range(output_classes - 1): |
| | val0 = torch.nan_to_num(value[0, _c], nan=0.0) |
| | val1 = 1.0 - torch.isnan(value[0, 0]).float() |
| | metric[2 * _c] += val0 * val1 |
| | metric[2 * _c + 1] += val1 |
| |
|
| | _index += 1 |
| |
|
| | |
| | if torch.cuda.device_count() > 1: |
| | dist.barrier() |
| | dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) |
| |
|
| | metric = metric.tolist() |
| | if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
| | for _c in range(output_classes - 1): |
| | print("evaluation metric - class {0:d}:".format(_c + 1), metric[2 * _c] / metric[2 * _c + 1]) |
| | avg_metric = 0 |
| | for _c in range(output_classes - 1): |
| | avg_metric += metric[2 * _c] / metric[2 * _c + 1] |
| | avg_metric = avg_metric / float(output_classes - 1) |
| | print("avg_metric", avg_metric) |
| |
|
| | if avg_metric > best_metric: |
| | best_metric = avg_metric |
| | best_metric_epoch = epoch + 1 |
| | best_metric_iterations = idx_iter |
| |
|
| | (node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d) = dints_space.decode() |
| | torch.save( |
| | { |
| | "node_a": node_a_d, |
| | "arch_code_a": arch_code_a_d, |
| | "arch_code_a_max": arch_code_a_max_d, |
| | "arch_code_c": arch_code_c_d, |
| | "iter_num": idx_iter, |
| | "epochs": epoch + 1, |
| | "best_dsc": best_metric, |
| | "best_path": best_metric_iterations, |
| | }, |
| | os.path.join(arch_ckpt_path, "search_code_" + str(idx_iter) + ".pt"), |
| | ) |
| | print("saved new best metric model") |
| |
|
| | dict_file = {} |
| | dict_file["best_avg_dice_score"] = float(best_metric) |
| | dict_file["best_avg_dice_score_epoch"] = int(best_metric_epoch) |
| | dict_file["best_avg_dice_score_iteration"] = int(idx_iter) |
| | with open(os.path.join(arch_ckpt_path, "progress.yaml"), "w") as out_file: |
| | _ = yaml.dump(dict_file, stream=out_file) |
| |
|
| | print( |
| | "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( |
| | epoch + 1, avg_metric, best_metric, best_metric_epoch |
| | ) |
| | ) |
| |
|
| | current_time = time.time() |
| | elapsed_time = (current_time - start_time) / 60.0 |
| | with open(os.path.join(arch_ckpt_path, "accuracy_history.csv"), "a") as f: |
| | f.write( |
| | "{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n".format( |
| | epoch + 1, avg_metric, loss_torch_epoch, lr, elapsed_time, idx_iter |
| | ) |
| | ) |
| |
|
| | if torch.cuda.device_count() > 1: |
| | dist.barrier() |
| |
|
| | torch.cuda.empty_cache() |
| |
|
| | print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") |
| |
|
| | if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
| | writer.close() |
| |
|
| | if torch.cuda.device_count() > 1: |
| | dist.destroy_process_group() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | from monai.utils import optional_import |
| |
|
| | fire, _ = optional_import("fire") |
| | fire.Fire() |
| |
|