pgps-demo / core /worker.py
asdfasdfdsafdsa's picture
Initial upload of PGPS demo with all dependencies
383bfb8 verified
import torch
import torch.optim
import torch.utils.data
import torch.nn.parallel
from core.train import *
from core.test import *
from utils import *
from core.network import get_model
from loss import get_criterion
from datasets import get_dataloader
def main_worker(args):
args.logger = initialize_logger(args)
train_loader, train_sampler, val_loader, src_lang, tgt_lang = get_dataloader(args)
model = get_model(args, src_lang, tgt_lang).cuda()
optimizer = get_optimizer(args, model)
scheduler = get_scheduler(args, optimizer)
criterion = get_criterion(args)
start_epoch = 0
# resume model
if not args.resume_model =='':
resume_model_dict = model.load_model(args.resume_model)
optimizer.load_state_dict(resume_model_dict['optimizer'])
scheduler.load_state_dict(resume_model_dict['scheduler'])
start_epoch = resume_model_dict["epoch"]+1
args.logger.info("The whole model has been loaded from "+ args.resume_model)
args.logger.info("The model resumes from epoch "+ str(resume_model_dict["epoch"]))
if args.evaluate_only:
acc_ans, acc_eq = validate(args, val_loader, model, tgt_lang)
args.logger.info("----------Epoch:{:>3d}, test answer_acc {:>5.4f}, equation_acc {:>5.4f} ---------" \
.format(resume_model_dict["epoch"], acc_ans, acc_eq))
return
else:
args.logger.info("The model is trained from scratch")
# distributed parallel training
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True
)
min_loss = 1e10
for epoch in range(start_epoch, args.max_epoch):
# train for one epoch
train_sampler.set_epoch(epoch)
loss = train(args, epoch, train_loader, model, criterion, optimizer)
args.logger.info("----------Epoch:{:>3d}, training loss is {:>5.4f} ---------". \
format(epoch, loss))
# evaluate on validation set and save model
if args.local_rank == 0:
if epoch % args.eval_epoch==0 or epoch>=args.max_epoch-5:
save_checkpoint({
'epoch': epoch ,
'state_dict': model.state_dict(),
'scheduler': scheduler.state_dict(),
'optimizer': optimizer.state_dict()}, False, args.dump_path)
if loss<min_loss:
min_loss = loss
save_checkpoint({
'epoch': epoch ,
'state_dict': model.state_dict(),
'scheduler': scheduler.state_dict(),
'optimizer': optimizer.state_dict()}, True, args.dump_path)
# learning scheduler step
scheduler.step()
args.logger.info("------------------- Train Finished -------------------")