irene / convgru_ensemble /model.py
franch's picture
Add source code and examples
df27dfb verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualConvBlock(nn.Module):
"""
Residual convolutional block with two convolutions and a skip connection.
Applies two 2D convolutions with a ReLU activation in between. If the
input and output channel counts differ, a 1x1 projection is used for the
residual path.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel_size : int, optional
Kernel size for both convolutions. Default is ``3``.
padding : int, optional
Padding for both convolutions. Default is ``1``.
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
"""
Initialize ResidualConvBlock.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel_size : int, optional
Kernel size for both convolutions. Default is ``3``.
padding : int, optional
Padding for both convolutions. Default is ``1``.
"""
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding)
if in_channels != out_channels:
self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1)
else:
self.proj = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the residual convolutional block.
Parameters
----------
x : torch.Tensor
Input tensor of shape ``(B, C_in, H, W)``.
Returns
-------
out : torch.Tensor
Output tensor of shape ``(B, C_out, H, W)``.
"""
residual = x
out = F.relu(self.conv1(x))
out = self.conv2(out)
if self.proj is not None:
residual = self.proj(residual)
out += residual
out = F.relu(out)
return out
class ConvGRUCell(nn.Module):
"""
Convolutional GRU cell operating on 2D spatial grids.
Implements a single-step GRU update where all linear projections are
replaced by 2D convolutions, preserving spatial structure.
Parameters
----------
input_size : int
Number of channels in the input tensor.
hidden_size : int
Number of channels in the hidden state.
kernel_size : int, optional
Kernel size for the convolutional gates. Default is ``3``.
conv_layer : nn.Module, optional
Convolutional layer class to use. Default is ``nn.Conv2d``.
"""
def __init__(self, input_size: int, hidden_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d):
"""
Initialize ConvGRUCell.
Parameters
----------
input_size : int
Number of channels in the input tensor.
hidden_size : int
Number of channels in the hidden state.
kernel_size : int, optional
Kernel size for the convolutional gates. Default is ``3``.
conv_layer : nn.Module, optional
Convolutional layer class to use. Default is ``nn.Conv2d``.
"""
super().__init__()
padding = kernel_size // 2
self.input_size = input_size
self.hidden_size = hidden_size
# update and reset gates are combined for optimization
self.combined_gates = conv_layer(input_size + hidden_size, 2 * hidden_size, kernel_size, padding=padding)
self.out_gate = conv_layer(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
def forward(self, inpt: torch.Tensor | None = None, h_s: torch.Tensor | None = None) -> torch.Tensor:
"""
Forward the ConvGRU cell for a single timestep.
If either input is ``None``, it is initialized to zeros based on the
shape of the other. If both are ``None``, a ``ValueError`` is raised.
Parameters
----------
inpt : torch.Tensor or None, optional
Input tensor of shape ``(B, input_size, H, W)``. Default is
``None``.
h_s : torch.Tensor or None, optional
Hidden state tensor of shape ``(B, hidden_size, H, W)``. Default
is ``None``.
Returns
-------
new_state : torch.Tensor
Updated hidden state of shape ``(B, hidden_size, H, W)``.
Raises
------
ValueError
If both ``inpt`` and ``h_s`` are ``None``.
"""
if h_s is None and inpt is None:
raise ValueError("Both input and state can't be None")
elif h_s is None:
h_s = torch.zeros(
inpt.size(0), self.hidden_size, inpt.size(2), inpt.size(3), dtype=inpt.dtype, device=inpt.device
)
elif inpt is None:
inpt = torch.zeros(
h_s.size(0), self.input_size, h_s.size(2), h_s.size(3), dtype=h_s.dtype, device=h_s.device
)
gamma, beta = torch.chunk(self.combined_gates(torch.cat([inpt, h_s], dim=1)), 2, dim=1)
update = torch.sigmoid(gamma)
reset = torch.sigmoid(beta)
out_inputs = torch.tanh(self.out_gate(torch.cat([inpt, h_s * reset], dim=1)))
new_state = h_s * (1 - update) + out_inputs * update
return new_state
class ConvGRU(nn.Module):
"""
Convolutional GRU that unrolls a :class:`ConvGRUCell` over a sequence.
Parameters
----------
input_size : int
Number of channels in the input tensor.
hidden_size : int
Number of channels in the hidden state.
kernel_size : int, optional
Kernel size for the convolutional gates. Default is ``3``.
conv_layer : nn.Module, optional
Convolutional layer class to use. Default is ``nn.Conv2d``.
"""
def __init__(self, input_size: int, hidden_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d):
"""
Initialize ConvGRU.
Parameters
----------
input_size : int
Number of channels in the input tensor.
hidden_size : int
Number of channels in the hidden state.
kernel_size : int, optional
Kernel size for the convolutional gates. Default is ``3``.
conv_layer : nn.Module, optional
Convolutional layer class to use. Default is ``nn.Conv2d``.
"""
super().__init__()
self.cell = ConvGRUCell(input_size, hidden_size, kernel_size, conv_layer)
def forward(self, x: torch.Tensor | None = None, h: torch.Tensor | None = None) -> torch.Tensor:
"""
Unroll the ConvGRU cell over the sequence (time) dimension.
.. code-block:: text
x[:, 0] x[:, 1]
| |
v v
*------* *------*
h --> | Cell | --> h_0 --> | Cell | --> h_1 ...
*------* *------*
If either input is ``None``, it is initialized to zeros based on the
shape of the other. If both are ``None``, a ``ValueError`` is raised.
Parameters
----------
x : torch.Tensor or None, optional
Input tensor of shape ``(B, T, input_size, H, W)``. Default is
``None``.
h : torch.Tensor or None, optional
Initial hidden state of shape ``(B, hidden_size, H, W)``. Default
is ``None``.
Returns
-------
hidden_states : torch.Tensor
Stacked hidden states of shape ``(B, T, hidden_size, H, W)``,
i.e. ``[h_0, h_1, h_2, ...]``.
"""
h_s = []
for i in range(x.size(1)):
h = self.cell(x[:, i], h)
h_s.append(h)
return torch.stack(h_s, dim=1)
class EncoderBlock(nn.Module):
"""
ConvGRU-based encoder block with spatial downsampling.
Applies a :class:`ConvGRU` followed by ``nn.PixelUnshuffle(2)`` to
halve spatial dimensions and quadruple channels.
Parameters
----------
input_size : int
Number of input channels.
kernel_size : int, optional
Kernel size for the ConvGRU. Default is ``3``.
conv_layer : nn.Module, optional
Convolutional layer class to use. Default is ``nn.Conv2d``.
"""
def __init__(self, input_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d):
"""
Initialize EncoderBlock.
Parameters
----------
input_size : int
Number of input channels.
kernel_size : int, optional
Kernel size for the ConvGRU. Default is ``3``.
conv_layer : nn.Module, optional
Convolutional layer class to use. Default is ``nn.Conv2d``.
"""
super().__init__()
self.convgru = ConvGRU(input_size, input_size, kernel_size, conv_layer)
self.down = nn.PixelUnshuffle(2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward the encoder block.
Parameters
----------
x : torch.Tensor
Input tensor of shape ``(B, T, C, H, W)``.
Returns
-------
out : torch.Tensor
Downsampled tensor of shape ``(B, T, C*4, H/2, W/2)``.
"""
x = self.convgru(x)
x = self.down(x)
return x
class Encoder(nn.Module):
r"""
ConvGRU-based encoder that stacks multiple :class:`EncoderBlock` layers.
After each block the spatial resolution is halved via pixel-unshuffle.
.. code-block:: text
/// Encoder Block 1 \\\ /// Encoder Block 2 \\\
/--------------------------------------------\ /---------------------------------------\
| | |
* *---------* *-----------------* * *---------* *-----------------* *
X -> | ConvGRU | ---> | Pixel Unshuffle | ---> | ConvGRU | ---> | Pixel Unshuffle | ---> ...
| *---------* | *-----------------* | *---------* | *-----------------* |
v v v v v
(b,t,c,h,w) (b,t,c,h,w) (b,t,c*4,h/2,w/2) (b,t,c*4,h/2,w/2) (b,t,c*16,h/4,w/4)
Parameters
----------
input_channels : int, optional
Number of input channels. Default is ``1``.
num_blocks : int, optional
Number of encoder blocks to stack. Default is ``4``.
**kwargs
Additional keyword arguments forwarded to each :class:`EncoderBlock`.
"""
def __init__(self, input_channels: int = 1, num_blocks: int = 4, **kwargs):
"""
Initialize Encoder.
Parameters
----------
input_channels : int, optional
Number of input channels. Default is ``1``.
num_blocks : int, optional
Number of encoder blocks to stack. Default is ``4``.
**kwargs
Additional keyword arguments forwarded to each
:class:`EncoderBlock`.
"""
super().__init__()
self.channel_sizes = [input_channels * 4**i for i in range(num_blocks)] # [1, 4, 16, 64]
self.blocks = nn.ModuleList([EncoderBlock(self.channel_sizes[i], **kwargs) for i in range(num_blocks)])
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""
Forward the encoder through all blocks.
Parameters
----------
x : torch.Tensor
Input tensor of shape ``(B, T, C, H, W)``.
Returns
-------
hidden_states : list of torch.Tensor
Hidden state tensors from each block, with progressively reduced
spatial dimensions:
``[(B, T, C*4, H/2, W/2), (B, T, C*16, H/4, W/4), ...]``.
"""
hidden_states = []
for block in self.blocks:
x = block(x)
hidden_states.append(x)
return hidden_states
class DecoderBlock(nn.Module):
"""
ConvGRU-based decoder block with spatial upsampling.
Applies a :class:`ConvGRU` followed by ``nn.PixelShuffle(2)`` to double
spatial dimensions and quarter channels.
Parameters
----------
input_size : int
Number of input channels.
hidden_size : int
Number of hidden channels for the ConvGRU.
kernel_size : int, optional
Kernel size for the ConvGRU. Default is ``3``.
conv_layer : nn.Module, optional
Convolutional layer class to use. Default is ``nn.Conv2d``.
"""
def __init__(self, input_size: int, hidden_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d):
"""
Initialize DecoderBlock.
Parameters
----------
input_size : int
Number of input channels.
hidden_size : int
Number of hidden channels for the ConvGRU.
kernel_size : int, optional
Kernel size for the ConvGRU. Default is ``3``.
conv_layer : nn.Module, optional
Convolutional layer class to use. Default is ``nn.Conv2d``.
"""
super().__init__()
self.convgru = ConvGRU(input_size, hidden_size, kernel_size, conv_layer)
self.up = nn.PixelShuffle(2)
def forward(self, x: torch.Tensor, hidden_state: torch.Tensor) -> torch.Tensor:
"""
Forward the decoder block.
Parameters
----------
x : torch.Tensor
Input tensor of shape ``(B, T, C, H, W)``.
hidden_state : torch.Tensor
Hidden state from the corresponding encoder block, of shape
``(B, hidden_size, H, W)``.
Returns
-------
out : torch.Tensor
Upsampled tensor of shape ``(B, T, hidden_size // 4, H*2, W*2)``.
"""
x = self.convgru(x, hidden_state)
x = self.up(x)
return x
class Decoder(nn.Module):
r"""
ConvGRU-based decoder that stacks multiple :class:`DecoderBlock` layers.
After each block the spatial resolution is doubled via pixel-shuffle.
Hidden sizes are computed from the desired output channels.
Parameters
----------
output_channels : int, optional
Number of output channels. Default is ``1``.
num_blocks : int, optional
Number of decoder blocks to stack. Default is ``4``.
**kwargs
Additional keyword arguments forwarded to each :class:`DecoderBlock`.
"""
def __init__(self, output_channels: int = 1, num_blocks: int = 4, **kwargs):
"""
Initialize Decoder.
Parameters
----------
output_channels : int, optional
Number of output channels. Default is ``1``.
num_blocks : int, optional
Number of decoder blocks to stack. Default is ``4``.
**kwargs
Additional keyword arguments forwarded to each
:class:`DecoderBlock`.
"""
super().__init__()
self.channel_sizes = [output_channels * 4 ** (i + 1) for i in reversed(range(num_blocks))] # [256, 64, 16, 4]
self.blocks = nn.ModuleList(
[DecoderBlock(self.channel_sizes[i], self.channel_sizes[i], **kwargs) for i in range(num_blocks)]
)
def forward(self, x: torch.Tensor, hidden_states: list[torch.Tensor]) -> torch.Tensor:
"""
Forward the decoder through all blocks.
Parameters
----------
x : torch.Tensor
Input tensor of shape ``(B, T, C, H, W)``.
hidden_states : list of torch.Tensor
Hidden states from the encoder (in reverse order), one per block.
Returns
-------
out : torch.Tensor
Output tensor of shape
``(B, T, output_channels, H * 2^num_blocks, W * 2^num_blocks)``.
"""
for block, hidden_state in zip(self.blocks, hidden_states, strict=True):
x = block(x, hidden_state)
return x
class EncoderDecoder(nn.Module):
"""
Full encoder-decoder model for spatio-temporal forecasting.
Encodes an input sequence into multi-scale hidden states and decodes
them into a forecast sequence, optionally generating multiple ensemble
members via noisy decoder inputs.
Parameters
----------
channels : int, optional
Number of input/output channels. Default is ``1``.
num_blocks : int, optional
Number of encoder and decoder blocks. Default is ``4``.
**kwargs
Additional keyword arguments forwarded to :class:`Encoder` and
:class:`Decoder`.
"""
def __init__(self, channels: int = 1, num_blocks: int = 4, **kwargs):
"""
Initialize EncoderDecoder.
Parameters
----------
channels : int, optional
Number of input/output channels. Default is ``1``.
num_blocks : int, optional
Number of encoder and decoder blocks. Default is ``4``.
**kwargs
Additional keyword arguments forwarded to :class:`Encoder` and
:class:`Decoder`.
"""
super().__init__()
self.encoder = Encoder(channels, num_blocks, **kwargs)
self.decoder = Decoder(channels, num_blocks, **kwargs)
def forward(self, x: torch.Tensor, steps: int, noisy_decoder: bool = False, ensemble_size: int = 1) -> torch.Tensor:
"""
Forward the encoder-decoder model.
Parameters
----------
x : torch.Tensor
Input tensor of shape ``(B, T, C, H, W)``.
steps : int
Number of future timesteps to forecast.
noisy_decoder : bool, optional
If ``True``, feed random noise (instead of zeros) as input to the
decoder. Default is ``False``.
ensemble_size : int, optional
Number of ensemble members to generate. When ``> 1``, the decoder
is always run with noisy inputs. Default is ``1``.
Returns
-------
preds : torch.Tensor
Forecast tensor. Shape is ``(B, steps, C, H, W)`` when
``ensemble_size == 1``, or
``(B, steps, ensemble_size * C, H, W)`` when ``ensemble_size > 1``
(for C=1, this is ``(B, steps, ensemble_size, H, W)``).
"""
# encode the input tensor into a sequence of hidden states
encoded = self.encoder(x)
# create a tensor with the same shape as the last hidden state of the encoder to use as a input for the decoder
x_dec_shape = list(encoded[-1].shape)
# set the desired number of timestep for the output
x_dec_shape[1] = steps
# collect all the last hidden states of the encoder blocks in reverse order
last_hidden_per_block = [e[:, -1] for e in reversed(encoded)]
if ensemble_size > 1:
# Generate M ensemble members by running decoder M times with different noise
preds = []
for _ in range(ensemble_size):
# the input will be random noise for each ensemble member
x_dec = torch.randn(x_dec_shape, dtype=encoded[-1].dtype, device=encoded[-1].device)
# decode (unroll) the input hidden states into a forecast sequence of N timesteps
decoded = self.decoder(x_dec, last_hidden_per_block)
preds.append(decoded)
# stack along channel/ensemble dimension: (B, T, M, H, W)
return torch.cat(preds, dim=2)
else:
# the input will be of random values if noisy_decoder is True, otherwise with zeros
x_dec_func = torch.randn if noisy_decoder else torch.zeros
# create the input tensor for the decoder
x_dec = x_dec_func(x_dec_shape, dtype=encoded[-1].dtype, device=encoded[-1].device)
# decode (unroll) the input hidden states into a forecast sequence of N timesteps
decoded = self.decoder(x_dec, last_hidden_per_block)
return decoded