|
|
import argparse |
|
|
import datetime |
|
|
import os |
|
|
import shutil |
|
|
import sys |
|
|
import time |
|
|
import warnings |
|
|
from functools import partial |
|
|
|
|
|
import cv2 |
|
|
import torch |
|
|
import torch.cuda.amp as amp |
|
|
import torch.distributed as dist |
|
|
import torch.multiprocessing as mp |
|
|
import torch.nn as nn |
|
|
import torch.nn.parallel |
|
|
import torch.optim |
|
|
import torch.utils.data as data |
|
|
from loguru import logger |
|
|
from torch.optim.lr_scheduler import MultiStepLR |
|
|
|
|
|
import utils.config as config |
|
|
import wandb |
|
|
from utils.dataset import RefDataset |
|
|
from engine.engine import train, validate |
|
|
from model import build_segmenter |
|
|
from utils.misc import (init_random_seed, set_random_seed, setup_logger, |
|
|
worker_init_fn, build_scheduler) |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
cv2.setNumThreads(0) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
import deepspeed |
|
|
from deepspeed.runtime.lr_schedules import WarmupLR |
|
|
|
|
|
|
|
|
def get_parser(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description='Pytorch Referring Expression Segmentation') |
|
|
parser.add_argument('--config', |
|
|
default='path to xxx.yaml', |
|
|
type=str, |
|
|
help='config file') |
|
|
parser.add_argument('--opts', |
|
|
default=None, |
|
|
nargs=argparse.REMAINDER, |
|
|
help='override some settings in the config.') |
|
|
parser.add_argument("--local_rank", type=int, default=0) |
|
|
parser = deepspeed.add_config_arguments(parser) |
|
|
args = parser.parse_args() |
|
|
assert args.config is not None |
|
|
cfg = config.load_cfg_from_cfg_file(args.config) |
|
|
if args.opts is not None: |
|
|
cfg = config.merge_cfg_from_list(cfg, args.opts) |
|
|
return cfg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(args): |
|
|
|
|
|
|
|
|
args.rank = int(os.environ["RANK"]) |
|
|
args.world_size = int(os.environ['WORLD_SIZE']) |
|
|
args.gpu = int(os.environ['LOCAL_RANK']) |
|
|
args.output_dir = os.path.join(args.output_folder, args.exp_name) |
|
|
args.dist_url = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" |
|
|
torch.cuda.set_device(args.gpu) |
|
|
|
|
|
|
|
|
setup_logger(args.output_dir, |
|
|
distributed_rank=args.gpu, |
|
|
filename="train.log", |
|
|
mode="a") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepspeed.init_distributed(init_method=args.dist_url, |
|
|
world_size=args.world_size, |
|
|
rank=args.rank) |
|
|
print("dist init done") |
|
|
|
|
|
|
|
|
if args.rank == 0: |
|
|
wandb.init(job_type="training", |
|
|
mode="offline", |
|
|
config=args, |
|
|
project=args.exp_name, |
|
|
name=args.exp_name, |
|
|
tags=[args.dataset]) |
|
|
dist.barrier() |
|
|
|
|
|
deepspeed_config = { |
|
|
"train_batch_size": args.batch_size, |
|
|
"gradient_accumulation_steps": 2, |
|
|
"fp16": { |
|
|
"enabled": True, |
|
|
"auto_cast": False, |
|
|
"loss_scale": 0, |
|
|
"initial_scale_power": 16, |
|
|
"loss_scale_window": 1000, |
|
|
"hysteresis": 2, |
|
|
"consecutive_hysteresis": False, |
|
|
"min_loss_scale": 1 |
|
|
}, |
|
|
"zero_optimization": { |
|
|
"stage": 3, |
|
|
"contiguous_gradients": True, |
|
|
"overlap_comm": True, |
|
|
"reduce_scatter": True, |
|
|
"reduce_bucket_size": 5e8, |
|
|
"allgather_bucket_size": 5e8, |
|
|
"offload_optimizer": { |
|
|
"device": "cpu" |
|
|
}, |
|
|
"offload_param": { |
|
|
"device": "cpu" |
|
|
} |
|
|
}, |
|
|
"optimizer": { |
|
|
"type": "AdamW", |
|
|
"params": { |
|
|
"lr": args.lr, |
|
|
"weight_decay": args.weight_decay, |
|
|
"betas": [ |
|
|
0.9, |
|
|
0.999 |
|
|
], |
|
|
"eps": 1e-8, |
|
|
"amsgrad": False |
|
|
} |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
model, param_list = build_segmenter(args) |
|
|
|
|
|
logger.info(args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
grad_acc_steps = deepspeed_config['gradient_accumulation_steps'] |
|
|
args.batch_size = int(args.batch_size / args.ngpus_per_node / grad_acc_steps) |
|
|
args.batch_size_val = int(args.batch_size_val / args.ngpus_per_node) |
|
|
args.workers = int( |
|
|
(args.workers + args.ngpus_per_node - 1) / args.ngpus_per_node) |
|
|
train_data = RefDataset(lmdb_dir=args.train_lmdb, |
|
|
mask_dir=args.mask_root, |
|
|
dataset=args.dataset, |
|
|
split=args.train_split, |
|
|
mode='train', |
|
|
input_size=args.input_size, |
|
|
word_length=args.word_len, |
|
|
args=args) |
|
|
val_data = RefDataset(lmdb_dir=args.val_lmdb, |
|
|
mask_dir=args.mask_root, |
|
|
dataset=args.dataset, |
|
|
split=args.val_split, |
|
|
mode='val', |
|
|
input_size=args.input_size, |
|
|
word_length=args.word_len, |
|
|
args=args) |
|
|
|
|
|
|
|
|
init_fn = partial(worker_init_fn, |
|
|
num_workers=args.workers, |
|
|
rank=args.rank, |
|
|
seed=args.manual_seed) |
|
|
train_sampler = data.distributed.DistributedSampler(train_data, |
|
|
shuffle=True) |
|
|
val_sampler = data.distributed.DistributedSampler(val_data, shuffle=False) |
|
|
|
|
|
train_loader = data.DataLoader(train_data, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=args.workers, |
|
|
pin_memory=True, |
|
|
worker_init_fn=init_fn, |
|
|
sampler=train_sampler, |
|
|
|
|
|
drop_last=True) |
|
|
val_loader = data.DataLoader(val_data, |
|
|
batch_size=args.batch_size_val, |
|
|
shuffle=False, |
|
|
num_workers=args.workers_val, |
|
|
pin_memory=True, |
|
|
sampler=val_sampler, |
|
|
drop_last=False, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
scaler = None |
|
|
|
|
|
scheduler = partial(torch.optim.lr_scheduler.LambdaLR, |
|
|
lr_lambda=lambda x: (1 - x / (len(train_data)/args.batch_size * args.epochs)) ** 0.9) |
|
|
model, optimizer, _, scheduler = deepspeed.initialize(model=model, |
|
|
config_params=deepspeed_config, |
|
|
model_parameters=param_list, |
|
|
lr_scheduler=scheduler, |
|
|
dist_init_required=True) |
|
|
|
|
|
best_IoU = 0.0 |
|
|
|
|
|
if args.resume: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.path.isfile(args.resume): |
|
|
logger.info("=> loading checkpoint '{}'".format(args.resume, args.ckpt_id)) |
|
|
_, client_sd = model.load_checkpoint(args.resume) |
|
|
else: |
|
|
raise ValueError( |
|
|
"=> resume failed! no checkpoint found at '{}'. Please check args.resume again!" |
|
|
.format(args.resume)) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
for epoch in range(args.start_epoch, args.epochs): |
|
|
epoch_log = epoch + 1 |
|
|
|
|
|
|
|
|
train_sampler.set_epoch(epoch_log) |
|
|
|
|
|
|
|
|
train(train_loader, model, optimizer, scheduler, scaler, epoch_log, args) |
|
|
|
|
|
|
|
|
iou, prec_dict = validate(val_loader, model, epoch_log, args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if dist.get_rank()==0: |
|
|
found_best_epoch = [(iou>=best_IoU and epoch_log<50)] |
|
|
else: |
|
|
found_best_epoch = [False] |
|
|
dist.broadcast_object_list(found_best_epoch, src=0, device=model.device) |
|
|
|
|
|
model.save_checkpoint(args.output_dir, "last_model") |
|
|
if found_best_epoch[0]: |
|
|
model.save_checkpoint(args.output_dir, "best_model") |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
time.sleep(2) |
|
|
if dist.get_rank() == 0: |
|
|
wandb.finish() |
|
|
|
|
|
logger.info("* Best IoU={} * ".format(best_IoU)) |
|
|
total_time = time.time() - start_time |
|
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
|
logger.info('* Training time {} *'.format(total_time_str)) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
args = get_parser() |
|
|
args.manual_seed = init_random_seed(args.manual_seed) |
|
|
set_random_seed(args.manual_seed, deterministic=True) |
|
|
args.ngpus_per_node = torch.cuda.device_count() |
|
|
args.world_size = args.ngpus_per_node * args.world_size |
|
|
|
|
|
main(args) |
|
|
sys.exit(0) |
|
|
|