| """
|
| Copyright (c) 2022, salesforce.com, inc.
|
| All rights reserved.
|
| SPDX-License-Identifier: BSD-3-Clause
|
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| """
|
|
|
| import datetime
|
| import logging
|
| import os
|
| import time
|
|
|
| import torch
|
| import torch.distributed as dist
|
| import webdataset as wds
|
| from lavis.common.dist_utils import download_cached_file, is_main_process, main_process
|
| from lavis.common.registry import registry
|
| from lavis.common.utils import is_url
|
| from lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split
|
| from lavis.runners.runner_base import RunnerBase
|
| from torch.utils.data.dataset import ChainDataset
|
|
|
|
|
| @registry.register_runner("runner_iter")
|
| class RunnerIter(RunnerBase):
|
| """
|
| Run training based on the number of iterations. This is common when
|
| the training dataset size is large. Underhood logic is similar to
|
| epoch-based training by considering every #iters_per_inner_epoch as an
|
| inner epoch.
|
|
|
| In iter-based runner, after every #iters_per_inner_epoch steps, we
|
|
|
| 1) do a validation epoch;
|
| 2) schedule the learning rate;
|
| 3) save the checkpoint.
|
|
|
| We refer every #iters_per_inner_epoch steps as an inner epoch.
|
| """
|
|
|
| def __init__(self, cfg, task, model, datasets, job_id):
|
| super().__init__(cfg, task, model, datasets, job_id)
|
|
|
| self.start_iters = 0
|
|
|
| self.max_iters = int(self.config.run_cfg.get("max_iters", -1))
|
| assert self.max_iters > 0, "max_iters must be greater than 0."
|
|
|
| self.iters_per_inner_epoch = int(
|
| self.config.run_cfg.get("iters_per_inner_epoch", -1)
|
| )
|
| assert (
|
| self.iters_per_inner_epoch > 0
|
| ), "iters_per_inner_epoch must be greater than 0."
|
|
|
| @property
|
| def max_epoch(self):
|
| return int(self.max_iters / self.iters_per_inner_epoch)
|
|
|
| @property
|
| def cur_epoch(self):
|
| try:
|
| return self.train_loader.epoch
|
| except AttributeError:
|
|
|
| return 0
|
|
|
| def _progress(self, cur_iters):
|
| return "{}_iters={}".format(self.cur_epoch, cur_iters)
|
|
|
| def train(self):
|
| start_time = time.time()
|
| best_agg_metric = 0
|
| best_iters = 0
|
|
|
| self.log_config()
|
|
|
|
|
| if not self.evaluate_only and self.resume_ckpt_path is not None:
|
| self._load_checkpoint(self.resume_ckpt_path)
|
|
|
| for start_iters in range(
|
| self.start_iters, self.max_iters, self.iters_per_inner_epoch
|
| ):
|
| end_iters = start_iters + self.iters_per_inner_epoch
|
|
|
|
|
| if not self.evaluate_only:
|
| logging.info(
|
| "Start training, max_iters={}, in total {} inner epochs.".format(
|
| self.max_iters, int(self.max_iters / self.iters_per_inner_epoch)
|
| )
|
| )
|
| if start_iters == self.start_iters:
|
| self.task.before_training(
|
| model=self.unwrap_dist_model(self.model),
|
| dataset=self.datasets,
|
| )
|
| train_stats = self.train_iters(self.cur_epoch, start_iters)
|
| self.log_stats(split_name="train", stats=train_stats)
|
|
|
|
|
| if len(self.valid_splits) > 0:
|
| for split_name in self.valid_splits:
|
| logging.info("Evaluating on {}.".format(split_name))
|
|
|
| val_log = self.eval_epoch(
|
| split_name=split_name, cur_epoch=self._progress(end_iters)
|
| )
|
| if val_log is not None:
|
| if is_main_process():
|
| assert (
|
| "agg_metrics" in val_log
|
| ), "No agg_metrics found in validation log."
|
|
|
| agg_metrics = val_log["agg_metrics"]
|
| if agg_metrics > best_agg_metric and split_name == "val":
|
| best_iters, best_agg_metric = end_iters, agg_metrics
|
|
|
| self._save_checkpoint(end_iters, is_best=True)
|
|
|
| val_log.update({"best_iters": best_iters})
|
| self.log_stats(val_log, split_name)
|
|
|
| else:
|
|
|
| if not self.evaluate_only:
|
| self._save_checkpoint(end_iters, is_best=False)
|
|
|
| if self.evaluate_only:
|
| break
|
| dist.barrier()
|
|
|
|
|
| self.evaluate(cur_epoch=self.cur_epoch)
|
|
|
| total_time = time.time() - start_time
|
| total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| logging.info("Training time {}".format(total_time_str))
|
|
|
| def train_iters(self, epoch, start_iters):
|
|
|
| self.model.train()
|
|
|
| return self.task.train_iters(
|
| epoch=epoch,
|
| start_iters=start_iters,
|
| iters_per_inner_epoch=self.iters_per_inner_epoch,
|
| model=self.model,
|
| data_loader=self.train_loader,
|
| optimizer=self.optimizer,
|
| scaler=self.scaler,
|
| lr_scheduler=self.lr_scheduler,
|
| cuda_enabled=self.cuda_enabled,
|
| log_freq=self.log_freq,
|
| accum_grad_iters=self.accum_grad_iters,
|
| )
|
|
|
| @main_process
|
| def _save_checkpoint(self, cur_iters, is_best=False):
|
| model_no_ddp = self.unwrap_dist_model(self.model)
|
| param_grad_dic = {
|
| k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
|
| }
|
|
|
| state_dict = model_no_ddp.state_dict()
|
| for k in list(state_dict.keys()):
|
| if k in param_grad_dic.keys() and not param_grad_dic[k]:
|
|
|
| del state_dict[k]
|
|
|
| save_obj = {
|
| "model": state_dict,
|
| "optimizer": self.optimizer.state_dict(),
|
| "config": self.config.to_dict(),
|
| "scaler": self.scaler.state_dict() if self.scaler else None,
|
| "iters": cur_iters,
|
| }
|
| save_to = os.path.join(
|
| self.output_dir,
|
| "checkpoint_{}.pth".format("best" if is_best else cur_iters),
|
| )
|
| logging.info("Saving checkpoint at iters {} to {}.".format(cur_iters, save_to))
|
| torch.save(save_obj, save_to)
|
|
|
| def _load_checkpoint(self, url_or_filename):
|
| """
|
| Resume from a checkpoint.
|
| """
|
| if is_url(url_or_filename):
|
| cached_file = download_cached_file(
|
| url_or_filename, check_hash=False, progress=True
|
| )
|
| checkpoint = torch.load(cached_file, map_location=self.device)
|
| elif os.path.isfile(url_or_filename):
|
| checkpoint = torch.load(url_or_filename, map_location=self.device)
|
| else:
|
| raise RuntimeError("checkpoint url or path is invalid")
|
|
|
| state_dict = checkpoint["model"]
|
| self.unwrap_dist_model(self.model).load_state_dict(state_dict)
|
|
|
| self.optimizer.load_state_dict(checkpoint["optimizer"])
|
| if self.scaler and "scaler" in checkpoint:
|
| self.scaler.load_state_dict(checkpoint["scaler"])
|
|
|
| self.start_iters = checkpoint["iters"] + 1
|
| logging.info("Resume checkpoint from {}".format(url_or_filename))
|
|
|
| @property
|
| def dataloaders(self) -> dict:
|
| """
|
| A property to get and create dataloaders by split just in need.
|
|
|
| If no train_dataset_ratio is provided, concatenate map-style datasets and
|
| chain wds.DataPipe datasets separately. Training set becomes a tuple
|
| (ConcatDataset, ChainDataset), both are optional but at least one of them is
|
| required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
|
|
|
| If train_dataset_ratio is provided, create a MultiIterLoader to sample
|
| each dataset by ratios during training.
|
|
|
| Currently do not support multiple datasets for validation and test.
|
|
|
| Returns:
|
| dict: {split_name: (tuples of) dataloader}
|
| """
|
| if self._dataloaders is None:
|
|
|
| dataset_ratios = self.config.run_cfg.get("train_dataset_ratios", None)
|
|
|
| if dataset_ratios is None:
|
|
|
|
|
|
|
|
|
| logging.info(
|
| "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
|
| )
|
|
|
| datasets = reorg_datasets_by_split(self.datasets)
|
| self.datasets = concat_datasets(datasets)
|
| else:
|
|
|
| missing_keys = [k for k in dataset_ratios if k not in self.datasets]
|
| if len(missing_keys) > 0:
|
| raise ValueError(
|
| "Datasets with the following split names are not found: {}".format(
|
| missing_keys
|
| )
|
| )
|
|
|
| unexpected_keys = [k for k in self.datasets if k not in dataset_ratios]
|
| if len(unexpected_keys) > 0:
|
| raise ValueError(
|
| "Datasets with the following split names are not expected: {}".format(
|
| unexpected_keys
|
| )
|
| )
|
|
|
| dataset_ratios = [float(dataset_ratios[k]) for k in self.datasets]
|
| self.datasets = reorg_datasets_by_split(self.datasets)
|
|
|
| self.datasets = {
|
| k: v[0] if len(v) == 1 else v for k, v in datasets.items()
|
| }
|
|
|
|
|
| for split_name in self.datasets:
|
| if isinstance(self.datasets[split_name], tuple) or isinstance(
|
| self.datasets[split_name], list
|
| ):
|
|
|
| num_records = sum(
|
| [
|
| len(d)
|
| if not type(d) in [wds.DataPipeline, ChainDataset]
|
| else 0
|
| for d in self.datasets[split_name]
|
| ]
|
| )
|
|
|
| else:
|
| try:
|
|
|
| num_records = len(self.datasets[split_name])
|
| except TypeError:
|
|
|
| num_records = -1
|
| logging.info(
|
| "Only a single wds.DataPipeline dataset, no __len__ attribute."
|
| )
|
|
|
| if num_records >= 0:
|
| logging.info(
|
| "Loaded {} records for {} split from the dataset.".format(
|
| num_records, split_name
|
| )
|
| )
|
|
|
|
|
| split_names = sorted(self.datasets.keys())
|
|
|
| datasets = [self.datasets[split] for split in split_names]
|
| is_trains = [split in self.train_splits for split in split_names]
|
|
|
| batch_sizes = [
|
| self.config.run_cfg.batch_size_train
|
| if split == "train"
|
| else self.config.run_cfg.batch_size_eval
|
| for split in split_names
|
| ]
|
|
|
| collate_fns = []
|
| for dataset in datasets:
|
| if isinstance(dataset, tuple) or isinstance(dataset, list):
|
| collate_fns.append([getattr(d, "collater", None) for d in dataset])
|
| else:
|
| collate_fns.append(getattr(dataset, "collater", None))
|
|
|
| dataloaders = self.create_loaders(
|
| datasets=datasets,
|
| num_workers=self.config.run_cfg.num_workers,
|
| batch_sizes=batch_sizes,
|
| is_trains=is_trains,
|
| collate_fns=collate_fns,
|
| dataset_ratios=dataset_ratios,
|
| )
|
|
|
| self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
|
|
|
| return self._dataloaders
|
|
|