Spaces:
Running
Running
| import numpy as np | |
| import torch as T | |
| from monai.data import (DataLoader, Dataset) | |
| from monai.metrics import compute_meandice | |
| from monai.transforms import ( | |
| AddChanneld, | |
| Compose, | |
| LoadImaged, | |
| NormalizeIntensityd, | |
| ToTensord, | |
| Resized, | |
| RandFlipd, | |
| RandRotate90d, | |
| ThresholdIntensityd, | |
| ScaleIntensityRanged | |
| ) | |
| from Engine.utils import ( | |
| load_model, | |
| read_yaml | |
| ) | |
| from Engine.utils import ( | |
| load_model, | |
| read_yaml | |
| ) | |
| from Engine.utils import ( | |
| load_model, | |
| read_yaml | |
| ) | |
| def init_plot_file(plot_path): | |
| with open(plot_path, "w+") as file: | |
| file.write("step,train_loss,val_loss,dice_score\n") | |
| def append_metrics(plot_path, epoch, train_loss, val_loss, dice_score): | |
| with open(plot_path, 'a') as file: | |
| file.write(f"{epoch},{train_loss},{val_loss},{dice_score}\n") | |
| def train(model, loss_function, optimizer, train_loader, val_loader, device, train_config): | |
| best_val_loss = float('inf') | |
| val_loss = float('inf') | |
| total_steps = 0 | |
| dcs_metric = 0 | |
| step_loss = 0 | |
| optimizer.zero_grad() | |
| for epoch in range(train_config['max_epochs']): | |
| print("-" * 10) | |
| print(f"epoch {epoch + 1}/{train_config['max_epochs']}") | |
| model.train() | |
| train_loss = 0 | |
| step = 0 | |
| accumulated_loss = 0 | |
| if((epoch + 1) % train_config['save_frequency'] == 0): | |
| T.save(model.state_dict(), train_config['model_directory'] + f"model_{epoch + 1}.pth") | |
| for batch_data in train_loader: | |
| total_steps += 1 | |
| step += 1 | |
| inputs, boxes, labels = batch_data["image"].to(device), batch_data["boxes"].to(device), batch_data["label"].to(device) | |
| loss = model.train_step(inputs, labels, loss_function, boxes) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.detach().item() | |
| step_loss += loss.detach().item() | |
| print( | |
| f"step {step}/{len(train_loader.dataset) // train_loader.batch_size}, " | |
| f"train_loss: {loss.item():.6f}") | |
| epoch_len = len(train_loader.dataset) // train_loader.batch_size | |
| accumulated_loss += loss | |
| if(total_steps % train_config['batch_size'] == 0): | |
| accumulated_loss = 0 | |
| step_loss /= train_config['batch_size'] | |
| append_metrics(train_config['metric_path'], total_steps, step_loss, val_loss, dcs_metric) | |
| step_loss = 0 | |
| train_loss /= len(train_loader.dataset) | |
| print(f"epoch {epoch + 1} average loss: {train_loss:.6f}") | |
| if((epoch + 1) % train_config['val_frequency'] == 0): | |
| model.eval() | |
| print("Start eval") | |
| step = 0 | |
| val_loss = 0 | |
| dcs_metric = 0 | |
| with T.no_grad(): | |
| for batch_data in val_loader: | |
| step += 1 | |
| inputs, boxes, labels = batch_data["image"].to(device), batch_data["boxes"].to(device), batch_data["label"].to(device) | |
| outputs = model(inputs) | |
| loss = loss_function(outputs, labels) | |
| outputs = outputs.cpu() | |
| outputs[outputs >= train_config['output_threshold']] = 1 | |
| outputs[outputs < train_config['output_threshold']] = 0 | |
| outputs = outputs.to(device) | |
| dcs_metric += T.mean(compute_meandice(outputs.unsqueeze(0), labels.unsqueeze(0))).item() | |
| val_loss += loss.item() | |
| dcs_metric /= len(val_loader.dataset) | |
| val_loss /= step | |
| if(val_loss < best_val_loss): | |
| best_val_loss = val_loss | |
| T.save(model.state_dict(), train_config['model_directory'] + f"model_best.pth") | |
| print(f"epoch {epoch + 1} validation loss: {val_loss:.6f}") | |
| print(f"epoch {epoch + 1} validation dice score: {dcs_metric:.6f}") | |
| T.save(model.state_dict(), train_config['model_directory'] + f"model_last.pth") | |
| def initiate(config_path): | |
| config = read_yaml(config_path) | |
| init_plot_file(config['train']['metric_path']) | |
| data_paths = config["data"]["train_dataset"] | |
| image_shape = (config["data"]["scale_dim"]["d_0"], config["data"]["scale_dim"]["d_1"], config["data"]["scale_dim"]["d_2"]) | |
| combined_train = [] | |
| combined_val = [] | |
| for data in data_paths: | |
| prefixes = read_yaml(data) | |
| for i, d in enumerate(read_yaml(data)['train']): | |
| instance = { | |
| 'image' : prefixes['image_prefix'] + d['label'], | |
| 'label' : prefixes['label_prefix'] + d['label'], | |
| 'boxes' : prefixes['boxes_prefix'] + d['label'] | |
| } | |
| combined_train.append(instance) | |
| for data in data_paths: | |
| prefixes = read_yaml(data) | |
| for i, d in enumerate(read_yaml(data)['train']): | |
| instance = { | |
| 'image' : prefixes['image_prefix'] + d['label'], | |
| 'label' : prefixes['label_prefix'] + d['label'], | |
| 'boxes' : prefixes['boxes_prefix'] + d['label'] | |
| } | |
| combined_val.append(instance) | |
| train_transform = Compose( | |
| [ | |
| LoadImaged(keys=["image", "boxes", "label"]), | |
| AddChanneld(keys=["image", "boxes", "label"]), | |
| RandFlipd(keys=["image", "boxes", "label"], prob=config["data"]["aug_prob"], spatial_axis=0), | |
| RandFlipd(keys=["image", "boxes", "label"], prob=config["data"]["aug_prob"], spatial_axis=1), | |
| RandFlipd(keys=["image", "boxes", "label"], prob=config["data"]["aug_prob"], spatial_axis=2), | |
| RandRotate90d(keys=["image", "boxes", "label"], prob=config["data"]["aug_prob"], spatial_axes=(0, 1)), | |
| RandRotate90d(keys=["image", "boxes", "label"], prob=config["data"]["aug_prob"], spatial_axes=(0, 2)), | |
| RandRotate90d(keys=["image", "boxes", "label"], prob=config["data"]["aug_prob"], spatial_axes=(1, 2)), | |
| ToTensord(keys=["image", "boxes", "label"]), | |
| ] | |
| ) | |
| val_transform = Compose( | |
| [ | |
| LoadImaged(keys=["image", "boxes", "label"]), | |
| AddChanneld(keys=["image", "boxes", "label"]), | |
| ToTensord(keys=["image", "boxes", "label"]), | |
| ] | |
| ) | |
| train_dataset = Dataset(combined_train, train_transform) | |
| train_loader = T.utils.data.DataLoader(train_dataset, 1, shuffle = True) | |
| val_dataset = Dataset(combined_val, val_transform) | |
| val_loader = T.utils.data.DataLoader(val_dataset, 1) | |
| device = T.device(config["device"]) | |
| model, loss, optimizer = load_model(config['model']) | |
| if (config["model"]["weights"]): | |
| model.load_state_dict(T.load(config["model"]["weights"])) | |
| model.to(device) | |
| print("initiates training!") | |
| train(model, loss, optimizer, train_loader, val_loader, device, config['train']) | |