""" Vision Transformer (ViT) for weather forecasting. Splits the spatial input into non-overlapping patches, projects each patch into an embedding, and processes them through a Transformer encoder. Input: (B, C, H, W) — single frame with C channels Output: (B, 6) """ import math import torch import torch.nn as nn class PatchEmbedding(nn.Module): """Convert spatial input into a sequence of patch embeddings.""" def __init__(self, in_channels, embed_dim, patch_size, img_h, img_w): super().__init__() self.patch_size = patch_size self.n_patches_h = img_h // patch_size self.n_patches_w = img_w // patch_size self.n_patches = self.n_patches_h * self.n_patches_w self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): # x: (B, C, H, W) -> (B, embed_dim, nH, nW) -> (B, n_patches, embed_dim) x = self.proj(x) x = x.flatten(2).transpose(1, 2) return x class TransformerBlock(nn.Module): """Standard Transformer encoder block with pre-norm.""" def __init__(self, embed_dim, n_heads, mlp_ratio=4.0, dropout=0.1): super().__init__() self.norm1 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention(embed_dim, n_heads, dropout=dropout, batch_first=True) self.norm2 = nn.LayerNorm(embed_dim) self.mlp = nn.Sequential( nn.Linear(embed_dim, int(embed_dim * mlp_ratio)), nn.GELU(), nn.Dropout(dropout), nn.Linear(int(embed_dim * mlp_ratio), embed_dim), nn.Dropout(dropout), ) def forward(self, x): h = self.norm1(x) h, _ = self.attn(h, h, h) x = x + h x = x + self.mlp(self.norm2(x)) return x class WeatherViT(nn.Module): """ Vision Transformer for weather forecasting. Input: (B, C, 450, 449) — pads width to 450 internally Output: (B, 6) Patches the input into 15x15 patches (30x30 = 900 tokens), adds CLS token and positional embeddings, runs through Transformer, and predicts from the CLS token. """ def __init__(self, n_input_channels=42, n_targets=6, patch_size=15, embed_dim=256, n_layers=6, n_heads=8, mlp_ratio=4.0, dropout=0.1, **kwargs): super().__init__() self.patch_size = patch_size img_h, img_w = 450, 450 # pad to square self.patch_embed = PatchEmbedding(n_input_channels, embed_dim, patch_size, img_h, img_w) n_patches = self.patch_embed.n_patches self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02) self.pos_embed = nn.Parameter(torch.randn(1, n_patches + 1, embed_dim) * 0.02) self.pos_drop = nn.Dropout(dropout) self.blocks = nn.Sequential(*[ TransformerBlock(embed_dim, n_heads, mlp_ratio, dropout) for _ in range(n_layers) ]) self.norm = nn.LayerNorm(embed_dim) self.head = nn.Sequential( nn.Linear(embed_dim, embed_dim // 2), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(embed_dim // 2, n_targets), ) def forward(self, x): B, C, H, W = x.shape # Pad width from 449 to 450 if needed if W < 450: x = nn.functional.pad(x, (0, 450 - W)) patches = self.patch_embed(x) # (B, n_patches, D) cls = self.cls_token.expand(B, -1, -1) # (B, 1, D) x = torch.cat([cls, patches], dim=1) # (B, n_patches+1, D) x = self.pos_drop(x + self.pos_embed) x = self.blocks(x) x = self.norm(x) cls_out = x[:, 0] # (B, D) return self.head(cls_out)