import torch import torch.nn as nn import torch.nn.functional as F class ResidualConvBlock(nn.Module): """ Residual convolutional block with two convolutions and a skip connection. Applies two 2D convolutions with a ReLU activation in between. If the input and output channel counts differ, a 1x1 projection is used for the residual path. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. kernel_size : int, optional Kernel size for both convolutions. Default is ``3``. padding : int, optional Padding for both convolutions. Default is ``1``. """ def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1): """ Initialize ResidualConvBlock. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. kernel_size : int, optional Kernel size for both convolutions. Default is ``3``. padding : int, optional Padding for both convolutions. Default is ``1``. """ super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding) if in_channels != out_channels: self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1) else: self.proj = None def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the residual convolutional block. Parameters ---------- x : torch.Tensor Input tensor of shape ``(B, C_in, H, W)``. Returns ------- out : torch.Tensor Output tensor of shape ``(B, C_out, H, W)``. """ residual = x out = F.relu(self.conv1(x)) out = self.conv2(out) if self.proj is not None: residual = self.proj(residual) out += residual out = F.relu(out) return out class ConvGRUCell(nn.Module): """ Convolutional GRU cell operating on 2D spatial grids. Implements a single-step GRU update where all linear projections are replaced by 2D convolutions, preserving spatial structure. Parameters ---------- input_size : int Number of channels in the input tensor. hidden_size : int Number of channels in the hidden state. kernel_size : int, optional Kernel size for the convolutional gates. Default is ``3``. conv_layer : nn.Module, optional Convolutional layer class to use. Default is ``nn.Conv2d``. """ def __init__(self, input_size: int, hidden_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d): """ Initialize ConvGRUCell. Parameters ---------- input_size : int Number of channels in the input tensor. hidden_size : int Number of channels in the hidden state. kernel_size : int, optional Kernel size for the convolutional gates. Default is ``3``. conv_layer : nn.Module, optional Convolutional layer class to use. Default is ``nn.Conv2d``. """ super().__init__() padding = kernel_size // 2 self.input_size = input_size self.hidden_size = hidden_size # update and reset gates are combined for optimization self.combined_gates = conv_layer(input_size + hidden_size, 2 * hidden_size, kernel_size, padding=padding) self.out_gate = conv_layer(input_size + hidden_size, hidden_size, kernel_size, padding=padding) def forward(self, inpt: torch.Tensor | None = None, h_s: torch.Tensor | None = None) -> torch.Tensor: """ Forward the ConvGRU cell for a single timestep. If either input is ``None``, it is initialized to zeros based on the shape of the other. If both are ``None``, a ``ValueError`` is raised. Parameters ---------- inpt : torch.Tensor or None, optional Input tensor of shape ``(B, input_size, H, W)``. Default is ``None``. h_s : torch.Tensor or None, optional Hidden state tensor of shape ``(B, hidden_size, H, W)``. Default is ``None``. Returns ------- new_state : torch.Tensor Updated hidden state of shape ``(B, hidden_size, H, W)``. Raises ------ ValueError If both ``inpt`` and ``h_s`` are ``None``. """ if h_s is None and inpt is None: raise ValueError("Both input and state can't be None") elif h_s is None: h_s = torch.zeros( inpt.size(0), self.hidden_size, inpt.size(2), inpt.size(3), dtype=inpt.dtype, device=inpt.device ) elif inpt is None: inpt = torch.zeros( h_s.size(0), self.input_size, h_s.size(2), h_s.size(3), dtype=h_s.dtype, device=h_s.device ) gamma, beta = torch.chunk(self.combined_gates(torch.cat([inpt, h_s], dim=1)), 2, dim=1) update = torch.sigmoid(gamma) reset = torch.sigmoid(beta) out_inputs = torch.tanh(self.out_gate(torch.cat([inpt, h_s * reset], dim=1))) new_state = h_s * (1 - update) + out_inputs * update return new_state class ConvGRU(nn.Module): """ Convolutional GRU that unrolls a :class:`ConvGRUCell` over a sequence. Parameters ---------- input_size : int Number of channels in the input tensor. hidden_size : int Number of channels in the hidden state. kernel_size : int, optional Kernel size for the convolutional gates. Default is ``3``. conv_layer : nn.Module, optional Convolutional layer class to use. Default is ``nn.Conv2d``. """ def __init__(self, input_size: int, hidden_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d): """ Initialize ConvGRU. Parameters ---------- input_size : int Number of channels in the input tensor. hidden_size : int Number of channels in the hidden state. kernel_size : int, optional Kernel size for the convolutional gates. Default is ``3``. conv_layer : nn.Module, optional Convolutional layer class to use. Default is ``nn.Conv2d``. """ super().__init__() self.cell = ConvGRUCell(input_size, hidden_size, kernel_size, conv_layer) def forward(self, x: torch.Tensor | None = None, h: torch.Tensor | None = None) -> torch.Tensor: """ Unroll the ConvGRU cell over the sequence (time) dimension. .. code-block:: text x[:, 0] x[:, 1] | | v v *------* *------* h --> | Cell | --> h_0 --> | Cell | --> h_1 ... *------* *------* If either input is ``None``, it is initialized to zeros based on the shape of the other. If both are ``None``, a ``ValueError`` is raised. Parameters ---------- x : torch.Tensor or None, optional Input tensor of shape ``(B, T, input_size, H, W)``. Default is ``None``. h : torch.Tensor or None, optional Initial hidden state of shape ``(B, hidden_size, H, W)``. Default is ``None``. Returns ------- hidden_states : torch.Tensor Stacked hidden states of shape ``(B, T, hidden_size, H, W)``, i.e. ``[h_0, h_1, h_2, ...]``. """ h_s = [] for i in range(x.size(1)): h = self.cell(x[:, i], h) h_s.append(h) return torch.stack(h_s, dim=1) class EncoderBlock(nn.Module): """ ConvGRU-based encoder block with spatial downsampling. Applies a :class:`ConvGRU` followed by ``nn.PixelUnshuffle(2)`` to halve spatial dimensions and quadruple channels. Parameters ---------- input_size : int Number of input channels. kernel_size : int, optional Kernel size for the ConvGRU. Default is ``3``. conv_layer : nn.Module, optional Convolutional layer class to use. Default is ``nn.Conv2d``. """ def __init__(self, input_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d): """ Initialize EncoderBlock. Parameters ---------- input_size : int Number of input channels. kernel_size : int, optional Kernel size for the ConvGRU. Default is ``3``. conv_layer : nn.Module, optional Convolutional layer class to use. Default is ``nn.Conv2d``. """ super().__init__() self.convgru = ConvGRU(input_size, input_size, kernel_size, conv_layer) self.down = nn.PixelUnshuffle(2) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward the encoder block. Parameters ---------- x : torch.Tensor Input tensor of shape ``(B, T, C, H, W)``. Returns ------- out : torch.Tensor Downsampled tensor of shape ``(B, T, C*4, H/2, W/2)``. """ x = self.convgru(x) x = self.down(x) return x class Encoder(nn.Module): r""" ConvGRU-based encoder that stacks multiple :class:`EncoderBlock` layers. After each block the spatial resolution is halved via pixel-unshuffle. .. code-block:: text /// Encoder Block 1 \\\ /// Encoder Block 2 \\\ /--------------------------------------------\ /---------------------------------------\ | | | * *---------* *-----------------* * *---------* *-----------------* * X -> | ConvGRU | ---> | Pixel Unshuffle | ---> | ConvGRU | ---> | Pixel Unshuffle | ---> ... | *---------* | *-----------------* | *---------* | *-----------------* | v v v v v (b,t,c,h,w) (b,t,c,h,w) (b,t,c*4,h/2,w/2) (b,t,c*4,h/2,w/2) (b,t,c*16,h/4,w/4) Parameters ---------- input_channels : int, optional Number of input channels. Default is ``1``. num_blocks : int, optional Number of encoder blocks to stack. Default is ``4``. **kwargs Additional keyword arguments forwarded to each :class:`EncoderBlock`. """ def __init__(self, input_channels: int = 1, num_blocks: int = 4, **kwargs): """ Initialize Encoder. Parameters ---------- input_channels : int, optional Number of input channels. Default is ``1``. num_blocks : int, optional Number of encoder blocks to stack. Default is ``4``. **kwargs Additional keyword arguments forwarded to each :class:`EncoderBlock`. """ super().__init__() self.channel_sizes = [input_channels * 4**i for i in range(num_blocks)] # [1, 4, 16, 64] self.blocks = nn.ModuleList([EncoderBlock(self.channel_sizes[i], **kwargs) for i in range(num_blocks)]) def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """ Forward the encoder through all blocks. Parameters ---------- x : torch.Tensor Input tensor of shape ``(B, T, C, H, W)``. Returns ------- hidden_states : list of torch.Tensor Hidden state tensors from each block, with progressively reduced spatial dimensions: ``[(B, T, C*4, H/2, W/2), (B, T, C*16, H/4, W/4), ...]``. """ hidden_states = [] for block in self.blocks: x = block(x) hidden_states.append(x) return hidden_states class DecoderBlock(nn.Module): """ ConvGRU-based decoder block with spatial upsampling. Applies a :class:`ConvGRU` followed by ``nn.PixelShuffle(2)`` to double spatial dimensions and quarter channels. Parameters ---------- input_size : int Number of input channels. hidden_size : int Number of hidden channels for the ConvGRU. kernel_size : int, optional Kernel size for the ConvGRU. Default is ``3``. conv_layer : nn.Module, optional Convolutional layer class to use. Default is ``nn.Conv2d``. """ def __init__(self, input_size: int, hidden_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d): """ Initialize DecoderBlock. Parameters ---------- input_size : int Number of input channels. hidden_size : int Number of hidden channels for the ConvGRU. kernel_size : int, optional Kernel size for the ConvGRU. Default is ``3``. conv_layer : nn.Module, optional Convolutional layer class to use. Default is ``nn.Conv2d``. """ super().__init__() self.convgru = ConvGRU(input_size, hidden_size, kernel_size, conv_layer) self.up = nn.PixelShuffle(2) def forward(self, x: torch.Tensor, hidden_state: torch.Tensor) -> torch.Tensor: """ Forward the decoder block. Parameters ---------- x : torch.Tensor Input tensor of shape ``(B, T, C, H, W)``. hidden_state : torch.Tensor Hidden state from the corresponding encoder block, of shape ``(B, hidden_size, H, W)``. Returns ------- out : torch.Tensor Upsampled tensor of shape ``(B, T, hidden_size // 4, H*2, W*2)``. """ x = self.convgru(x, hidden_state) x = self.up(x) return x class Decoder(nn.Module): r""" ConvGRU-based decoder that stacks multiple :class:`DecoderBlock` layers. After each block the spatial resolution is doubled via pixel-shuffle. Hidden sizes are computed from the desired output channels. Parameters ---------- output_channels : int, optional Number of output channels. Default is ``1``. num_blocks : int, optional Number of decoder blocks to stack. Default is ``4``. **kwargs Additional keyword arguments forwarded to each :class:`DecoderBlock`. """ def __init__(self, output_channels: int = 1, num_blocks: int = 4, **kwargs): """ Initialize Decoder. Parameters ---------- output_channels : int, optional Number of output channels. Default is ``1``. num_blocks : int, optional Number of decoder blocks to stack. Default is ``4``. **kwargs Additional keyword arguments forwarded to each :class:`DecoderBlock`. """ super().__init__() self.channel_sizes = [output_channels * 4 ** (i + 1) for i in reversed(range(num_blocks))] # [256, 64, 16, 4] self.blocks = nn.ModuleList( [DecoderBlock(self.channel_sizes[i], self.channel_sizes[i], **kwargs) for i in range(num_blocks)] ) def forward(self, x: torch.Tensor, hidden_states: list[torch.Tensor]) -> torch.Tensor: """ Forward the decoder through all blocks. Parameters ---------- x : torch.Tensor Input tensor of shape ``(B, T, C, H, W)``. hidden_states : list of torch.Tensor Hidden states from the encoder (in reverse order), one per block. Returns ------- out : torch.Tensor Output tensor of shape ``(B, T, output_channels, H * 2^num_blocks, W * 2^num_blocks)``. """ for block, hidden_state in zip(self.blocks, hidden_states, strict=True): x = block(x, hidden_state) return x class EncoderDecoder(nn.Module): """ Full encoder-decoder model for spatio-temporal forecasting. Encodes an input sequence into multi-scale hidden states and decodes them into a forecast sequence, optionally generating multiple ensemble members via noisy decoder inputs. Parameters ---------- channels : int, optional Number of input/output channels. Default is ``1``. num_blocks : int, optional Number of encoder and decoder blocks. Default is ``4``. **kwargs Additional keyword arguments forwarded to :class:`Encoder` and :class:`Decoder`. """ def __init__(self, channels: int = 1, num_blocks: int = 4, **kwargs): """ Initialize EncoderDecoder. Parameters ---------- channels : int, optional Number of input/output channels. Default is ``1``. num_blocks : int, optional Number of encoder and decoder blocks. Default is ``4``. **kwargs Additional keyword arguments forwarded to :class:`Encoder` and :class:`Decoder`. """ super().__init__() self.encoder = Encoder(channels, num_blocks, **kwargs) self.decoder = Decoder(channels, num_blocks, **kwargs) def forward(self, x: torch.Tensor, steps: int, noisy_decoder: bool = False, ensemble_size: int = 1) -> torch.Tensor: """ Forward the encoder-decoder model. Parameters ---------- x : torch.Tensor Input tensor of shape ``(B, T, C, H, W)``. steps : int Number of future timesteps to forecast. noisy_decoder : bool, optional If ``True``, feed random noise (instead of zeros) as input to the decoder. Default is ``False``. ensemble_size : int, optional Number of ensemble members to generate. When ``> 1``, the decoder is always run with noisy inputs. Default is ``1``. Returns ------- preds : torch.Tensor Forecast tensor. Shape is ``(B, steps, C, H, W)`` when ``ensemble_size == 1``, or ``(B, steps, ensemble_size * C, H, W)`` when ``ensemble_size > 1`` (for C=1, this is ``(B, steps, ensemble_size, H, W)``). """ # encode the input tensor into a sequence of hidden states encoded = self.encoder(x) # create a tensor with the same shape as the last hidden state of the encoder to use as a input for the decoder x_dec_shape = list(encoded[-1].shape) # set the desired number of timestep for the output x_dec_shape[1] = steps # collect all the last hidden states of the encoder blocks in reverse order last_hidden_per_block = [e[:, -1] for e in reversed(encoded)] if ensemble_size > 1: # Generate M ensemble members by running decoder M times with different noise preds = [] for _ in range(ensemble_size): # the input will be random noise for each ensemble member x_dec = torch.randn(x_dec_shape, dtype=encoded[-1].dtype, device=encoded[-1].device) # decode (unroll) the input hidden states into a forecast sequence of N timesteps decoded = self.decoder(x_dec, last_hidden_per_block) preds.append(decoded) # stack along channel/ensemble dimension: (B, T, M, H, W) return torch.cat(preds, dim=2) else: # the input will be of random values if noisy_decoder is True, otherwise with zeros x_dec_func = torch.randn if noisy_decoder else torch.zeros # create the input tensor for the decoder x_dec = x_dec_func(x_dec_shape, dtype=encoded[-1].dtype, device=encoded[-1].device) # decode (unroll) the input hidden states into a forecast sequence of N timesteps decoded = self.decoder(x_dec, last_hidden_per_block) return decoded