File size: 1,747 Bytes
8ec10cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""main file that does everything"""
from utils import interact

from option import args, setup, cleanup
from data import Data
from model import Model
from loss import Loss
from optim import Optimizer
from train import Trainer

def main_worker(rank, args):
    args.rank = rank
    args = setup(args)

    loaders = Data(args).get_loader()
    model = Model(args)
    model.parallelize()
    optimizer = Optimizer(args, model)

    criterion = Loss(args, model=model, optimizer=optimizer)

    trainer = Trainer(args, model, criterion, optimizer, loaders)

    if args.stay:
        interact(local=locals())
        exit()

    if args.demo:
        trainer.evaluate(epoch=args.start_epoch, mode='demo')
        exit()

    for epoch in range(1, args.start_epoch):
        if args.do_validate:
            if epoch % args.validate_every == 0:
                trainer.fill_evaluation(epoch, 'val')
        if args.do_test:
            if epoch % args.test_every == 0:
                trainer.fill_evaluation(epoch, 'test')

    for epoch in range(args.start_epoch, args.end_epoch+1):
        if args.do_train:
            trainer.train(epoch)

        if args.do_validate:
            if epoch % args.validate_every == 0:
                if trainer.epoch != epoch:
                    trainer.load(epoch)
                trainer.validate(epoch)

        if args.do_test:
            if epoch % args.test_every == 0:
                if trainer.epoch != epoch:
                    trainer.load(epoch)
                trainer.test(epoch)

        if args.rank == 0 or not args.launched:
            print('')

    trainer.imsaver.join_background()

    cleanup(args)

def main():
    main_worker(args.rank, args)

if __name__ == "__main__":
    main()