MRaCL / CGFormer /train_gref.py
dianecy's picture
Upload folder using huggingface_hub
ea1014e verified
import argparse
import datetime
import os
import shutil
import sys
import time
import warnings
from functools import partial
import cv2
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data as data
from loguru import logger
from torch.optim.lr_scheduler import MultiStepLR
import utils.config as config
import wandb
from utils.dataset_sbert import RefDataset_gref
from engine.engine_gref import train, validate
from model import build_segmenter
from utils.misc import (init_random_seed, set_random_seed, setup_logger,
worker_init_fn, build_scheduler, collate_fn)
warnings.filterwarnings("ignore")
cv2.setNumThreads(0)
def get_parser():
parser = argparse.ArgumentParser(
description='Pytorch Referring Expression Segmentation')
parser.add_argument('--config',
default='path to xxx.yaml',
type=str,
help='config file')
parser.add_argument('--opts',
default=None,
nargs=argparse.REMAINDER,
help='override some settings in the config.')
parser.add_argument('--local-rank',
type=int,
default=0,
help='local rank for distributed training')
args = parser.parse_args()
assert args.config is not None
cfg = config.load_cfg_from_cfg_file(args.config)
if args.opts is not None:
cfg = config.merge_cfg_from_list(cfg, args.opts)
return cfg
@logger.catch
def main():
args = get_parser()
args.manual_seed = init_random_seed(args.manual_seed)
set_random_seed(args.manual_seed, deterministic=True)
if 'LOCAL_RANK' in os.environ:
args.local_rank = int(os.environ['LOCAL_RANK'])
logger.info(f"LOCAL_RANK from env: {args.local_rank}")
if 'LOCAL_RANK' in os.environ:
main_worker_ddp(args)
else:
args.ngpus_per_node = torch.cuda.device_count()
args.world_size = args.ngpus_per_node * getattr(args, 'world_size', 1)
mp.spawn(main_worker_mp, nprocs=args.ngpus_per_node, args=(args,))
def main_worker_ddp(args):
args.output_dir = os.path.join(args.output_folder, args.exp_name)
args.gpu = args.local_rank
args.rank = args.local_rank
args.world_size = int(os.environ.get('WORLD_SIZE', 1))
torch.cuda.set_device(args.gpu)
setup_logger(args.output_dir,
distributed_rank=args.gpu,
filename="train.log",
mode="a")
logger.info(f"Starting with GPU: {args.gpu}, Rank: {args.rank}, World Size: {args.world_size}")
dist_url = os.environ.get('MASTER_ADDR', 'localhost') + ':' + os.environ.get('MASTER_PORT', '12355')
dist.init_process_group(backend=getattr(args, 'dist_backend', 'nccl'),
init_method=f"env://",
world_size=args.world_size,
rank=args.rank)
run_training(args)
def main_worker_mp(gpu, args):
args.output_dir = os.path.join(args.output_folder, args.exp_name)
# local rank & global rank
args.gpu = gpu
rank = getattr(args, 'rank', 0)
args.rank = rank * args.ngpus_per_node + gpu
torch.cuda.set_device(args.gpu)
setup_logger(args.output_dir,
distributed_rank=args.gpu,
filename="train.log",
mode="a")
dist_url = getattr(args, 'dist_url', f'tcp://localhost:12355')
dist.init_process_group(backend=getattr(args, 'dist_backend', 'nccl'),
init_method=dist_url,
world_size=args.world_size,
rank=args.rank)
run_training(args)
def run_training(args):
# wandb
if args.rank == 0:
wandb.init(job_type="training",
mode="offline",
config=args,
project=args.exp_name,
name=args.exp_name,
tags=[args.dataset])
dist.barrier()
# build model
model, param_list = build_segmenter(args)
model = model.cuda(args.gpu)
if hasattr(model, 'text_encoder'):
model.text_encoder = model.text_encoder.cuda(args.gpu)
logger.info(f"Model moved to GPU: {args.gpu}")
logger.info(args)
# build optimizer & lr scheduler
optimizer = torch.optim.AdamW(param_list,
lr=args.lr,
weight_decay=args.weight_decay,
amsgrad=args.amsgrad
)
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[args.gpu],
find_unused_parameters=True
)
scaler = amp.GradScaler()
args.batch_size = int(args.batch_size / dist.get_world_size())
args.batch_size_val = int(args.batch_size_val / dist.get_world_size())
args.workers = int((args.workers + dist.get_world_size() - 1) / dist.get_world_size())
# build dataset
train_data = RefDataset_gref(lmdb_dir=args.train_lmdb,
mask_dir=args.mask_root,
dataset=args.dataset,
split=args.train_split,
mode='train',
input_size=args.input_size,
word_length=args.word_len,
args=args
)
val_data = RefDataset_gref(lmdb_dir=args.val_lmdb,
mask_dir=args.mask_root,
dataset=args.dataset,
split=args.val_split,
mode='val',
input_size=args.input_size,
word_length=args.word_len,
args=args
)
# build dataloader
init_fn = partial(worker_init_fn,
num_workers=args.workers,
rank=args.rank,
seed=args.manual_seed)
train_sampler = data.distributed.DistributedSampler(train_data,
shuffle=True)
val_sampler = data.distributed.DistributedSampler(val_data, shuffle=False)
train_loader = data.DataLoader(train_data,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True,
worker_init_fn=init_fn,
sampler=train_sampler,
collate_fn=collate_fn,
drop_last=True)
val_loader = data.DataLoader(val_data,
batch_size=args.batch_size_val,
shuffle=False,
num_workers=args.workers_val,
pin_memory=True,
sampler=val_sampler,
drop_last=False,
collate_fn=collate_fn,
)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lambda x: (1 - x / (len(train_loader) * args.epochs)) ** 0.9)
best_IoU = 0.0
# resume
if args.resume:
if os.path.isfile(args.resume):
logger.info("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(
args.resume, map_location=lambda storage, loc: storage.cuda())
args.start_epoch = checkpoint['epoch']
best_IoU = checkpoint["best_iou"]
checkpoint['model_state_dict'].pop('decoder.tokens.weight')
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
logger.info("=> loaded checkpoint '{}' (epoch {})".format(
args.resume, checkpoint['epoch']))
else:
raise ValueError(
"=> resume failed! no checkpoint found at '{}'. Please check args.resume again!"
.format(args.resume))
# start training
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
epoch_log = epoch + 1
# shuffle loader
train_sampler.set_epoch(epoch_log)
# train
train(train_loader, model, optimizer, scheduler, scaler, epoch_log, args)
# evaluation
iou, prec_dict = validate(val_loader, model, epoch_log, args)
# save model
if dist.get_rank() == 0:
lastname = os.path.join(args.output_dir, "last_model.pth")
torch.save(
{
'epoch': epoch_log,
'cur_iou': iou,
'best_iou': best_IoU,
'prec': prec_dict,
'model_state_dict': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict()
}, lastname)
if iou >= best_IoU and epoch_log<50:
best_IoU = iou
bestname = os.path.join(args.output_dir, "best_model.pth")
shutil.copyfile(lastname, bestname)
torch.cuda.empty_cache()
time.sleep(2)
if dist.get_rank() == 0:
try:
wandb.finish()
except AttributeError:
logger.warning("Failed to properly finish wandb run due to StreamToLoguru conflict")
logger.info("* Best IoU={} * ".format(best_IoU))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('* Training time {} *'.format(total_time_str))
if __name__ == '__main__':
main()
sys.exit(0)