weather_predict / models /cnn_multi_frame.py
jeffliulab's picture
Initial deploy: real-time weather forecast demo
e22f65c
"""
Multi-frame 2D CNN for weather forecasting.
Stacks k consecutive spatial snapshots along the channel dimension,
allowing the model to learn temporal patterns with standard 2D convolutions.
Input: (B, k*C, H, W)
Output: (B, 6)
"""
import torch
import torch.nn as nn
from .cnn_baseline import ResBlock
class MultiFrameCNN(nn.Module):
"""
Multi-frame 2D CNN that concatenates consecutive frames along channels.
Identical backbone to BaselineCNN but with an adapted stem for k*C input channels,
plus a temporal mixing layer after the stem.
"""
def __init__(self, n_input_channels=42, n_targets=6, n_frames=4, base_channels=64):
super().__init__()
self.n_frames = n_frames
in_ch = n_input_channels * n_frames
ch = base_channels
self.stem = nn.Sequential(
nn.Conv2d(in_ch, ch * 2, 7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(ch * 2),
nn.ReLU(inplace=True),
nn.Conv2d(ch * 2, ch, 1, bias=False),
nn.BatchNorm2d(ch),
nn.ReLU(inplace=True),
)
self.layer1 = ResBlock(ch, ch, stride=1)
self.layer2 = ResBlock(ch, ch * 2, stride=2)
self.layer3 = ResBlock(ch * 2, ch * 4, stride=2)
self.layer4 = ResBlock(ch * 4, ch * 4, stride=2)
self.layer5 = ResBlock(ch * 4, ch * 8, stride=2)
self.layer6 = ResBlock(ch * 8, ch * 8, stride=2)
self.pool = nn.AdaptiveAvgPool2d(1)
self.head = nn.Sequential(
nn.Flatten(),
nn.Linear(ch * 8, ch * 2),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(ch * 2, n_targets),
)
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
x = self.pool(x)
return self.head(x)