| import torch | |
| import utility | |
| import data | |
| import model | |
| import loss | |
| from option import args | |
| from trainer import Trainer | |
| torch.manual_seed(args.seed) | |
| checkpoint = utility.checkpoint(args) | |
| def main(): | |
| global model | |
| if args.data_test == ['video']: | |
| from videotester import VideoTester | |
| model = model.Model(args, checkpoint) | |
| t = VideoTester(args, model, checkpoint) | |
| t.test() | |
| else: | |
| if checkpoint.ok: | |
| loader = data.Data(args) | |
| _model = model.Model(args, checkpoint) | |
| _loss = loss.Loss(args, checkpoint) if not args.test_only else None | |
| t = Trainer(args, loader, _model, _loss, checkpoint) | |
| while not t.terminate(): | |
| t.train() | |
| t.test() | |
| checkpoint.done() | |
| if __name__ == '__main__': | |
| main() | |