| """Base class for trainable models.""" |
|
|
| import logging |
| import re |
| from abc import ABCMeta, abstractmethod |
| from copy import copy |
|
|
| import omegaconf |
| import torch |
| from omegaconf import OmegaConf |
| from torch import nn |
|
|
| logger = logging.getLogger(__name__) |
|
|
| try: |
| import wandb |
| except ImportError: |
| logger.debug("Could not import wandb.") |
| wandb = None |
|
|
| |
| |
|
|
|
|
| class MetaModel(ABCMeta): |
| def __prepare__(name, bases, **kwds): |
| total_conf = OmegaConf.create() |
| for base in bases: |
| for key in ("base_default_conf", "default_conf"): |
| update = getattr(base, key, {}) |
| if isinstance(update, dict): |
| update = OmegaConf.create(update) |
| total_conf = OmegaConf.merge(total_conf, update) |
| return dict(base_default_conf=total_conf) |
|
|
|
|
| class BaseModel(nn.Module, metaclass=MetaModel): |
| """ |
| What the child model is expect to declare: |
| default_conf: dictionary of the default configuration of the model. |
| It recursively updates the default_conf of all parent classes, and |
| it is updated by the user-provided configuration passed to __init__. |
| Configurations can be nested. |
| |
| required_data_keys: list of expected keys in the input data dictionary. |
| |
| strict_conf (optional): boolean. If false, BaseModel does not raise |
| an error when the user provides an unknown configuration entry. |
| |
| _init(self, conf): initialization method, where conf is the final |
| configuration object (also accessible with `self.conf`). Accessing |
| unknown configuration entries will raise an error. |
| |
| _forward(self, data): method that returns a dictionary of batched |
| prediction tensors based on a dictionary of batched input data tensors. |
| |
| loss(self, pred, data): method that returns a dictionary of losses, |
| computed from model predictions and input data. Each loss is a batch |
| of scalars, i.e. a torch.Tensor of shape (B,). |
| The total loss to be optimized has the key `'total'`. |
| |
| metrics(self, pred, data): method that returns a dictionary of metrics, |
| each as a batch of scalars. |
| """ |
|
|
| default_conf = { |
| "name": None, |
| "trainable": True, |
| "freeze_batch_normalization": False, |
| "timeit": False, |
| "watch": False, |
| } |
| required_data_keys = [] |
| strict_conf = False |
|
|
| def __init__(self, conf): |
| """Perform some logic and call the _init method of the child model.""" |
| super().__init__() |
| default_conf = OmegaConf.merge(self.base_default_conf, OmegaConf.create(self.default_conf)) |
| if self.strict_conf: |
| OmegaConf.set_struct(default_conf, True) |
|
|
| |
| if "pad" in conf and "pad" not in default_conf: |
| with omegaconf.read_write(conf): |
| with omegaconf.open_dict(conf): |
| conf["interpolation"] = {"pad": conf.pop("pad")} |
|
|
| if isinstance(conf, dict): |
| conf = OmegaConf.create(conf) |
| self.conf = conf = OmegaConf.merge(default_conf, conf) |
| OmegaConf.set_readonly(conf, True) |
| OmegaConf.set_struct(conf, True) |
| self.required_data_keys = copy(self.required_data_keys) |
| self._init(conf) |
|
|
| |
| if "weights" in conf and conf.weights is not None: |
| logger.info(f"Loading checkpoint {conf.weights}") |
| ckpt = torch.load(str(conf.weights), map_location="cpu", weights_only=False) |
| weights_key = "model" if "model" in ckpt else "state_dict" |
| self.flexible_load(ckpt[weights_key]) |
|
|
| if not conf.trainable: |
| for p in self.parameters(): |
| p.requires_grad = False |
|
|
| if conf.watch: |
| try: |
| wandb.watch(self, log="all", log_graph=True, log_freq=10) |
| logger.info(f"Watching {self.__class__.__name__}.") |
| except ValueError: |
| logger.warning(f"Could not watch {self.__class__.__name__}.") |
|
|
| n_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| logger.info(f"Creating model {self.__class__.__name__} ({n_trainable/1e6:.2f} Mio)") |
|
|
| def flexible_load(self, state_dict): |
| """TODO: fix a probable nasty bug, and move to BaseModel.""" |
| |
| for key in list(state_dict.keys()): |
| if "gravity" in key: |
| new_key = key.replace("gravity", "up") |
| state_dict[new_key] = state_dict.pop(key) |
| |
|
|
| |
| for key in list(state_dict.keys()): |
| if "linear_pred_latitude" in key or "linear_pred_up" in key: |
| continue |
|
|
| if "_head" in key and "_head.decoder" not in key: |
| |
| pattern = r"_head\.\d+" |
| if re.search(pattern, key): |
| continue |
| new_key = key.replace("_head.", "_head.decoder.") |
| state_dict[new_key] = state_dict.pop(key) |
| |
|
|
| dict_params = set(state_dict.keys()) |
| model_params = set(map(lambda n: n[0], self.named_parameters())) |
|
|
| if dict_params == model_params: |
| logger.info("Loading all parameters of the checkpoint.") |
| self.load_state_dict(state_dict, strict=True) |
| return |
| elif len(dict_params & model_params) == 0: |
| strip_prefix = lambda x: ".".join(x.split(".")[:1] + x.split(".")[2:]) |
| state_dict = {strip_prefix(n): p for n, p in state_dict.items()} |
| dict_params = set(state_dict.keys()) |
| if len(dict_params & model_params) == 0: |
| raise ValueError( |
| "Could not manage to load the checkpoint with" |
| "parameters:" + "\n\t".join(sorted(dict_params)) |
| ) |
| common_params = dict_params & model_params |
| left_params = dict_params - model_params |
| left_params = [ |
| p for p in left_params if "running" not in p and "num_batches_tracked" not in p |
| ] |
| logger.debug("Loading parameters:\n\t" + "\n\t".join(sorted(common_params))) |
| if left_params: |
| |
| logger.warning("Could not load parameters:\n\t" + "\n\t".join(sorted(left_params))) |
| self.load_state_dict(state_dict, strict=False) |
|
|
| def train(self, mode=True): |
| super().train(mode) |
|
|
| def freeze_bn(module): |
| if isinstance(module, nn.modules.batchnorm._BatchNorm): |
| module.eval() |
|
|
| if self.conf.freeze_batch_normalization: |
| self.apply(freeze_bn) |
|
|
| return self |
|
|
| def forward(self, data): |
| """Check the data and call the _forward method of the child model.""" |
|
|
| def recursive_key_check(expected, given): |
| for key in expected: |
| assert key in given, f"Missing key {key} in data: {list(given.keys())}" |
| if isinstance(expected, dict): |
| recursive_key_check(expected[key], given[key]) |
|
|
| recursive_key_check(self.required_data_keys, data) |
| return self._forward(data) |
|
|
| @abstractmethod |
| def _init(self, conf): |
| """To be implemented by the child class.""" |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def _forward(self, data): |
| """To be implemented by the child class.""" |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def loss(self, pred, data): |
| """To be implemented by the child class.""" |
| raise NotImplementedError |
|
|