""" 3D ResNet for tornado detection + prediction. Configs: ResNet3D-18: BasicBlock, [2,2,2,2] (~11M params) ResNet3D-34: BasicBlock, [3,4,6,3] (~21M params) ResNet3D-50: Bottleneck, [3,4,6,3] (~40M params) Input: (B, 24, 8, 128, 128) — 24 dual-pol channels, 8 time frames, 128x128 grid Output: (B, 4) — [det_neg, det_pos, pred_neg, pred_pos] """ import torch import torch.nn as nn class BasicBlock3D(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1, downsample=None): super().__init__() self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm3d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm3d(planes) self.downsample = downsample def forward(self, x): identity = x out = self.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) if self.downsample is not None: identity = self.downsample(x) out += identity return self.relu(out) class Bottleneck3D(nn.Module): expansion = 4 def __init__(self, in_planes, planes, stride=1, downsample=None): super().__init__() self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm3d(planes) self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm3d(planes) self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm3d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample def forward(self, x): identity = x out = self.relu(self.bn1(self.conv1(x))) out = self.relu(self.bn2(self.conv2(out))) out = self.bn3(self.conv3(out)) if self.downsample is not None: identity = self.downsample(x) out += identity return self.relu(out) class ResNet3D(nn.Module): """3D ResNet backbone. Returns feature vector of size 512 * block.expansion.""" def __init__(self, block, layers, in_channels=24): super().__init__() self.in_planes = 64 # Initial conv: don't downsample time aggressively self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False) self.bn1 = nn.BatchNorm3d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) # Residual layers — spatial downsampling in layers 2-4, temporal in layer 3 self.layer1 = self._make_layer(block, 64, layers[0], stride=1) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # spatial /2 self.layer3 = self._make_layer(block, 256, layers[2], stride=(2, 2, 2)) # temporal /2, spatial /2 self.layer4 = self._make_layer(block, 512, layers[3], stride=(2, 2, 2)) # temporal /2, spatial /2 self.avgpool = nn.AdaptiveAvgPool3d(1) self.feat_dim = 512 * block.expansion # Weight initialization for m in self.modules(): if isinstance(m, nn.Conv3d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.BatchNorm3d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def _make_layer(self, block, planes, num_blocks, stride=1): downsample = None if stride != 1 or self.in_planes != planes * block.expansion: if isinstance(stride, int): s = stride else: s = stride downsample = nn.Sequential( nn.Conv3d(self.in_planes, planes * block.expansion, kernel_size=1, stride=s, bias=False), nn.BatchNorm3d(planes * block.expansion), ) layers = [block(self.in_planes, planes, stride, downsample)] self.in_planes = planes * block.expansion for _ in range(1, num_blocks): layers.append(block(self.in_planes, planes)) return nn.Sequential(*layers) def forward(self, x): # x: (B, C, T, H, W) = (B, 24, 8, 128, 128) x = self.relu(self.bn1(self.conv1(x))) # (B, 64, 8, 64, 64) x = self.maxpool(x) # (B, 64, 8, 32, 32) x = self.layer1(x) # (B, 64, 8, 32, 32) x = self.layer2(x) # (B, 128, 8, 16, 16) x = self.layer3(x) # (B, 256, 4, 8, 8) x = self.layer4(x) # (B, 512, 2, 4, 4) x = self.avgpool(x) # (B, 512, 1, 1, 1) return x.flatten(1) # (B, 512) class DualHeadResNet3D(nn.Module): """Dual-head wrapper: detection + prediction heads on shared ResNet3D backbone.""" def __init__(self, block, layers, in_channels=24, drop_rate=0.3): super().__init__() self.backbone = ResNet3D(block, layers, in_channels) feat_dim = self.backbone.feat_dim self.dropout = nn.Dropout(drop_rate) self.detect_head = nn.Linear(feat_dim, 2) self.predict_head = nn.Linear(feat_dim, 2) # Init heads for head in [self.detect_head, self.predict_head]: nn.init.normal_(head.weight, std=0.01) nn.init.zeros_(head.bias) def forward(self, x): # x: (B, C, T, H, W) features = self.backbone(x) # (B, feat_dim) # FP32 cast before heads to prevent Inf grads under AMP features = features.float() features = self.dropout(features) det = self.detect_head(features) # (B, 2) pred = self.predict_head(features) # (B, 2) return torch.cat([det, pred], dim=1) # (B, 4) # --- Factory functions --- CONFIGS = { "resnet18": {"block": BasicBlock3D, "layers": [2, 2, 2, 2]}, "resnet34": {"block": BasicBlock3D, "layers": [3, 4, 6, 3]}, "resnet50": {"block": Bottleneck3D, "layers": [3, 4, 6, 3]}, } def build_resnet3d(config="resnet34", in_channels=24, drop_rate=0.3): """Build a DualHeadResNet3D from config name.""" cfg = CONFIGS[config] return DualHeadResNet3D(cfg["block"], cfg["layers"], in_channels, drop_rate) if __name__ == "__main__": print("=== ResNet3D Model Configs ===\n") for name in ["resnet18", "resnet34", "resnet50"]: model = build_resnet3d(name) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) n_backbone = sum(p.numel() for p in model.backbone.parameters() if p.requires_grad) print(f"{name}: {n_params:>12,} total params ({n_backbone:,} backbone)") # Forward pass test print("\nForward pass test (resnet34)...") model = build_resnet3d("resnet34") x = torch.randn(2, 24, 8, 128, 128) with torch.no_grad(): out = model(x) print(f" Input: {tuple(x.shape)}") print(f" Output: {tuple(out.shape)} (expected (2, 4))") assert out.shape == (2, 4), f"Expected (2, 4), got {out.shape}" print(" PASSED")