| | from tqdm import tqdm |
| | import network |
| | import utils |
| | import os |
| | import random |
| | import argparse |
| | import numpy as np |
| | import json |
| |
|
| | from torch.utils import data |
| | from datasets import VOCSegmentation, Cityscapes |
| | from utils import ext_transforms as et |
| | from metrics import StreamSegMetrics |
| | from torch.utils.tensorboard import SummaryWriter |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from PIL import Image |
| | import matplotlib |
| | import matplotlib.pyplot as plt |
| |
|
| |
|
| | def get_argparser(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--out_dir", type=str, default="run_0") |
| |
|
| | |
| | parser.add_argument("--data_root", type=str, default='', |
| | help="path to Dataset") |
| | parser.add_argument("--dataset", type=str, default='voc', |
| | choices=['voc'], help='Name of dataset') |
| | parser.add_argument("--num_classes", type=int, default=None, |
| | help="num classes (default: None)") |
| |
|
| | |
| | parser.add_argument("--model", type=str, default='deeplabv3plus_resnet101', |
| | choices=['deeplabv3plus_resnet101', 'deeplabv3plus_resnet50', 'deeplabv3plus_mobilenet', |
| | 'deeplabv3plus_xception', 'deeplabv3plus_hrnetv2_48', 'deeplabv3plus_hrnetv2_32', |
| | 'deeplabv3_resnet101', 'deeplabv3_resnet50', 'deeplabv3_mobilenet', |
| | 'deeplabv3_xception', 'deeplabv3_hrnetv2_48', 'deeplabv3_hrnetv2_32'], |
| | help='model name') |
| | parser.add_argument("--separable_conv", action='store_true', default=False, |
| | help="apply separable conv to decoder and aspp") |
| | parser.add_argument("--output_stride", type=int, default=16, choices=[8, 16]) |
| | |
| | |
| | parser.add_argument("--use_eoaNet", action='store_true', default=True, |
| | help="Use Entropy-Optimized Attention Network") |
| | parser.add_argument("--no_eoaNet", action='store_false', dest='use_eoaNet', |
| | help="Disable Entropy-Optimized Attention Network") |
| | parser.add_argument("--msa_scales", nargs='+', type=int, default=[1, 2, 4], |
| | help="Scales for Multi-Scale Attention") |
| | parser.add_argument("--eog_beta", type=float, default=0.3, |
| | help="Entropy threshold for Entropy-Optimized Gating") |
| |
|
| | |
| | parser.add_argument("--test_only", action='store_true', default=False) |
| | parser.add_argument("--save_val_results", action='store_true', default=False, |
| | help="save segmentation results to \"./results\"") |
| | parser.add_argument("--total_itrs", type=int, default=30e3, |
| | help="epoch number (default: 30k 30e3)") |
| | parser.add_argument("--lr", type=float, default=0.02, |
| | help="learning rate (default: 0.01)") |
| | parser.add_argument("--lr_policy", type=str, default='poly', choices=['poly', 'step'], |
| | help="learning rate scheduler policy") |
| | parser.add_argument("--step_size", type=int, default=10000) |
| | parser.add_argument("--crop_val", action='store_true', default=True, |
| | help='crop validation (default: False)') |
| | parser.add_argument("--batch_size", type=int, default=32, |
| | help='batch size (default: 16)') |
| | parser.add_argument("--val_batch_size", type=int, default=4, |
| | help='batch size for validation (default: 4)') |
| | parser.add_argument("--crop_size", type=int, default=513) |
| |
|
| | parser.add_argument("--ckpt", default=None, type=str, |
| | help="restore from checkpoint") |
| | parser.add_argument("--continue_training", action='store_true', default=False) |
| |
|
| | parser.add_argument("--loss_type", type=str, default='cross_entropy', |
| | choices=['cross_entropy', 'focal_loss'], help="loss type (default: False)") |
| | parser.add_argument("--gpu_id", type=str, default='0,1', |
| | help="GPU ID") |
| | parser.add_argument("--weight_decay", type=float, default=1e-4, |
| | help='weight decay (default: 1e-4)') |
| | parser.add_argument("--random_seed", type=int, default=1, |
| | help="random seed (default: 1)") |
| | parser.add_argument("--print_interval", type=int, default=10, |
| | help="print interval of loss (default: 10)") |
| | parser.add_argument("--val_interval", type=int, default=100, |
| | help="epoch interval for eval (default: 100)") |
| | parser.add_argument("--download", action='store_true', default=False, |
| | help="download datasets") |
| |
|
| | |
| | parser.add_argument("--year", type=str, default='2012_aug', |
| | choices=['2012_aug', '2012', '2011', '2009', '2008', '2007'], help='year of VOC') |
| | return parser |
| |
|
| |
|
| | def get_dataset(opts): |
| | """ Dataset And Augmentation |
| | """ |
| | if opts.dataset == 'voc': |
| | train_transform = et.ExtCompose([ |
| | |
| | et.ExtRandomScale((0.5, 2.0)), |
| | et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True), |
| | et.ExtRandomHorizontalFlip(), |
| | et.ExtToTensor(), |
| | et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225]), |
| | ]) |
| | if opts.crop_val: |
| | val_transform = et.ExtCompose([ |
| | et.ExtResize(opts.crop_size), |
| | et.ExtCenterCrop(opts.crop_size), |
| | et.ExtToTensor(), |
| | et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225]), |
| | ]) |
| | else: |
| | val_transform = et.ExtCompose([ |
| | et.ExtToTensor(), |
| | et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225]), |
| | ]) |
| | train_dst = VOCSegmentation(root=opts.data_root, year=opts.year, |
| | image_set='train', download=opts.download, transform=train_transform) |
| | val_dst = VOCSegmentation(root=opts.data_root, year=opts.year, |
| | image_set='val', download=False, transform=val_transform) |
| |
|
| | if opts.dataset == 'cityscapes': |
| | train_transform = et.ExtCompose([ |
| | |
| | et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)), |
| | et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), |
| | et.ExtRandomHorizontalFlip(), |
| | et.ExtToTensor(), |
| | et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225]), |
| | ]) |
| |
|
| | val_transform = et.ExtCompose([ |
| | |
| | et.ExtToTensor(), |
| | et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225]), |
| | ]) |
| |
|
| | train_dst = Cityscapes(root=opts.data_root, |
| | split='train', transform=train_transform) |
| | val_dst = Cityscapes(root=opts.data_root, |
| | split='val', transform=val_transform) |
| | return train_dst, val_dst |
| |
|
| |
|
| | def validate(opts, model, loader, device, metrics, ret_samples_ids=None): |
| | """Do validation and return specified samples""" |
| | metrics.reset() |
| | ret_samples = [] |
| | if opts.save_val_results: |
| | if not os.path.exists('results'): |
| | os.mkdir('results') |
| | denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225]) |
| | img_id = 0 |
| |
|
| | with torch.no_grad(): |
| | for i, (images, labels) in tqdm(enumerate(loader)): |
| |
|
| | images = images.to(device, dtype=torch.float32) |
| | labels = labels.to(device, dtype=torch.long) |
| |
|
| | outputs = model(images) |
| | preds = outputs.detach().max(dim=1)[1].cpu().numpy() |
| | targets = labels.cpu().numpy() |
| |
|
| | metrics.update(targets, preds) |
| | if ret_samples_ids is not None and i in ret_samples_ids: |
| | ret_samples.append( |
| | (images[0].detach().cpu().numpy(), targets[0], preds[0])) |
| |
|
| | if opts.save_val_results: |
| | for i in range(len(images)): |
| | image = images[i].detach().cpu().numpy() |
| | target = targets[i] |
| | pred = preds[i] |
| |
|
| | image = (denorm(image) * 255).transpose(1, 2, 0).astype(np.uint8) |
| | target = loader.dataset.decode_target(target).astype(np.uint8) |
| | pred = loader.dataset.decode_target(pred).astype(np.uint8) |
| |
|
| | Image.fromarray(image).save('results/%d_image.png' % img_id) |
| | Image.fromarray(target).save('results/%d_target.png' % img_id) |
| | Image.fromarray(pred).save('results/%d_pred.png' % img_id) |
| |
|
| | fig = plt.figure() |
| | plt.imshow(image) |
| | plt.axis('off') |
| | plt.imshow(pred, alpha=0.7) |
| | ax = plt.gca() |
| | ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator()) |
| | ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator()) |
| | plt.savefig('results/%d_overlay.png' % img_id, bbox_inches='tight', pad_inches=0) |
| | plt.close() |
| | img_id += 1 |
| |
|
| | score = metrics.get_results() |
| | return score, ret_samples |
| |
|
| | def main(opts): |
| | if opts.dataset.lower() == 'voc': |
| | opts.num_classes = 21 |
| | elif opts.dataset.lower() == 'cityscapes': |
| | opts.num_classes = 19 |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print("Device: %s" % device) |
| |
|
| | |
| | torch.manual_seed(opts.random_seed) |
| | np.random.seed(opts.random_seed) |
| | random.seed(opts.random_seed) |
| | |
| | |
| | writer = SummaryWriter(log_dir='logs') |
| | |
| | |
| | if opts.dataset == 'voc' and not opts.crop_val: |
| | opts.val_batch_size = 1 |
| |
|
| | train_dst, val_dst = get_dataset(opts) |
| | |
| | |
| | effective_batch_size = min(opts.batch_size, len(train_dst)) |
| | effective_val_batch_size = min(opts.val_batch_size, len(val_dst)) |
| | |
| | if effective_batch_size < opts.batch_size: |
| | print(f"Warning: Reducing batch size from {opts.batch_size} to {effective_batch_size} due to small dataset") |
| | |
| | train_loader = data.DataLoader( |
| | train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2, |
| | drop_last=False) |
| | val_loader = data.DataLoader( |
| | val_dst, batch_size=effective_val_batch_size, shuffle=True, num_workers=2) |
| | print("Dataset: %s, Train set: %d, Val set: %d" % |
| | (opts.dataset, len(train_dst), len(val_dst))) |
| |
|
| | |
| | model = network.modeling.__dict__[opts.model]( |
| | num_classes=opts.num_classes, |
| | output_stride=opts.output_stride, |
| | use_eoaNet=opts.use_eoaNet, |
| | msa_scales=opts.msa_scales, |
| | eog_beta=opts.eog_beta |
| | ) |
| | if opts.separable_conv and 'plus' in opts.model: |
| | network.convert_to_separable_conv(model.classifier) |
| | utils.set_bn_momentum(model.backbone, momentum=0.01) |
| |
|
| | |
| | metrics = StreamSegMetrics(opts.num_classes) |
| |
|
| | |
| | optimizer = torch.optim.SGD(params=[ |
| | {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr}, |
| | {'params': model.classifier.parameters(), 'lr': opts.lr}, |
| | ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) |
| | |
| | |
| | if opts.lr_policy == 'poly': |
| | scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9) |
| | elif opts.lr_policy == 'step': |
| | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1) |
| |
|
| | |
| | |
| | if opts.loss_type == 'focal_loss': |
| | criterion = utils.FocalLoss(ignore_index=255, size_average=True) |
| | elif opts.loss_type == 'cross_entropy': |
| | criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean') |
| |
|
| | def save_ckpt(path): |
| | """ save current model |
| | """ |
| | torch.save({ |
| | "cur_itrs": cur_itrs, |
| | "model_state": model.module.state_dict(), |
| | "optimizer_state": optimizer.state_dict(), |
| | "scheduler_state": scheduler.state_dict(), |
| | "best_score": best_score, |
| | }, path) |
| | print("Model saved as %s" % path) |
| | |
| | if not os.path.exists('checkpoints'): |
| | os.mkdir('checkpoints') |
| | |
| | |
| | best_score = 0.0 |
| | cur_itrs = 0 |
| | cur_epochs = 0 |
| | |
| | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) |
| | if opts.ckpt is not None and os.path.isfile(opts.ckpt): |
| | |
| | checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu')) |
| | model.load_state_dict(checkpoint["model_state"]) |
| | model = nn.DataParallel(model) |
| | model.to(device) |
| | if opts.continue_training: |
| | optimizer.load_state_dict(checkpoint["optimizer_state"]) |
| | scheduler.load_state_dict(checkpoint["scheduler_state"]) |
| | cur_itrs = checkpoint["cur_itrs"] |
| | best_score = checkpoint['best_score'] |
| | print("Training state restored from %s" % opts.ckpt) |
| | print("Model restored from %s" % opts.ckpt) |
| | del checkpoint |
| | else: |
| | print("[!] Retrain") |
| | model = nn.DataParallel(model) |
| | model.to(device) |
| |
|
| | |
| | denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| |
|
| | if opts.test_only: |
| | model.eval() |
| | val_score, ret_samples = validate( |
| | opts=opts, model=model, loader=val_loader, device=device, metrics=metrics) |
| | print(metrics.to_str(val_score)) |
| | writer.close() |
| | return |
| |
|
| | interval_loss = 0 |
| | latest_checkpoints = [] |
| | if not os.path.exists(f'checkpoints'): |
| | os.mkdir(f'checkpoints') |
| | while True: |
| | |
| | model.train() |
| | cur_epochs += 1 |
| | for (images, labels) in train_loader: |
| | cur_itrs += 1 |
| |
|
| | images = images.to(device, dtype=torch.float32) |
| | labels = labels.to(device, dtype=torch.long) |
| |
|
| | optimizer.zero_grad() |
| | outputs = model(images) |
| | loss = criterion(outputs, labels) |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | np_loss = loss.detach().cpu().numpy() |
| | interval_loss += np_loss |
| | |
| | writer.add_scalar('Loss/train', np_loss, cur_itrs) |
| |
|
| | if (cur_itrs) % 10 == 0: |
| | interval_loss = interval_loss / 10 |
| | print("Epoch %d, Itrs %d/%d, Loss=%f" % |
| | (cur_epochs, cur_itrs, opts.total_itrs, interval_loss)) |
| | interval_loss = 0.0 |
| |
|
| | if (cur_itrs) % opts.val_interval == 0: |
| | ckpt_path = f'checkpoints/latest_{cur_itrs}_{opts.model}_{opts.dataset}_os{opts.output_stride}.pth' |
| | save_ckpt(ckpt_path) |
| | latest_checkpoints.append(ckpt_path) |
| | |
| | if len(latest_checkpoints) > 2: |
| | |
| | oldest_ckpt_path = latest_checkpoints.pop(0) |
| | try: |
| | |
| | os.remove(oldest_ckpt_path) |
| | print(f"Successfully removed old checkpoint: {oldest_ckpt_path}") |
| | except FileNotFoundError: |
| | |
| | print(f"Warning: Could not remove checkpoint because it was not found: {oldest_ckpt_path}") |
| | except OSError as e: |
| | |
| | print(f"Error removing checkpoint {oldest_ckpt_path}: {e}") |
| | |
| | print("validation...") |
| | model.eval() |
| | val_score, ret_samples = validate( |
| | opts=opts, model=model, loader=val_loader, device=device, metrics=metrics) |
| | print(metrics.to_str(val_score)) |
| | |
| | writer.add_scalar('Metrics/Mean_IoU', val_score['Mean IoU'], cur_itrs) |
| | writer.add_scalar('Metrics/Overall_Acc', val_score['Overall Acc'], cur_itrs) |
| | writer.add_scalar('Metrics/Mean_Acc', val_score['Mean Acc'], cur_itrs) |
| |
|
| | if val_score['Mean IoU'] > best_score: |
| | best_score = val_score['Mean IoU'] |
| | save_ckpt(f'checkpoints/best_{opts.model}_{opts.dataset}_os{opts.output_stride}.pth') |
| | with open(f'checkpoints/best_score.txt', 'a') as f: |
| | f.write(f"iter:{cur_itrs}\n{str(best_score)}\n") |
| | with open(f"final_info.json", "w") as f: |
| | final_info = { |
| | "voc12_aug": { |
| | "means": { |
| | "mIoU": val_score['Mean IoU'], |
| | "OA": val_score['Overall Acc'], |
| | "mAcc": val_score['Mean IoU'] |
| | } |
| | } |
| | } |
| | json.dump(final_info, f, indent=4) |
| |
|
| | model.train() |
| | scheduler.step() |
| |
|
| | if cur_itrs >= opts.total_itrs: |
| | writer.close() |
| | return |
| |
|
| |
|
| | if __name__ == '__main__': |
| | args = get_argparser().parse_args() |
| | try: |
| | main(args) |
| | except Exception as e: |
| | import traceback |
| | print("Original error in subprocess:", flush=True) |
| | traceback.print_exc(file=open("traceback.log", "w")) |
| | raise |
| |
|