File size: 2,978 Bytes
383bfb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.optim
import torch.utils.data
import torch.nn.parallel
from core.train import *
from core.test import *
from utils import *
from core.network import get_model
from loss import get_criterion
from datasets import get_dataloader


def main_worker(args):

    args.logger = initialize_logger(args)
    train_loader, train_sampler, val_loader, src_lang, tgt_lang = get_dataloader(args)
    model = get_model(args, src_lang, tgt_lang).cuda()
    optimizer = get_optimizer(args, model)
    scheduler = get_scheduler(args, optimizer)
    criterion = get_criterion(args)
    start_epoch = 0
    
    # resume model
    if not args.resume_model =='':
        resume_model_dict = model.load_model(args.resume_model)
        optimizer.load_state_dict(resume_model_dict['optimizer'])
        scheduler.load_state_dict(resume_model_dict['scheduler'])
        start_epoch = resume_model_dict["epoch"]+1
        args.logger.info("The whole model has been loaded from "+ args.resume_model)
        args.logger.info("The model resumes from epoch "+ str(resume_model_dict["epoch"]))
        if args.evaluate_only:
            acc_ans, acc_eq = validate(args, val_loader, model, tgt_lang)
            args.logger.info("----------Epoch:{:>3d}, test answer_acc {:>5.4f}, equation_acc {:>5.4f} ---------" \
                                            .format(resume_model_dict["epoch"], acc_ans, acc_eq))
            return
    else:
        args.logger.info("The model is trained from scratch")

    # distributed parallel training 
    model = torch.nn.parallel.DistributedDataParallel(
        model, 
        device_ids=[args.local_rank], 
        output_device=args.local_rank, 
        find_unused_parameters=True
        )

    min_loss = 1e10 
    
    for epoch in range(start_epoch, args.max_epoch):
        # train for one epoch
        train_sampler.set_epoch(epoch)
        loss = train(args, epoch, train_loader, model, criterion, optimizer)
        args.logger.info("----------Epoch:{:>3d}, training loss is {:>5.4f} ---------". \
                    format(epoch, loss))
        # evaluate on validation set and save model 
        if args.local_rank == 0: 
            if epoch % args.eval_epoch==0 or epoch>=args.max_epoch-5:
                save_checkpoint({
                    'epoch': epoch ,
                    'state_dict': model.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'optimizer': optimizer.state_dict()}, False, args.dump_path)
            if loss<min_loss: 
                min_loss = loss
                save_checkpoint({
                    'epoch': epoch ,
                    'state_dict': model.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'optimizer': optimizer.state_dict()}, True, args.dump_path)
        # learning scheduler step
        scheduler.step()
    
    args.logger.info("------------------- Train Finished -------------------")