File size: 1,747 Bytes
8ec10cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
"""main file that does everything"""
from utils import interact
from option import args, setup, cleanup
from data import Data
from model import Model
from loss import Loss
from optim import Optimizer
from train import Trainer
def main_worker(rank, args):
args.rank = rank
args = setup(args)
loaders = Data(args).get_loader()
model = Model(args)
model.parallelize()
optimizer = Optimizer(args, model)
criterion = Loss(args, model=model, optimizer=optimizer)
trainer = Trainer(args, model, criterion, optimizer, loaders)
if args.stay:
interact(local=locals())
exit()
if args.demo:
trainer.evaluate(epoch=args.start_epoch, mode='demo')
exit()
for epoch in range(1, args.start_epoch):
if args.do_validate:
if epoch % args.validate_every == 0:
trainer.fill_evaluation(epoch, 'val')
if args.do_test:
if epoch % args.test_every == 0:
trainer.fill_evaluation(epoch, 'test')
for epoch in range(args.start_epoch, args.end_epoch+1):
if args.do_train:
trainer.train(epoch)
if args.do_validate:
if epoch % args.validate_every == 0:
if trainer.epoch != epoch:
trainer.load(epoch)
trainer.validate(epoch)
if args.do_test:
if epoch % args.test_every == 0:
if trainer.epoch != epoch:
trainer.load(epoch)
trainer.test(epoch)
if args.rank == 0 or not args.launched:
print('')
trainer.imsaver.join_background()
cleanup(args)
def main():
main_worker(args.rank, args)
if __name__ == "__main__":
main() |