English
climate
File size: 6,005 Bytes
ec86bf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d09d703
ec86bf7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""Neural network architectures."""

from typing import Optional

import netCDF4 as nc  # type: ignore
import torch
from torch import nn, Tensor


class ANN(nn.Sequential):
    """Model used in the paper.

    Paper: https://doi.org/10.1029/2020GL091363


    Parameters
    ----------
    n_in : int
        Number of input features.
    n_out : int
        Number of output features.
    n_layers : int
        Number of layers.
    neurons : int
        The number of neurons in the hidden layers.
    dropout : float
        The dropout probability to apply in the hidden layers.
    device : str
        The device to put the model on.
    features_mean : ndarray
        The mean of the input features.
    features_std : ndarray
        The standard deviation of the input features.
    outputs_mean : ndarray
        The mean of the output features.
    outputs_std : ndarray
        The standard deviation of the output features.
    output_groups : ndarray
        The number of output features in each group of the ouput.

    Notes
    -----
    If you are doing inference, always remember to put the model in eval model,
    by using ``model.eval()``, so the dropout layers are turned off.

    """

    def __init__(  # pylint: disable=too-many-arguments,too-many-locals
        self,
        n_in: int = 61,
        n_out: int = 148,
        n_layers: int = 5,
        neurons: int = 128,
        dropout: float = 0.0,
        device: str = "cpu",
        features_mean: Optional[Tensor] = None,
        features_std: Optional[Tensor] = None,
        outputs_mean: Optional[Tensor] = None,
        outputs_std: Optional[Tensor] = None,
        output_groups: Optional[list] = None,
    ):
        """Initialize the ANN model."""
        dims = [n_in] + [neurons] * (n_layers - 1) + [n_out]
        layers = []

        for i in range(n_layers):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            if i < n_layers - 1:
                layers.append(nn.ReLU())  # type: ignore
                layers.append(nn.Dropout(dropout))  # type: ignore

        super().__init__(*layers)

        fmean = fstd = omean = ostd = None

        if features_mean is not None:
            assert features_std is not None
            assert len(features_mean) == len(features_std)
            fmean = torch.tensor(features_mean)
            fstd = torch.tensor(features_std)

        if outputs_mean is not None:
            assert outputs_std is not None
            assert len(outputs_mean) == len(outputs_std)
            if output_groups is None:
                omean = torch.tensor(outputs_mean)
                ostd = torch.tensor(outputs_std)
            else:
                assert len(output_groups) == len(outputs_mean)
                omean = torch.tensor(
                    [x for x, g in zip(outputs_mean, output_groups) for _ in range(g)]
                )
                ostd = torch.tensor(
                    [x for x, g in zip(outputs_std, output_groups) for _ in range(g)]
                )

        self.register_buffer("features_mean", fmean)
        self.register_buffer("features_std", fstd)
        self.register_buffer("outputs_mean", omean)
        self.register_buffer("outputs_std", ostd)

        self.to(torch.device(device))

    def forward(self, input: Tensor):  # pylint: disable=redefined-builtin
        """Pass the input through the model.

        Override the forward method of nn.Sequential to add normalization
        to the input and denormalization to the output.

        Parameters
        ----------
        input : Tensor
            A mini-batch of inputs.

        Returns
        -------
        Tensor
            The model output.

        """
        if self.features_mean is not None:
            input = (input - self.features_mean) / self.features_std

        # pass the input through the layers using nn.Sequential.forward
        output = super().forward(input)

        if self.outputs_mean is not None:
            output = output * self.outputs_std + self.outputs_mean

        return output

    def load(self, path: str) -> "ANN":
        """Load the model from a checkpoint.

        Parameters
        ----------
        path : str
            The path to the checkpoint.

        """
        state = torch.load(path)
        for key in ["features_mean", "features_std", "outputs_mean", "outputs_std"]:
            if key in state and getattr(self, key) is None:
                setattr(self, key, state[key])
        self.load_state_dict(state)
        return self

    def save(self, path: str):
        """Save the model to a checkpoint.

        Parameters
        ----------
        path : str
            The path to save the checkpoint to.

        """
        torch.save(self.state_dict(), path)


def load_from_netcdf_params(nc_file: str, dtype: str = "float32") -> ANN:
    """Load the model with weights and biases from the netcdf file.

    Parameters
    ----------
    nc_file : str
        The netcdf file containing the parameters.
    dtype : str
        The data type to cast the parameters to.

    """
    data_set = nc.Dataset(nc_file)  # pylint: disable=no-member

    model = ANN(
        features_mean=data_set["fscale_mean"][:].astype(dtype),
        features_std=data_set["fscale_stnd"][:].astype(dtype),
        outputs_mean=data_set["oscale_mean"][:].astype(dtype),
        outputs_std=data_set["oscale_stnd"][:].astype(dtype),
        output_groups=[30, 29, 29, 30, 30],
    )

    for i, layer in enumerate(l for l in model.modules() if isinstance(l, nn.Linear)):
        layer.weight.data = torch.tensor(data_set[f"w{i+1}"][:].astype(dtype))
        layer.bias.data = torch.tensor(data_set[f"b{i+1}"][:].astype(dtype))

    return model


if __name__ == "__main__":
    # Load the model from the netcdf file and save it to a checkpoint.
    net = load_from_netcdf_params(
        "NN_weights_YOG_convection.nc"
    )
    net.save("nn_state.pt")
    print("Model saved to nn_state.pt")