training_sem / runner /train.py
kai-2054's picture
Initial commit: add code
cb0ad2d
import shutil
import torch
import tqdm
import json
import os
import sys
sys.path.append('./')
sys.path.append('../')
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR
from collections import defaultdict
from libs.utils.cal_f1 import pred_result_to_table, table_to_relations, evaluate_f1
from libs.utils.comm import distributed, synchronize
from libs.utils.checkpoint import load_checkpoint, save_checkpoint
from libs.data import create_train_dataloader, create_valid_dataloader
from libs.utils.model_synchronizer import ModelSynchronizer
from libs.utils.time_counter import TimeCounter
from libs.utils.utils import is_simple_table
from libs.utils.utils import cal_mean_lr
from libs.utils.counter import Counter
from libs.utils import logger
from libs.model import build_model
from libs.configs import cfg, setup_config
metrics_name = ['f1']
best_metrics = [0.0]
def init():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='debug')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
setup_config(args.cfg)
os.environ['LOCAL_RANK'] = str(args.local_rank)
num_gpus = int(os.environ['MORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
distributed = num_gpus > 1
if distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
synchronize()
logger.setup_logger('Line Detect Model', cfg.work_dir, 'train.log')
logger.info('Use config:%s' % args.cfg)
def train(cfg, epoch, dataloader, model, optimizer, scheduler, time_counter, synchronizer=None):
model.train()
counter = Counter(cache_nums=1000)
for it, data_batch in enumerate(dataloader):
ids = data_batch['ids']
images_size = data_batch['images_size']
images = data_batch['images'].to(cfg.device)
cls_labels = data_batch['cls_labels'].to(cfg.device)
labels_mask = data_batch['labels_mask'].to(cfg.device)
rows_fg_spans = data_batch['rows_fg_spans']
rows_bg_spans = data_batch['rows_bg_spans']
cols_fg_spans = data_batch['cols_fg_spans']
cols_bg_spans = data_batch['cols_bg_spans']
cells_spans = data_batch['cells_spans']
divide_labels = data_batch['divide_labels'].to(cfg.device)
layouts = data_batch['layouts'].to(cfg.device)
try:
optimizer.zero_grad()
pred_result, result_info = model(
images, images_size,
cls_labels, labels_mask, layouts,
rows_fg_spans, rows_bg_spans,
cols_fg_spans, cols_bg_spans,
cells_spans, divide_labels,
)
loss = sum([val for key, val in result_info.items() if 'loss' in key])
loss.backward()
optimizer.step()
scheduler.step()
counter.update(result_info)
except:
logger.info('CUDA Out Of Memory')
if it % cfg.log_sep == 0:
logger.info(
'[Train][Epoch %03d Iter %04d][Memory: %.0f ][Mean LR: %f ][Left: %s] %s' %
(
epoch,
it,
torch.cuda.max_memory_allocated()/1024/1024,
cal_mean_lr(optimizer),
time_counter.step(epoch, it + 1),
counter.format_mean(sync=False)
)
)
if synchronizer is not None:
synchronizer()
if synchronizer is not None:
synchronizer(final_align=True)
def valid(cfg, dataloader, model):
model.eval()
total_label_relations = list()
total_pred_relations = list()
total_relations_metric = list()
for it, data_batch in enumerate(tqdm.tqdm(dataloader)):
ids = data_batch['ids']
images_size = data_batch['images_size']
images = data_batch['images'].to(cfg.device)
tables = data_batch['tables']
pred_result, _ = model(images, images_size)
pred_tables = [
pred_result_to_table(tables[batch_idx],
(pred_result[0][batch_idx], pred_result[1][batch_idx],
pred_result[2][batch_idx], pred_result[3][batch_idx])
)
for batch_idx in range(len(ids))
]
pred_relations = [table_to_relations(table) for table in pred_tables]
total_pred_relations.extend(pred_relations)
# label
label_relations = []
for table in tables:
label_path = os.path.join(cfg.valid_data_dir, table['label_path'])
with open(table['label_path'], 'r') as f:
label_relations.append(json.load(f))
total_label_relations.extend(label_relations)
# cal P, R, F1
total_relations_metric = evaluate_f1(total_label_relations, total_pred_relations, num_workers=40)
P, R, F1 = np.array(total_relations_metric).mean(0).tolist()
F1 = 2 * P * R / (P + R)
logger.info('[Valid] Total Type Mertric: Precision: %s, Recall: %s, F1-Score: %s' % (P, R, F1))
return (F1,)
def build_optimizer(cfg, model):
params = list()
for _, value in model.named_parameters():
if not value.requires_grad:
continue
lr = cfg.base_lr
weight_decay = cfg.weight_decay
params += [{'params': [value], 'lr': lr, 'weight_decay': weight_decay}]
optimizer = torch.optim.Adam(params, cfg.base_lr)
return optimizer
def build_scheduler(cfg, optimizer, epoch_iters, start_epoch=0):
scheduler = CosineAnnealingLR(
optimizer=optimizer,
T_max=cfg.num_epochs * epoch_iters,
eta_min=cfg.min_lr,
last_epoch=-1 if start_epoch == 0 else start_epoch * epoch_iters
)
return scheduler
def main():
init()
train_dataloader = create_train_dataloader(
cfg.vocab,
cfg.train_lrcs_path,
cfg.train_num_workers,
cfg.train_max_batch_size,
cfg.train_max_pixel_nums,
cfg.train_bucket_seps,
cfg.train_data_dir
)
logger.info(
'Train dataset have %d samples, %d batchs' %
(
len(train_dataloader.dataset),
len(train_dataloader.batch_sampler)
)
)
valid_dataloader = create_valid_dataloader(
cfg.vocab,
cfg.valid_lrc_path,
cfg.valid_num_workers,
cfg.valid_batch_size,
cfg.valid_data_dir
)
logger.info(
'Valid dataset have %d samples, %d batchs with batch_size=%d' %
(
len(valid_dataloader.dataset),
len(valid_dataloader.batch_sampler),
valid_dataloader.batch_size
)
)
model = build_model(cfg)
model.cuda()
if distributed():
synchronizer = ModelSynchronizer(model, cfg.sync_rate)
else:
synchronizer = None
epoch_iters = len(train_dataloader.batch_sampler)
optimizer = build_optimizer(cfg, model)
global metrics_name
global best_metrics
start_epoch = 0
resume_path = os.path.join(cfg.work_dir, 'latest_model.pth')
if os.path.exists(resume_path):
best_metrics, start_epoch = load_checkpoint(resume_path, model, optimizer)
start_epoch += 1
logger.info('resume from: %s' % resume_path)
elif cfg.train_checkpoint is not None:
load_checkpoint(cfg.train_checkpoint, model)
logger.info('load checkpoint from: %s' % cfg.train_checkpoint)
scheduler = build_scheduler(cfg, optimizer, epoch_iters, start_epoch)
time_counter = TimeCounter(start_epoch, cfg.num_epochs, epoch_iters)
time_counter.reset()
for epoch in range(start_epoch, cfg.num_epochs):
if hasattr(train_dataloader.sampler, 'set_epoch'):
train_dataloader.sampler.set_epoch(epoch)
train(cfg, epoch, train_dataloader, model, optimizer, scheduler, time_counter, synchronizer)
with torch.no_grad():
metrics = valid(cfg, valid_dataloader, model)
for metric_idx in range(len(metrics_name)):
if metrics[metric_idx] > best_metrics[metric_idx]:
best_metrics[metric_idx] = metrics[metric_idx]
save_checkpoint(os.path.join(cfg.work_dir, 'best_%s_model.pth' % metrics_name[metric_idx]), model, optimizer, best_metrics, epoch)
logger.info('Save current model as best_%s_model' % metrics_name[metric_idx])
save_checkpoint(os.path.join(cfg.work_dir, 'latest_model.pth'), model, optimizer, best_metrics, epoch)
if __name__ == '__main__':
main()