""" Multi-frame 2D CNN for weather forecasting. Stacks k consecutive spatial snapshots along the channel dimension, allowing the model to learn temporal patterns with standard 2D convolutions. Input: (B, k*C, H, W) Output: (B, 6) """ import torch import torch.nn as nn from .cnn_baseline import ResBlock class MultiFrameCNN(nn.Module): """ Multi-frame 2D CNN that concatenates consecutive frames along channels. Identical backbone to BaselineCNN but with an adapted stem for k*C input channels, plus a temporal mixing layer after the stem. """ def __init__(self, n_input_channels=42, n_targets=6, n_frames=4, base_channels=64): super().__init__() self.n_frames = n_frames in_ch = n_input_channels * n_frames ch = base_channels self.stem = nn.Sequential( nn.Conv2d(in_ch, ch * 2, 7, stride=2, padding=3, bias=False), nn.BatchNorm2d(ch * 2), nn.ReLU(inplace=True), nn.Conv2d(ch * 2, ch, 1, bias=False), nn.BatchNorm2d(ch), nn.ReLU(inplace=True), ) self.layer1 = ResBlock(ch, ch, stride=1) self.layer2 = ResBlock(ch, ch * 2, stride=2) self.layer3 = ResBlock(ch * 2, ch * 4, stride=2) self.layer4 = ResBlock(ch * 4, ch * 4, stride=2) self.layer5 = ResBlock(ch * 4, ch * 8, stride=2) self.layer6 = ResBlock(ch * 8, ch * 8, stride=2) self.pool = nn.AdaptiveAvgPool2d(1) self.head = nn.Sequential( nn.Flatten(), nn.Linear(ch * 8, ch * 2), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(ch * 2, n_targets), ) def forward(self, x): x = self.stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.layer5(x) x = self.layer6(x) x = self.pool(x) return self.head(x)