| | import argparse |
| | import os |
| | import ruamel_yaml as yaml |
| | import numpy as np |
| | import time |
| | import datetime |
| | import json |
| | from pathlib import Path |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import DataLoader |
| | import torch.backends.cudnn as cudnn |
| | from test_res_ft import test |
| | from tensorboardX import SummaryWriter |
| | import utils |
| | from models.resnet import ModelRes_ft |
| | from test_res_ft import test |
| | from dataset.dataset_siim_acr import SIIM_ACR_Dataset |
| | from scheduler import create_scheduler |
| | from optim import create_optimizer |
| | import warnings |
| |
|
| | warnings.filterwarnings("ignore") |
| |
|
| |
|
| | def train( |
| | model, |
| | data_loader, |
| | optimizer, |
| | criterion, |
| | epoch, |
| | warmup_steps, |
| | device, |
| | scheduler, |
| | args, |
| | config, |
| | writer, |
| | ): |
| | model.train() |
| | metric_logger = utils.MetricLogger(delimiter=" ") |
| | metric_logger.add_meter( |
| | "lr", utils.SmoothedValue(window_size=50, fmt="{value:.6f}") |
| | ) |
| | metric_logger.add_meter( |
| | "loss", utils.SmoothedValue(window_size=50, fmt="{value:.6f}") |
| | ) |
| | metric_logger.update(loss=1.0) |
| | metric_logger.update(lr=scheduler._get_lr(epoch)[0]) |
| |
|
| | header = "Train Epoch: [{}]".format(epoch) |
| | print_freq = 50 |
| | step_size = 100 |
| | warmup_iterations = warmup_steps * step_size |
| | scalar_step = epoch * len(data_loader) |
| |
|
| | for i, sample in enumerate( |
| | metric_logger.log_every(data_loader, print_freq, header) |
| | ): |
| | image = sample["image"] |
| | label = sample["label"].float().to(device) |
| | input_image = image.to(device, non_blocking=True) |
| |
|
| | optimizer.zero_grad() |
| | pred_class = model(input_image) |
| |
|
| | loss = criterion(pred_class, label) |
| | loss.backward() |
| | optimizer.step() |
| | writer.add_scalar("loss/loss", loss, scalar_step) |
| | scalar_step += 1 |
| |
|
| | metric_logger.update(loss=loss.item()) |
| | if epoch == 0 and i % step_size == 0 and i <= warmup_iterations: |
| | scheduler.step(i // step_size) |
| | metric_logger.update(lr=scheduler._get_lr(epoch)[0]) |
| |
|
| | |
| | metric_logger.synchronize_between_processes() |
| | print("Averaged stats:", metric_logger.global_avg()) |
| | return { |
| | k: "{:.6f}".format(meter.global_avg) |
| | for k, meter in metric_logger.meters.items() |
| | } |
| |
|
| |
|
| | def valid(model, data_loader, criterion, epoch, device, config, writer): |
| | model.eval() |
| | val_scalar_step = epoch * len(data_loader) |
| | val_losses = [] |
| | for i, sample in enumerate(data_loader): |
| | image = sample["image"] |
| | label = sample["label"].float().to(device) |
| | input_image = image.to(device, non_blocking=True) |
| | with torch.no_grad(): |
| | pred_class = model(input_image) |
| | val_loss = criterion(pred_class, label) |
| | val_losses.append(val_loss.item()) |
| | writer.add_scalar("val_loss/loss", val_loss, val_scalar_step) |
| | val_scalar_step += 1 |
| | avg_val_loss = np.array(val_losses).mean() |
| | return avg_val_loss |
| |
|
| |
|
| | def main(args, config): |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print("Total CUDA devices: ", torch.cuda.device_count()) |
| | torch.set_default_tensor_type("torch.FloatTensor") |
| |
|
| | start_epoch = 0 |
| | max_epoch = config["schedular"]["epochs"] |
| | warmup_steps = config["schedular"]["warmup_epochs"] |
| |
|
| | |
| | print("Creating dataset") |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | model = ModelRes_ft(res_base_model="resnet50", out_size=1, use_base=args.use_base) |
| | if args.ddp: |
| | model = nn.DataParallel( |
| | model, device_ids=[i for i in range(torch.cuda.device_count())] |
| | ) |
| | model = model.to(device) |
| |
|
| | arg_opt = utils.AttrDict(config["optimizer"]) |
| | optimizer = create_optimizer(arg_opt, model) |
| | arg_sche = utils.AttrDict(config["schedular"]) |
| | lr_scheduler, _ = create_scheduler(arg_sche, optimizer) |
| |
|
| | criterion = nn.BCEWithLogitsLoss() |
| |
|
| | if args.checkpoint: |
| | checkpoint = torch.load(args.checkpoint, map_location="cpu") |
| | state_dict = checkpoint["model"] |
| | optimizer.load_state_dict(checkpoint["optimizer"]) |
| | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) |
| | start_epoch = checkpoint["epoch"] + 1 |
| | model.load_state_dict(state_dict) |
| | print("load checkpoint from %s" % args.checkpoint) |
| | elif args.pretrain_path: |
| | checkpoint = torch.load(args.pretrain_path, map_location="cpu") |
| | state_dict = checkpoint["model"] |
| | model_dict = model.state_dict() |
| | model_checkpoint = {k: v for k, v in state_dict.items() if k in model_dict} |
| | model_dict.update(model_checkpoint) |
| | model.load_state_dict(model_dict) |
| | print("load pretrain_path from %s" % args.pretrain_path) |
| |
|
| | print("Start training") |
| | start_time = time.time() |
| |
|
| | best_test_auc = 0.0 |
| | writer = SummaryWriter(os.path.join(args.output_dir, "log")) |
| | for epoch in range(start_epoch, max_epoch): |
| | if epoch > 0: |
| | lr_scheduler.step(epoch + warmup_steps) |
| | train_stats = train( |
| | model, |
| | train_dataloader, |
| | optimizer, |
| | criterion, |
| | epoch, |
| | warmup_steps, |
| | device, |
| | lr_scheduler, |
| | args, |
| | config, |
| | writer, |
| | ) |
| |
|
| | for k, v in train_stats.items(): |
| | train_loss_epoch = v |
| |
|
| | writer.add_scalar("loss/train_loss_epoch", float(train_loss_epoch), epoch) |
| | writer.add_scalar("loss/leaning_rate", lr_scheduler._get_lr(epoch)[0], epoch) |
| |
|
| | val_loss = valid( |
| | model, val_dataloader, criterion, epoch, device, config, writer |
| | ) |
| | writer.add_scalar("loss/val_loss_epoch", val_loss, epoch) |
| |
|
| | if utils.is_main_process(): |
| | log_stats = { |
| | **{f"train_{k}": v for k, v in train_stats.items()}, |
| | "epoch": epoch, |
| | "val_loss": val_loss.item(), |
| | } |
| | save_obj = { |
| | "model": model.state_dict(), |
| | "optimizer": optimizer.state_dict(), |
| | "lr_scheduler": lr_scheduler.state_dict(), |
| | "config": config, |
| | "epoch": epoch, |
| | } |
| | torch.save(save_obj, os.path.join(args.output_dir, "checkpoint_state.pth")) |
| |
|
| | with open(os.path.join(args.output_dir, "log.txt"), "a") as f: |
| | f.write(json.dumps(log_stats) + "\n") |
| |
|
| | test_auc = test(args, config) |
| | print(best_test_auc, test_auc) |
| | if test_auc > best_test_auc: |
| | save_obj = { |
| | "model": model.state_dict(), |
| | "optimizer": optimizer.state_dict(), |
| | "lr_scheduler": lr_scheduler.state_dict(), |
| | "config": config, |
| | "epoch": epoch, |
| | } |
| | torch.save(save_obj, os.path.join(args.output_dir, "best_test.pth")) |
| | best_test_auc = test_auc |
| | args.model_path = os.path.join(args.output_dir, "checkpoint_state.pth") |
| |
|
| | with open(os.path.join(args.output_dir, "log.txt"), "a") as f: |
| | f.write( |
| | "The average AUROC is {AUROC_avg:.4f}".format(AUROC_avg=test_auc) |
| | + "\n" |
| | ) |
| |
|
| | if epoch % 20 == 1 and epoch > 1: |
| | save_obj = { |
| | "model": model.state_dict(), |
| | "optimizer": optimizer.state_dict(), |
| | "lr_scheduler": lr_scheduler.state_dict(), |
| | "config": config, |
| | "epoch": epoch, |
| | } |
| | torch.save( |
| | save_obj, |
| | os.path.join(args.output_dir, "checkpoint_" + str(epoch) + ".pth"), |
| | ) |
| |
|
| | total_time = time.time() - start_time |
| | total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| | print("Training time {}".format(total_time_str)) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--config", |
| | default="Sample_Finetuning_SIIMACR/I1_classification/configs/Res_train.yaml", |
| | ) |
| | parser.add_argument("--checkpoint", default="") |
| | parser.add_argument("--model_path", default="") |
| | parser.add_argument("--pretrain_path", default="MeDSLIP_resnet50.pth") |
| | parser.add_argument( |
| | "--output_dir", default="Sample_Finetuning_SIIMACR/I1_classification/runs/" |
| | ) |
| | parser.add_argument("--device", default="cuda") |
| | parser.add_argument("--gpu", type=str, default="0", help="gpu") |
| | parser.add_argument("--use_base", type=bool, default=True) |
| | parser.add_argument("--ddp", action="store_true", help="use ddp") |
| | args = parser.parse_args() |
| |
|
| | config = yaml.load(open(args.config, "r"), Loader=yaml.Loader) |
| | args.output_dir = os.path.join(args.output_dir, str(config["percentage"])) |
| | from datetime import datetime |
| |
|
| | args.output_dir = os.path.join( |
| | args.output_dir, datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| | ) |
| | args.model_path = os.path.join(args.output_dir, "checkpoint_state.pth") |
| |
|
| | Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
| |
|
| | yaml.dump(config, open(os.path.join(args.output_dir, "config.yaml"), "w")) |
| |
|
| | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
| | torch.cuda.current_device() |
| | torch.cuda._initialized = True |
| |
|
| | main(args, config) |
| |
|