File size: 2,713 Bytes
c14d03d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from pathlib import Path
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils.torch_utilities import (
    load_pretrained_model, merge_matched_keys, create_mask_from_length,
    loss_with_mask, create_alignment_path
)


class LoadPretrainedBase(nn.Module):
    def process_state_dict(
        self, model_dict: dict[str, torch.Tensor],
        state_dict: dict[str, torch.Tensor]
    ):
        """
        Custom processing functions of each model that transforms `state_dict` loaded from 
        checkpoints to the state that can be used in `load_state_dict`.
        Use `merge_mathced_keys` to update parameters with matched names and shapes by 
        default.  

        Args
            model_dict:
                The state dict of the current model, which is going to load pretrained parameters
            state_dict:
                A dictionary of parameters from a pre-trained model.

            Returns:
                dict[str, torch.Tensor]:
                    The updated state dict, where parameters with matched keys and shape are 
                    updated with values in `state_dict`.      
        """
        state_dict = merge_matched_keys(model_dict, state_dict)
        return state_dict

    def load_pretrained(self, ckpt_path: str | Path):
        load_pretrained_model(
            self, ckpt_path, state_dict_process_fn=self.process_state_dict
        )


class CountParamsBase(nn.Module):
    def count_params(self):
        num_params = 0
        trainable_params = 0
        for param in self.parameters():
            num_params += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
        return num_params, trainable_params


class SaveTrainableParamsBase(nn.Module):
    @property
    def param_names_to_save(self):
        names = []
        for name, param in self.named_parameters():
            if param.requires_grad:
                names.append(name)
        for name, _ in self.named_buffers():
            names.append(name)
        return names

    def load_state_dict(self, state_dict, strict=True):
        missing_keys = []
        for key in self.param_names_to_save:
            if key not in state_dict:
                missing_keys.append(key)

        if strict and len(missing_keys) > 0:
            raise Exception(
                f"{missing_keys} not found in either pre-trained models (e.g. BERT) or resumed checkpoints (e.g. epoch_40/model.pt)"
            )
        elif len(missing_keys) > 0:
            print(f"Warning: missing keys {missing_keys}, skipping them.")
        
        return super().load_state_dict(state_dict, strict)