| | |
| | from __future__ import division |
| | import logging |
| | from utils import CheckpointSaver |
| | from tensorboardX import SummaryWriter |
| |
|
| | import torch |
| | from tqdm import tqdm |
| |
|
| | tqdm.monitor_interval = 0 |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class BaseTrainer(object): |
| | """Base class for Trainer objects. |
| | Takes care of checkpointing/logging/resuming training. |
| | """ |
| |
|
| | def __init__(self, options): |
| | self.options = options |
| | if options.multiprocessing_distributed: |
| | self.device = torch.device('cuda', options.gpu) |
| | else: |
| | self.device = torch.device( |
| | 'cuda' if torch.cuda.is_available() else 'cpu') |
| | |
| | self.saver = CheckpointSaver(save_dir=options.checkpoint_dir, |
| | overwrite=options.overwrite) |
| | if options.rank == 0: |
| | self.summary_writer = SummaryWriter(self.options.summary_dir) |
| | self.init_fn() |
| |
|
| | self.checkpoint = None |
| | if options.resume and self.saver.exists_checkpoint(): |
| | self.checkpoint = self.saver.load_checkpoint( |
| | self.models_dict, self.optimizers_dict) |
| |
|
| | if self.checkpoint is None: |
| | self.epoch_count = 0 |
| | self.step_count = 0 |
| | else: |
| | self.epoch_count = self.checkpoint['epoch'] |
| | self.step_count = self.checkpoint['total_step_count'] |
| |
|
| | if self.checkpoint is not None: |
| | self.checkpoint_batch_idx = self.checkpoint['batch_idx'] |
| | else: |
| | self.checkpoint_batch_idx = 0 |
| |
|
| | self.best_performance = float('inf') |
| |
|
| | def load_pretrained(self, checkpoint_file=None): |
| | """Load a pretrained checkpoint. |
| | This is different from resuming training using --resume. |
| | """ |
| | if checkpoint_file is not None: |
| | checkpoint = torch.load(checkpoint_file) |
| | for model in self.models_dict: |
| | if model in checkpoint: |
| | self.models_dict[model].load_state_dict(checkpoint[model], |
| | strict=True) |
| | print(f'Checkpoint {model} loaded') |
| |
|
| | def move_dict_to_device(self, dict, device, tensor2float=False): |
| | for k, v in dict.items(): |
| | if isinstance(v, torch.Tensor): |
| | if tensor2float: |
| | dict[k] = v.float().to(device) |
| | else: |
| | dict[k] = v.to(device) |
| |
|
| | |
| | def train(self, epoch): |
| | raise NotImplementedError('You need to provide an train method') |
| |
|
| | def init_fn(self): |
| | raise NotImplementedError('You need to provide an _init_fn method') |
| |
|
| | def train_step(self, input_batch): |
| | raise NotImplementedError('You need to provide a _train_step method') |
| |
|
| | def train_summaries(self, input_batch): |
| | raise NotImplementedError( |
| | 'You need to provide a _train_summaries method') |
| |
|
| | def visualize(self, input_batch): |
| | raise NotImplementedError('You need to provide a visualize method') |
| |
|
| | def validate(self): |
| | pass |
| |
|
| | def test(self): |
| | pass |
| |
|
| | def evaluate(self): |
| | pass |
| |
|
| | def fit(self): |
| | |
| | for epoch in tqdm(range(self.epoch_count, self.options.num_epochs), |
| | total=self.options.num_epochs, |
| | initial=self.epoch_count): |
| | self.epoch_count = epoch |
| | self.train(epoch) |
| | return |
| |
|