|
|
import argparse |
|
|
import datetime |
|
|
import os |
|
|
import shutil |
|
|
import sys |
|
|
import time |
|
|
import warnings |
|
|
from functools import partial |
|
|
|
|
|
import cv2 |
|
|
import torch |
|
|
import torch.cuda.amp as amp |
|
|
import torch.distributed as dist |
|
|
import torch.multiprocessing as mp |
|
|
import torch.nn as nn |
|
|
import torch.nn.parallel |
|
|
import torch.optim |
|
|
import torch.utils.data as data |
|
|
from loguru import logger |
|
|
from torch.optim.lr_scheduler import MultiStepLR |
|
|
|
|
|
import utils.config as config |
|
|
import wandb |
|
|
from utils.dataset_sbert import RefDataset_gref |
|
|
from engine.engine_gref import train, validate |
|
|
from model import build_segmenter |
|
|
from utils.misc import (init_random_seed, set_random_seed, setup_logger, |
|
|
worker_init_fn, build_scheduler, collate_fn) |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
cv2.setNumThreads(0) |
|
|
|
|
|
|
|
|
def get_parser(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description='Pytorch Referring Expression Segmentation') |
|
|
parser.add_argument('--config', |
|
|
default='path to xxx.yaml', |
|
|
type=str, |
|
|
help='config file') |
|
|
parser.add_argument('--opts', |
|
|
default=None, |
|
|
nargs=argparse.REMAINDER, |
|
|
help='override some settings in the config.') |
|
|
parser.add_argument('--local-rank', |
|
|
type=int, |
|
|
default=0, |
|
|
help='local rank for distributed training') |
|
|
|
|
|
args = parser.parse_args() |
|
|
assert args.config is not None |
|
|
cfg = config.load_cfg_from_cfg_file(args.config) |
|
|
if args.opts is not None: |
|
|
cfg = config.merge_cfg_from_list(cfg, args.opts) |
|
|
return cfg |
|
|
|
|
|
|
|
|
@logger.catch |
|
|
def main(): |
|
|
args = get_parser() |
|
|
args.manual_seed = init_random_seed(args.manual_seed) |
|
|
set_random_seed(args.manual_seed, deterministic=True) |
|
|
|
|
|
if 'LOCAL_RANK' in os.environ: |
|
|
args.local_rank = int(os.environ['LOCAL_RANK']) |
|
|
logger.info(f"LOCAL_RANK from env: {args.local_rank}") |
|
|
|
|
|
if 'LOCAL_RANK' in os.environ: |
|
|
main_worker_ddp(args) |
|
|
else: |
|
|
args.ngpus_per_node = torch.cuda.device_count() |
|
|
args.world_size = args.ngpus_per_node * getattr(args, 'world_size', 1) |
|
|
mp.spawn(main_worker_mp, nprocs=args.ngpus_per_node, args=(args,)) |
|
|
|
|
|
|
|
|
def main_worker_ddp(args): |
|
|
args.output_dir = os.path.join(args.output_folder, args.exp_name) |
|
|
|
|
|
args.gpu = args.local_rank |
|
|
args.rank = args.local_rank |
|
|
args.world_size = int(os.environ.get('WORLD_SIZE', 1)) |
|
|
|
|
|
torch.cuda.set_device(args.gpu) |
|
|
|
|
|
setup_logger(args.output_dir, |
|
|
distributed_rank=args.gpu, |
|
|
filename="train.log", |
|
|
mode="a") |
|
|
|
|
|
logger.info(f"Starting with GPU: {args.gpu}, Rank: {args.rank}, World Size: {args.world_size}") |
|
|
|
|
|
dist_url = os.environ.get('MASTER_ADDR', 'localhost') + ':' + os.environ.get('MASTER_PORT', '12355') |
|
|
dist.init_process_group(backend=getattr(args, 'dist_backend', 'nccl'), |
|
|
init_method=f"env://", |
|
|
world_size=args.world_size, |
|
|
rank=args.rank) |
|
|
|
|
|
run_training(args) |
|
|
|
|
|
|
|
|
def main_worker_mp(gpu, args): |
|
|
args.output_dir = os.path.join(args.output_folder, args.exp_name) |
|
|
|
|
|
|
|
|
args.gpu = gpu |
|
|
|
|
|
rank = getattr(args, 'rank', 0) |
|
|
args.rank = rank * args.ngpus_per_node + gpu |
|
|
torch.cuda.set_device(args.gpu) |
|
|
|
|
|
setup_logger(args.output_dir, |
|
|
distributed_rank=args.gpu, |
|
|
filename="train.log", |
|
|
mode="a") |
|
|
|
|
|
dist_url = getattr(args, 'dist_url', f'tcp://localhost:12355') |
|
|
dist.init_process_group(backend=getattr(args, 'dist_backend', 'nccl'), |
|
|
init_method=dist_url, |
|
|
world_size=args.world_size, |
|
|
rank=args.rank) |
|
|
|
|
|
run_training(args) |
|
|
|
|
|
|
|
|
def run_training(args): |
|
|
|
|
|
if args.rank == 0: |
|
|
wandb.init(job_type="training", |
|
|
mode="offline", |
|
|
config=args, |
|
|
project=args.exp_name, |
|
|
name=args.exp_name, |
|
|
tags=[args.dataset]) |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
model, param_list = build_segmenter(args) |
|
|
|
|
|
model = model.cuda(args.gpu) |
|
|
|
|
|
if hasattr(model, 'text_encoder'): |
|
|
model.text_encoder = model.text_encoder.cuda(args.gpu) |
|
|
|
|
|
logger.info(f"Model moved to GPU: {args.gpu}") |
|
|
logger.info(args) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(param_list, |
|
|
lr=args.lr, |
|
|
weight_decay=args.weight_decay, |
|
|
amsgrad=args.amsgrad |
|
|
) |
|
|
|
|
|
model = torch.nn.parallel.DistributedDataParallel( |
|
|
model, |
|
|
device_ids=[args.gpu], |
|
|
find_unused_parameters=True |
|
|
) |
|
|
|
|
|
scaler = amp.GradScaler() |
|
|
|
|
|
args.batch_size = int(args.batch_size / dist.get_world_size()) |
|
|
args.batch_size_val = int(args.batch_size_val / dist.get_world_size()) |
|
|
args.workers = int((args.workers + dist.get_world_size() - 1) / dist.get_world_size()) |
|
|
|
|
|
|
|
|
train_data = RefDataset_gref(lmdb_dir=args.train_lmdb, |
|
|
mask_dir=args.mask_root, |
|
|
dataset=args.dataset, |
|
|
split=args.train_split, |
|
|
mode='train', |
|
|
input_size=args.input_size, |
|
|
word_length=args.word_len, |
|
|
args=args |
|
|
) |
|
|
val_data = RefDataset_gref(lmdb_dir=args.val_lmdb, |
|
|
mask_dir=args.mask_root, |
|
|
dataset=args.dataset, |
|
|
split=args.val_split, |
|
|
mode='val', |
|
|
input_size=args.input_size, |
|
|
word_length=args.word_len, |
|
|
args=args |
|
|
) |
|
|
|
|
|
|
|
|
init_fn = partial(worker_init_fn, |
|
|
num_workers=args.workers, |
|
|
rank=args.rank, |
|
|
seed=args.manual_seed) |
|
|
train_sampler = data.distributed.DistributedSampler(train_data, |
|
|
shuffle=True) |
|
|
val_sampler = data.distributed.DistributedSampler(val_data, shuffle=False) |
|
|
|
|
|
train_loader = data.DataLoader(train_data, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=args.workers, |
|
|
pin_memory=True, |
|
|
worker_init_fn=init_fn, |
|
|
sampler=train_sampler, |
|
|
collate_fn=collate_fn, |
|
|
drop_last=True) |
|
|
val_loader = data.DataLoader(val_data, |
|
|
batch_size=args.batch_size_val, |
|
|
shuffle=False, |
|
|
num_workers=args.workers_val, |
|
|
pin_memory=True, |
|
|
sampler=val_sampler, |
|
|
drop_last=False, |
|
|
collate_fn=collate_fn, |
|
|
) |
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR( |
|
|
optimizer, lambda x: (1 - x / (len(train_loader) * args.epochs)) ** 0.9) |
|
|
|
|
|
best_IoU = 0.0 |
|
|
|
|
|
if args.resume: |
|
|
if os.path.isfile(args.resume): |
|
|
logger.info("=> loading checkpoint '{}'".format(args.resume)) |
|
|
checkpoint = torch.load( |
|
|
args.resume, map_location=lambda storage, loc: storage.cuda()) |
|
|
args.start_epoch = checkpoint['epoch'] |
|
|
best_IoU = checkpoint["best_iou"] |
|
|
checkpoint['model_state_dict'].pop('decoder.tokens.weight') |
|
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
scheduler.load_state_dict(checkpoint['scheduler']) |
|
|
logger.info("=> loaded checkpoint '{}' (epoch {})".format( |
|
|
args.resume, checkpoint['epoch'])) |
|
|
else: |
|
|
raise ValueError( |
|
|
"=> resume failed! no checkpoint found at '{}'. Please check args.resume again!" |
|
|
.format(args.resume)) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
for epoch in range(args.start_epoch, args.epochs): |
|
|
epoch_log = epoch + 1 |
|
|
|
|
|
|
|
|
train_sampler.set_epoch(epoch_log) |
|
|
|
|
|
|
|
|
train(train_loader, model, optimizer, scheduler, scaler, epoch_log, args) |
|
|
|
|
|
|
|
|
iou, prec_dict = validate(val_loader, model, epoch_log, args) |
|
|
|
|
|
|
|
|
if dist.get_rank() == 0: |
|
|
lastname = os.path.join(args.output_dir, "last_model.pth") |
|
|
torch.save( |
|
|
{ |
|
|
'epoch': epoch_log, |
|
|
'cur_iou': iou, |
|
|
'best_iou': best_IoU, |
|
|
'prec': prec_dict, |
|
|
'model_state_dict': model.module.state_dict(), |
|
|
'optimizer': optimizer.state_dict(), |
|
|
'scheduler': scheduler.state_dict() |
|
|
}, lastname) |
|
|
if iou >= best_IoU and epoch_log<50: |
|
|
best_IoU = iou |
|
|
bestname = os.path.join(args.output_dir, "best_model.pth") |
|
|
shutil.copyfile(lastname, bestname) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
time.sleep(2) |
|
|
if dist.get_rank() == 0: |
|
|
try: |
|
|
wandb.finish() |
|
|
except AttributeError: |
|
|
logger.warning("Failed to properly finish wandb run due to StreamToLoguru conflict") |
|
|
|
|
|
logger.info("* Best IoU={} * ".format(best_IoU)) |
|
|
total_time = time.time() - start_time |
|
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
|
logger.info('* Training time {} *'.format(total_time_str)) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
sys.exit(0) |