MRaCL / CGFormer /train.py
dianecy's picture
Upload folder using huggingface_hub
ea1014e verified
import argparse
import datetime
import os
import shutil
import sys
import time
import warnings
from functools import partial
import cv2
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torch.multiprocessing as mp #https://blog.csdn.net/hxxjxw/article/details/119839548
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data as data
from loguru import logger # https://hanjunqiang.blog.csdn.net/article/details/124779625
from torch.optim.lr_scheduler import MultiStepLR
import utils.config as config
import wandb
from utils.dataset import RefDataset
from engine.engine import train, validate
from model import build_segmenter
from utils.misc import (init_random_seed, set_random_seed, setup_logger,
worker_init_fn, build_scheduler) #, collate_fn)
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=UserWarning)
cv2.setNumThreads(0)
torch.cuda.empty_cache()
import deepspeed
from deepspeed.runtime.lr_schedules import WarmupLR
def get_parser():
parser = argparse.ArgumentParser(
description='Pytorch Referring Expression Segmentation')
parser.add_argument('--config',
default='path to xxx.yaml',
type=str,
help='config file')
parser.add_argument('--opts',
default=None,
nargs=argparse.REMAINDER,
help='override some settings in the config.')
parser.add_argument("--local_rank", type=int, default=0)
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
assert args.config is not None
cfg = config.load_cfg_from_cfg_file(args.config)
if args.opts is not None:
cfg = config.merge_cfg_from_list(cfg, args.opts)
return cfg
# @logger.catch
# def main():
# args = get_parser()
# args.manual_seed = init_random_seed(args.manual_seed)
# set_random_seed(args.manual_seed, deterministic=True)
# args.ngpus_per_node = torch.cuda.device_count()
# args.world_size = args.ngpus_per_node * args.world_size
# #mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args, ))
def main(args):
# local rank & global rank
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
args.output_dir = os.path.join(args.output_folder, args.exp_name)
args.dist_url = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
torch.cuda.set_device(args.gpu)
# logger
setup_logger(args.output_dir,
distributed_rank=args.gpu,
filename="train.log",
mode="a")
# dist init
# dist.init_process_group(backend=args.dist_backend,
# init_method=args.dist_url,
# world_size=args.world_size,
# rank=args.rank)
deepspeed.init_distributed(init_method=args.dist_url, #args.dist_backend,
world_size=args.world_size,
rank=args.rank)
print("dist init done")
# wandb
if args.rank == 0:
wandb.init(job_type="training",
mode="offline",
config=args,
project=args.exp_name,
name=args.exp_name,
tags=[args.dataset])
dist.barrier()
deepspeed_config = {
"train_batch_size": args.batch_size,
"gradient_accumulation_steps": 2,
"fp16": {
"enabled": True,
"auto_cast": False,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"consecutive_hysteresis": False,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 3,
"contiguous_gradients": True,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8,
"offload_optimizer": {
"device": "cpu"
},
"offload_param": {
"device": "cpu"
}
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": args.lr,
"weight_decay": args.weight_decay,
"betas": [
0.9,
0.999
],
"eps": 1e-8,
"amsgrad": False
}
},
}
# build model
model, param_list = build_segmenter(args)
# logger.info(model)
logger.info(args)
# build optimizer & lr scheduler
# optimizer = torch.optim.AdamW(param_list,
# lr=args.lr,
# weight_decay=args.weight_decay,
# amsgrad=args.amsgrad
# )
# build dataset
grad_acc_steps = deepspeed_config['gradient_accumulation_steps']
args.batch_size = int(args.batch_size / args.ngpus_per_node / grad_acc_steps)
args.batch_size_val = int(args.batch_size_val / args.ngpus_per_node)
args.workers = int(
(args.workers + args.ngpus_per_node - 1) / args.ngpus_per_node)
train_data = RefDataset(lmdb_dir=args.train_lmdb,
mask_dir=args.mask_root,
dataset=args.dataset,
split=args.train_split,
mode='train',
input_size=args.input_size,
word_length=args.word_len,
args=args)
val_data = RefDataset(lmdb_dir=args.val_lmdb,
mask_dir=args.mask_root,
dataset=args.dataset,
split=args.val_split,
mode='val',
input_size=args.input_size,
word_length=args.word_len,
args=args)
# build dataloader
init_fn = partial(worker_init_fn,
num_workers=args.workers,
rank=args.rank,
seed=args.manual_seed)
train_sampler = data.distributed.DistributedSampler(train_data,
shuffle=True)
val_sampler = data.distributed.DistributedSampler(val_data, shuffle=False)
train_loader = data.DataLoader(train_data,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True,
worker_init_fn=init_fn,
sampler=train_sampler,
#collate_fn=collate_fn,
drop_last=True)
val_loader = data.DataLoader(val_data,
batch_size=args.batch_size_val,
shuffle=False,
num_workers=args.workers_val,
pin_memory=True,
sampler=val_sampler,
drop_last=False,
#collate_fn=collate_fn,
)
#scheduler = WarmupLR(optimizer)
scaler = None # amp.GradScaler()
# torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / (len(train_loader) * args.epochs)) ** 0.9)
scheduler = partial(torch.optim.lr_scheduler.LambdaLR,
lr_lambda=lambda x: (1 - x / (len(train_data)/args.batch_size * args.epochs)) ** 0.9) #len(train_loader)
model, optimizer, _, scheduler = deepspeed.initialize(model=model,
config_params=deepspeed_config,
model_parameters=param_list,
lr_scheduler=scheduler,
dist_init_required=True)
best_IoU = 0.0
# resume
if args.resume:
# if os.path.isfile(args.resume):
# logger.info("=> loading checkpoint '{}'".format(args.resume))
# checkpoint = torch.load(
# args.resume, map_location=lambda storage, loc: storage.cuda())
# args.start_epoch = checkpoint['epoch']
# best_IoU = checkpoint["best_iou"]
# checkpoint['model_state_dict'].pop('decoder.tokens.weight')
# optimizer.load_state_dict(checkpoint['optimizer'])
# scheduler.load_state_dict(checkpoint['scheduler'])
# logger.info("=> loaded checkpoint '{}' (epoch {})".format(
# args.resume, checkpoint['epoch']))
if os.path.isfile(args.resume):
logger.info("=> loading checkpoint '{}'".format(args.resume, args.ckpt_id))
_, client_sd = model.load_checkpoint(args.resume)
else:
raise ValueError(
"=> resume failed! no checkpoint found at '{}'. Please check args.resume again!"
.format(args.resume))
# start training
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
epoch_log = epoch + 1
# shuffle loader
train_sampler.set_epoch(epoch_log)
# train
train(train_loader, model, optimizer, scheduler, scaler, epoch_log, args)
# evaluation
iou, prec_dict = validate(val_loader, model, epoch_log, args)
# save model
# if dist.get_rank() == 0:
# lastname = os.path.join(args.output_dir, "last_model.pth")
# torch.save(
# {
# 'epoch': epoch_log,
# 'cur_iou': iou,
# 'best_iou': best_IoU,
# 'prec': prec_dict,
# 'model_state_dict': model.module.state_dict(),
# 'optimizer': optimizer.state_dict(),
# 'scheduler': scheduler.state_dict()
# }, lastname)
# if iou >= best_IoU and epoch_log<50:
# best_IoU = iou
# bestname = os.path.join(args.output_dir, "best_model.pth")
# shutil.copyfile(lastname, bestname)
if dist.get_rank()==0:
found_best_epoch = [(iou>=best_IoU and epoch_log<50)]
else:
found_best_epoch = [False]
dist.broadcast_object_list(found_best_epoch, src=0, device=model.device)
model.save_checkpoint(args.output_dir, "last_model")
if found_best_epoch[0]:
model.save_checkpoint(args.output_dir, "best_model")
torch.cuda.empty_cache()
time.sleep(2)
if dist.get_rank() == 0:
wandb.finish()
logger.info("* Best IoU={} * ".format(best_IoU))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('* Training time {} *'.format(total_time_str))
if __name__ == '__main__':
args = get_parser()
args.manual_seed = init_random_seed(args.manual_seed)
set_random_seed(args.manual_seed, deterministic=True)
args.ngpus_per_node = torch.cuda.device_count()
args.world_size = args.ngpus_per_node * args.world_size
#mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args, ))
main(args)
sys.exit(0)