| |
| from __future__ import division |
| import logging |
| from utils import CheckpointSaver |
| from tensorboardX import SummaryWriter |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| tqdm.monitor_interval = 0 |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class BaseTrainer(object): |
| """Base class for Trainer objects. |
| Takes care of checkpointing/logging/resuming training. |
| """ |
|
|
| def __init__(self, options): |
| self.options = options |
| if options.multiprocessing_distributed: |
| self.device = torch.device('cuda', options.gpu) |
| else: |
| self.device = torch.device( |
| 'cuda' if torch.cuda.is_available() else 'cpu') |
| |
| self.saver = CheckpointSaver(save_dir=options.checkpoint_dir, |
| overwrite=options.overwrite) |
| if options.rank == 0: |
| self.summary_writer = SummaryWriter(self.options.summary_dir) |
| self.init_fn() |
|
|
| self.checkpoint = None |
| if options.resume and self.saver.exists_checkpoint(): |
| self.checkpoint = self.saver.load_checkpoint( |
| self.models_dict, self.optimizers_dict) |
|
|
| if self.checkpoint is None: |
| self.epoch_count = 0 |
| self.step_count = 0 |
| else: |
| self.epoch_count = self.checkpoint['epoch'] |
| self.step_count = self.checkpoint['total_step_count'] |
|
|
| if self.checkpoint is not None: |
| self.checkpoint_batch_idx = self.checkpoint['batch_idx'] |
| else: |
| self.checkpoint_batch_idx = 0 |
|
|
| self.best_performance = float('inf') |
|
|
| def load_pretrained(self, checkpoint_file=None): |
| """Load a pretrained checkpoint. |
| This is different from resuming training using --resume. |
| """ |
| if checkpoint_file is not None: |
| checkpoint = torch.load(checkpoint_file) |
| for model in self.models_dict: |
| if model in checkpoint: |
| self.models_dict[model].load_state_dict(checkpoint[model], |
| strict=True) |
| print(f'Checkpoint {model} loaded') |
|
|
| def move_dict_to_device(self, dict, device, tensor2float=False): |
| for k, v in dict.items(): |
| if isinstance(v, torch.Tensor): |
| if tensor2float: |
| dict[k] = v.float().to(device) |
| else: |
| dict[k] = v.to(device) |
|
|
| |
| def train(self, epoch): |
| raise NotImplementedError('You need to provide an train method') |
|
|
| def init_fn(self): |
| raise NotImplementedError('You need to provide an _init_fn method') |
|
|
| def train_step(self, input_batch): |
| raise NotImplementedError('You need to provide a _train_step method') |
|
|
| def train_summaries(self, input_batch): |
| raise NotImplementedError( |
| 'You need to provide a _train_summaries method') |
|
|
| def visualize(self, input_batch): |
| raise NotImplementedError('You need to provide a visualize method') |
|
|
| def validate(self): |
| pass |
|
|
| def test(self): |
| pass |
|
|
| def evaluate(self): |
| pass |
|
|
| def fit(self): |
| |
| for epoch in tqdm(range(self.epoch_count, self.options.num_epochs), |
| total=self.options.num_epochs, |
| initial=self.epoch_count): |
| self.epoch_count = epoch |
| self.train(epoch) |
| return |
|
|