| |
| |
| |
|
|
| import os |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from src.efficientvit.apps.data_provider import DataProvider, parse_image_size |
| from src.efficientvit.apps.trainer.run_config import RunConfig |
| from src.efficientvit.apps.utils import (EMA, dist_barrier, get_dist_local_rank, |
| is_master) |
| from src.efficientvit.models.nn.norm import reset_bn |
| from src.efficientvit.models.utils import is_parallel, load_state_dict_from_file |
|
|
| __all__ = ["Trainer"] |
|
|
|
|
| class Trainer: |
| def __init__(self, path: str, model: nn.Module, data_provider: DataProvider): |
| self.path = os.path.realpath(os.path.expanduser(path)) |
| self.model = model.cuda() |
| self.data_provider = data_provider |
|
|
| self.ema = None |
|
|
| self.checkpoint_path = os.path.join(self.path, "checkpoint") |
| self.logs_path = os.path.join(self.path, "logs") |
| for path in [self.path, self.checkpoint_path, self.logs_path]: |
| os.makedirs(path, exist_ok=True) |
|
|
| self.best_val = 0.0 |
| self.start_epoch = 0 |
|
|
| @property |
| def network(self) -> nn.Module: |
| return self.model.module if is_parallel(self.model) else self.model |
|
|
| @property |
| def eval_network(self) -> nn.Module: |
| if self.ema is None: |
| model = self.model |
| else: |
| model = self.ema.shadows |
| model = model.module if is_parallel(model) else model |
| return model |
|
|
| def write_log(self, log_str, prefix="valid", print_log=True, mode="a") -> None: |
| if is_master(): |
| fout = open(os.path.join(self.logs_path, f"{prefix}.log"), mode) |
| fout.write(log_str + "\n") |
| fout.flush() |
| fout.close() |
| if print_log: |
| print(log_str) |
|
|
| def save_model( |
| self, |
| checkpoint=None, |
| only_state_dict=True, |
| epoch=0, |
| model_name=None, |
| ) -> None: |
| if is_master(): |
| if checkpoint is None: |
| if only_state_dict: |
| checkpoint = {"state_dict": self.network.state_dict()} |
| else: |
| checkpoint = { |
| "state_dict": self.network.state_dict(), |
| "epoch": epoch, |
| "best_val": self.best_val, |
| "optimizer": self.optimizer.state_dict(), |
| "lr_scheduler": self.lr_scheduler.state_dict(), |
| "ema": self.ema.state_dict() if self.ema is not None else None, |
| "scaler": self.scaler.state_dict() if self.fp16 else None, |
| } |
|
|
| model_name = model_name or "checkpoint.pt" |
|
|
| latest_fname = os.path.join(self.checkpoint_path, "latest.txt") |
| model_path = os.path.join(self.checkpoint_path, model_name) |
| with open(latest_fname, "w") as _fout: |
| _fout.write(model_path + "\n") |
| torch.save(checkpoint, model_path) |
|
|
| def load_model(self, model_fname=None) -> None: |
| latest_fname = os.path.join(self.checkpoint_path, "latest.txt") |
| if model_fname is None and os.path.exists(latest_fname): |
| with open(latest_fname, "r") as fin: |
| model_fname = fin.readline() |
| if len(model_fname) > 0 and model_fname[-1] == "\n": |
| model_fname = model_fname[:-1] |
| try: |
| if model_fname is None: |
| model_fname = f"{self.checkpoint_path}/checkpoint.pt" |
| elif not os.path.exists(model_fname): |
| model_fname = f"{self.checkpoint_path}/{os.path.basename(model_fname)}" |
| if not os.path.exists(model_fname): |
| model_fname = f"{self.checkpoint_path}/checkpoint.pt" |
| print(f"=> loading checkpoint {model_fname}") |
| checkpoint = load_state_dict_from_file(model_fname, False) |
| except Exception: |
| self.write_log(f"fail to load checkpoint from {self.checkpoint_path}") |
| return |
|
|
| |
| self.network.load_state_dict(checkpoint["state_dict"], strict=False) |
| log = [] |
| if "epoch" in checkpoint: |
| self.start_epoch = checkpoint["epoch"] + 1 |
| self.run_config.update_global_step(self.start_epoch) |
| log.append(f"epoch={self.start_epoch - 1}") |
| if "best_val" in checkpoint: |
| self.best_val = checkpoint["best_val"] |
| log.append(f"best_val={self.best_val:.2f}") |
| if "optimizer" in checkpoint: |
| self.optimizer.load_state_dict(checkpoint["optimizer"]) |
| log.append("optimizer") |
| if "lr_scheduler" in checkpoint: |
| self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) |
| log.append("lr_scheduler") |
| if "ema" in checkpoint and self.ema is not None: |
| self.ema.load_state_dict(checkpoint["ema"]) |
| log.append("ema") |
| if "scaler" in checkpoint and self.fp16: |
| self.scaler.load_state_dict(checkpoint["scaler"]) |
| log.append("scaler") |
| self.write_log("Loaded: " + ", ".join(log)) |
|
|
| """ validate """ |
|
|
| def reset_bn( |
| self, |
| network: nn.Module or None = None, |
| subset_size: int = 16000, |
| subset_batch_size: int = 100, |
| data_loader=None, |
| progress_bar=False, |
| ) -> None: |
| network = network or self.network |
| if data_loader is None: |
| data_loader = [] |
| for data in self.data_provider.build_sub_train_loader( |
| subset_size, subset_batch_size |
| ): |
| if isinstance(data, list): |
| data_loader.append(data[0]) |
| elif isinstance(data, dict): |
| data_loader.append(data["data"]) |
| elif isinstance(data, torch.Tensor): |
| data_loader.append(data) |
| else: |
| raise NotImplementedError |
|
|
| network.eval() |
| reset_bn( |
| network, |
| data_loader, |
| sync=True, |
| progress_bar=progress_bar, |
| ) |
|
|
| def _validate(self, model, data_loader, epoch) -> dict[str, any]: |
| raise NotImplementedError |
|
|
| def validate( |
| self, model=None, data_loader=None, is_test=True, epoch=0 |
| ) -> dict[str, any]: |
| model = model or self.eval_network |
| if data_loader is None: |
| if is_test: |
| data_loader = self.data_provider.test |
| else: |
| data_loader = self.data_provider.valid |
|
|
| model.eval() |
| return self._validate(model, data_loader, epoch) |
|
|
| def multires_validate( |
| self, |
| model=None, |
| data_loader=None, |
| is_test=True, |
| epoch=0, |
| eval_image_size=None, |
| ) -> dict[str, dict[str, any]]: |
| eval_image_size = eval_image_size or self.run_config.eval_image_size |
| eval_image_size = eval_image_size or self.data_provider.image_size |
| model = model or self.eval_network |
|
|
| if not isinstance(eval_image_size, list): |
| eval_image_size = [eval_image_size] |
|
|
| output_dict = {} |
| for r in eval_image_size: |
| self.data_provider.assign_active_image_size(parse_image_size(r)) |
| if self.run_config.reset_bn: |
| self.reset_bn( |
| network=model, |
| subset_size=self.run_config.reset_bn_size, |
| subset_batch_size=self.run_config.reset_bn_batch_size, |
| progress_bar=True, |
| ) |
| output_dict[f"r{r}"] = self.validate(model, data_loader, is_test, epoch) |
| return output_dict |
|
|
| """ training """ |
|
|
| def prep_for_training( |
| self, run_config: RunConfig, ema_decay: float or None = None, fp16=False |
| ) -> None: |
| self.run_config = run_config |
| self.model = nn.parallel.DistributedDataParallel( |
| self.model.cuda(), |
| device_ids=[get_dist_local_rank()], |
| static_graph=True, |
| ) |
|
|
| self.run_config.global_step = 0 |
| self.run_config.batch_per_epoch = len(self.data_provider.train) |
| assert self.run_config.batch_per_epoch > 0, "Training set is empty" |
|
|
| |
| self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model) |
|
|
| if ema_decay is not None: |
| self.ema = EMA(self.network, ema_decay) |
|
|
| |
| self.fp16 = fp16 |
| self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) |
|
|
| def sync_model(self): |
| print("Sync model") |
| self.save_model(model_name="sync.pt") |
| dist_barrier() |
| checkpoint = torch.load( |
| os.path.join(self.checkpoint_path, "sync.pt"), map_location="cpu" |
| ) |
| dist_barrier() |
| if is_master(): |
| os.remove(os.path.join(self.checkpoint_path, "sync.pt")) |
| dist_barrier() |
|
|
| |
| self.network.load_state_dict(checkpoint["state_dict"], strict=False) |
| if "optimizer" in checkpoint: |
| self.optimizer.load_state_dict(checkpoint["optimizer"]) |
| if "lr_scheduler" in checkpoint: |
| self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) |
| if "ema" in checkpoint and self.ema is not None: |
| self.ema.load_state_dict(checkpoint["ema"]) |
| if "scaler" in checkpoint and self.fp16: |
| self.scaler.load_state_dict(checkpoint["scaler"]) |
|
|
| def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]: |
| for key in feed_dict: |
| if isinstance(feed_dict[key], torch.Tensor): |
| feed_dict[key] = feed_dict[key].cuda() |
| return feed_dict |
|
|
| def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]: |
| raise NotImplementedError |
|
|
| def after_step(self) -> None: |
| self.scaler.unscale_(self.optimizer) |
| |
| if self.run_config.grad_clip is not None: |
| torch.nn.utils.clip_grad_value_( |
| self.model.parameters(), self.run_config.grad_clip |
| ) |
| |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
|
|
| self.lr_scheduler.step() |
| self.run_config.step() |
| |
| if self.ema is not None: |
| self.ema.step(self.network, self.run_config.global_step) |
|
|
| def _train_one_epoch(self, epoch: int) -> dict[str, any]: |
| raise NotImplementedError |
|
|
| def train_one_epoch(self, epoch: int) -> dict[str, any]: |
| self.model.train() |
|
|
| self.data_provider.set_epoch(epoch) |
|
|
| train_info_dict = self._train_one_epoch(epoch) |
|
|
| return train_info_dict |
|
|
| def train(self) -> None: |
| raise NotImplementedError |
|
|