Spaces:
Sleeping
Sleeping
File size: 2,978 Bytes
383bfb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import torch
import torch.optim
import torch.utils.data
import torch.nn.parallel
from core.train import *
from core.test import *
from utils import *
from core.network import get_model
from loss import get_criterion
from datasets import get_dataloader
def main_worker(args):
args.logger = initialize_logger(args)
train_loader, train_sampler, val_loader, src_lang, tgt_lang = get_dataloader(args)
model = get_model(args, src_lang, tgt_lang).cuda()
optimizer = get_optimizer(args, model)
scheduler = get_scheduler(args, optimizer)
criterion = get_criterion(args)
start_epoch = 0
# resume model
if not args.resume_model =='':
resume_model_dict = model.load_model(args.resume_model)
optimizer.load_state_dict(resume_model_dict['optimizer'])
scheduler.load_state_dict(resume_model_dict['scheduler'])
start_epoch = resume_model_dict["epoch"]+1
args.logger.info("The whole model has been loaded from "+ args.resume_model)
args.logger.info("The model resumes from epoch "+ str(resume_model_dict["epoch"]))
if args.evaluate_only:
acc_ans, acc_eq = validate(args, val_loader, model, tgt_lang)
args.logger.info("----------Epoch:{:>3d}, test answer_acc {:>5.4f}, equation_acc {:>5.4f} ---------" \
.format(resume_model_dict["epoch"], acc_ans, acc_eq))
return
else:
args.logger.info("The model is trained from scratch")
# distributed parallel training
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True
)
min_loss = 1e10
for epoch in range(start_epoch, args.max_epoch):
# train for one epoch
train_sampler.set_epoch(epoch)
loss = train(args, epoch, train_loader, model, criterion, optimizer)
args.logger.info("----------Epoch:{:>3d}, training loss is {:>5.4f} ---------". \
format(epoch, loss))
# evaluate on validation set and save model
if args.local_rank == 0:
if epoch % args.eval_epoch==0 or epoch>=args.max_epoch-5:
save_checkpoint({
'epoch': epoch ,
'state_dict': model.state_dict(),
'scheduler': scheduler.state_dict(),
'optimizer': optimizer.state_dict()}, False, args.dump_path)
if loss<min_loss:
min_loss = loss
save_checkpoint({
'epoch': epoch ,
'state_dict': model.state_dict(),
'scheduler': scheduler.state_dict(),
'optimizer': optimizer.state_dict()}, True, args.dump_path)
# learning scheduler step
scheduler.step()
args.logger.info("------------------- Train Finished -------------------")
|