MRaCL / CGFormer /train_refzom_sbert_test.py
dianecy's picture
Upload folder using huggingface_hub
ea1014e verified
import argparse
import datetime
import os
import shutil
import sys
import time
import warnings
from functools import partial
import cv2
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data as data
from loguru import logger
from torch.optim.lr_scheduler import MultiStepLR
import utils.config as config
import wandb
from model import build_segmenter
# from engine.engine_verbonly import train, validate
# from engine.engine_verbonly_hardneg import train, validate
from utils.misc import (init_random_seed, set_random_seed, setup_logger,
worker_init_fn, build_scheduler, collate_fn)
warnings.filterwarnings("ignore")
cv2.setNumThreads(0)
def get_parser():
parser = argparse.ArgumentParser(
description='Pytorch Referring Expression Segmentation')
parser.add_argument('--config',
default='path to xxx.yaml',
type=str,
help='config file')
parser.add_argument('--opts',
default=None,
nargs=argparse.REMAINDER,
help='override some settings in the config.')
parser.add_argument('--local-rank',
type=int,
default=0,
help='local rank for distributed training')
args = parser.parse_args()
assert args.config is not None
cfg = config.load_cfg_from_cfg_file(args.config)
if args.opts is not None:
cfg = config.merge_cfg_from_list(cfg, args.opts)
return cfg
@logger.catch
def main():
args = get_parser()
args.manual_seed = init_random_seed(args.manual_seed)
set_random_seed(args.manual_seed, deterministic=True)
if 'LOCAL_RANK' in os.environ:
args.local_rank = int(os.environ['LOCAL_RANK'])
logger.info(f"LOCAL_RANK from env: {args.local_rank}")
if 'LOCAL_RANK' in os.environ:
main_worker_ddp(args)
else:
args.ngpus_per_node = torch.cuda.device_count()
args.world_size = args.ngpus_per_node * getattr(args, 'world_size', 1)
mp.spawn(main_worker_mp, nprocs=args.ngpus_per_node, args=(args,))
def main_worker_ddp(args):
args.output_dir = os.path.join(args.output_folder, args.exp_name)
args.gpu = args.local_rank
args.rank = args.local_rank
args.world_size = int(os.environ.get('WORLD_SIZE', 1))
torch.cuda.set_device(args.gpu)
setup_logger(args.output_dir,
distributed_rank=args.gpu,
filename="train.log",
mode="a")
logger.info(f"Starting with GPU: {args.gpu}, Rank: {args.rank}, World Size: {args.world_size}")
dist_url = os.environ.get('MASTER_ADDR', 'localhost') + ':' + os.environ.get('MASTER_PORT', '12355')
dist.init_process_group(backend=getattr(args, 'dist_backend', 'nccl'),
init_method=f"env://",
world_size=args.world_size,
rank=args.rank)
run_training(args)
def main_worker_mp(gpu, args):
args.output_dir = os.path.join(args.output_folder, args.exp_name)
# local rank & global rank
args.gpu = gpu
rank = getattr(args, 'rank', 0)
args.rank = rank * args.ngpus_per_node + gpu
torch.cuda.set_device(args.gpu)
setup_logger(args.output_dir,
distributed_rank=args.gpu,
filename="train.log",
mode="a")
dist_url = getattr(args, 'dist_url', f'tcp://localhost:12355')
dist.init_process_group(backend=getattr(args, 'dist_backend', 'nccl'),
init_method=dist_url,
world_size=args.world_size,
rank=args.rank)
run_training(args)
def run_training(args):
# wandb
if args.rank == 0:
wandb.init(job_type="training",
mode="offline",
config=args,
project=args.exp_name,
name=args.exp_name,
tags=[args.dataset])
dist.barrier()
# build model
if args.dataset == 'ref-zom' :
from engine.engine_refzom_sbert import train, validate
from utils.dataset_zom_sbert import RefZom_FilterDataset, Refzom_DistributedSampler
model, param_list = build_segmenter(args)
model = model.cuda(args.gpu)
if hasattr(model, 'text_encoder'):
model.text_encoder = model.text_encoder.cuda(args.gpu)
logger.info(f"Model moved to GPU: {args.gpu}")
logger.info(args)
# build optimizer & lr scheduler
optimizer = torch.optim.AdamW(param_list,
lr=args.lr,
weight_decay=args.weight_decay,
amsgrad=args.amsgrad
)
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[args.gpu],
find_unused_parameters=True
)
scaler = amp.GradScaler()
args.batch_size = int(args.batch_size / dist.get_world_size())
args.batch_size_val = int(args.batch_size_val / dist.get_world_size())
args.workers = int((args.workers + dist.get_world_size() - 1) / dist.get_world_size())
# build dataset
train_data = RefZom_FilterDataset(lmdb_dir=args.train_lmdb,
mask_dir=args.mask_root,
dataset=args.dataset,
split=args.train_split,
mode='train',
input_size=args.input_size,
word_length=args.word_len,
args=args
)
val_data = RefZom_FilterDataset(lmdb_dir=args.val_lmdb,
mask_dir=args.mask_root,
dataset=args.dataset,
split=args.val_split,
mode='test',
input_size=args.input_size,
word_length=args.word_len,
args=args
)
# build dataloader
init_fn = partial(worker_init_fn,
num_workers=args.workers,
rank=args.rank,
seed=args.manual_seed)
train_sampler = Refzom_DistributedSampler(
train_data, num_replicas=args.world_size, rank=args.rank, shuffle=True
)
val_sampler = data.distributed.DistributedSampler(val_data, shuffle=False)
train_loader = data.DataLoader(train_data,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True,
worker_init_fn=init_fn,
sampler=train_sampler,
collate_fn=collate_fn,
drop_last=True)
val_loader = data.DataLoader(val_data,
batch_size=args.batch_size_val,
shuffle=False,
num_workers=args.workers_val,
pin_memory=True,
sampler=val_sampler,
drop_last=False,
)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lambda x: (1 - x / (len(train_loader) * args.epochs)) ** 0.9)
best_IoU = 0.0
best_oIoU = 0.0
# resume
if args.resume:
if os.path.isfile(args.resume):
logger.info("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(
args.resume, map_location=lambda storage, loc: storage.cuda())
args.start_epoch = checkpoint['epoch']
best_IoU = checkpoint["best_iou"]
checkpoint['model_state_dict'].pop('decoder.tokens.weight')
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
logger.info("=> loaded checkpoint '{}' (epoch {})".format(
args.resume, checkpoint['epoch']))
else:
raise ValueError(
"=> resume failed! no checkpoint found at '{}'. Please check args.resume again!"
.format(args.resume))
# start training
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
epoch_log = epoch + 1
# shuffle loader
train_sampler.set_epoch(epoch_log)
# train
# train(train_loader, model, optimizer, scheduler, scaler, epoch_log, args)
# torch.cuda.empty_cache()
# dist.barrier()
# evaluation
print("Start Evaluation : epoch ", epoch_log)
iou, oiou, prec_dict, mean_acc = validate(val_loader, model, epoch_log, args)
# save model
if dist.get_rank() == 0:
lastname = os.path.join(args.output_dir, "last_model.pth")
torch.save(
{
'epoch': epoch_log,
'cur_iou': iou,
'best_iou': best_IoU,
'best_oiou' : best_oIoU,
'prec': prec_dict,
'mean_acc' : mean_acc,
'model_state_dict': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict()
}, lastname)
if iou >= best_IoU and epoch_log<50:
best_IoU = iou
bestname = os.path.join(args.output_dir, "best_model_miou.pth")
shutil.copyfile(lastname, bestname)
if oiou >= best_oIoU and epoch_log<50:
best_oIoU = oiou
bestname_oiou = os.path.join(args.output_dir, "best_model_oiou.pth")
shutil.copyfile(lastname, bestname_oiou)
torch.cuda.empty_cache()
time.sleep(2)
if dist.get_rank() == 0:
try:
wandb.finish()
except AttributeError:
logger.warning("Failed to properly finish wandb run due to StreamToLoguru conflict")
logger.info("* Best IoU={} * ".format(best_IoU))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('* Training time {} *'.format(total_time_str))
if __name__ == '__main__':
main()
sys.exit(0)