|
|
from pathlib import Path |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from utils.torch_utilities import load_pretrained_model, merge_matched_keys |
|
|
|
|
|
|
|
|
class LoadPretrainedBase(nn.Module): |
|
|
def process_state_dict( |
|
|
self, model_dict: dict[str, torch.Tensor], |
|
|
state_dict: dict[str, torch.Tensor] |
|
|
): |
|
|
""" |
|
|
Custom processing functions of each model that transforms `state_dict` loaded from |
|
|
checkpoints to the state that can be used in `load_state_dict`. |
|
|
Use `merge_mathced_keys` to update parameters with matched names and shapes by |
|
|
default. |
|
|
|
|
|
Args |
|
|
model_dict: |
|
|
The state dict of the current model, which is going to load pretrained parameters |
|
|
state_dict: |
|
|
A dictionary of parameters from a pre-trained model. |
|
|
|
|
|
Returns: |
|
|
dict[str, torch.Tensor]: |
|
|
The updated state dict, where parameters with matched keys and shape are |
|
|
updated with values in `state_dict`. |
|
|
""" |
|
|
state_dict = merge_matched_keys(model_dict, state_dict) |
|
|
return state_dict |
|
|
|
|
|
def load_pretrained(self, ckpt_path: str | Path): |
|
|
load_pretrained_model( |
|
|
self, ckpt_path, state_dict_process_fn=self.process_state_dict |
|
|
) |
|
|
|
|
|
|
|
|
class CountParamsBase(nn.Module): |
|
|
def count_params(self): |
|
|
num_params = 0 |
|
|
trainable_params = 0 |
|
|
for param in self.parameters(): |
|
|
num_params += param.numel() |
|
|
if param.requires_grad: |
|
|
trainable_params += param.numel() |
|
|
return num_params, trainable_params |
|
|
|
|
|
|
|
|
class SaveTrainableParamsBase(nn.Module): |
|
|
@property |
|
|
def param_names_to_save(self): |
|
|
names = [] |
|
|
for name, param in self.named_parameters(): |
|
|
if param.requires_grad: |
|
|
names.append(name) |
|
|
for name, _ in self.named_buffers(): |
|
|
names.append(name) |
|
|
return names |
|
|
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
|
for key in self.param_names_to_save: |
|
|
if key not in state_dict: |
|
|
raise Exception( |
|
|
f"{key} not found in either pre-trained models (e.g. BERT)" |
|
|
" or resumed checkpoints (e.g. epoch_40/model.pt)" |
|
|
) |
|
|
return super().load_state_dict(state_dict, strict) |
|
|
|