| | """ |
| | 3D ResNet for tornado detection + prediction. |
| | |
| | Configs: |
| | ResNet3D-18: BasicBlock, [2,2,2,2] (~11M params) |
| | ResNet3D-34: BasicBlock, [3,4,6,3] (~21M params) |
| | ResNet3D-50: Bottleneck, [3,4,6,3] (~40M params) |
| | |
| | Input: (B, 24, 8, 128, 128) — 24 dual-pol channels, 8 time frames, 128x128 grid |
| | Output: (B, 4) — [det_neg, det_pos, pred_neg, pred_pos] |
| | """ |
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | class BasicBlock3D(nn.Module): |
| | expansion = 1 |
| |
|
| | def __init__(self, in_planes, planes, stride=1, downsample=None): |
| | super().__init__() |
| | self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=3, stride=stride, |
| | padding=1, bias=False) |
| | self.bn1 = nn.BatchNorm3d(planes) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=1, |
| | padding=1, bias=False) |
| | self.bn2 = nn.BatchNorm3d(planes) |
| | self.downsample = downsample |
| |
|
| | def forward(self, x): |
| | identity = x |
| | out = self.relu(self.bn1(self.conv1(x))) |
| | out = self.bn2(self.conv2(out)) |
| | if self.downsample is not None: |
| | identity = self.downsample(x) |
| | out += identity |
| | return self.relu(out) |
| |
|
| |
|
| | class Bottleneck3D(nn.Module): |
| | expansion = 4 |
| |
|
| | def __init__(self, in_planes, planes, stride=1, downsample=None): |
| | super().__init__() |
| | self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=1, bias=False) |
| | self.bn1 = nn.BatchNorm3d(planes) |
| | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, |
| | padding=1, bias=False) |
| | self.bn2 = nn.BatchNorm3d(planes) |
| | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False) |
| | self.bn3 = nn.BatchNorm3d(planes * self.expansion) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.downsample = downsample |
| |
|
| | def forward(self, x): |
| | identity = x |
| | out = self.relu(self.bn1(self.conv1(x))) |
| | out = self.relu(self.bn2(self.conv2(out))) |
| | out = self.bn3(self.conv3(out)) |
| | if self.downsample is not None: |
| | identity = self.downsample(x) |
| | out += identity |
| | return self.relu(out) |
| |
|
| |
|
| | class ResNet3D(nn.Module): |
| | """3D ResNet backbone. Returns feature vector of size 512 * block.expansion.""" |
| |
|
| | def __init__(self, block, layers, in_channels=24): |
| | super().__init__() |
| | self.in_planes = 64 |
| |
|
| | |
| | self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=(3, 7, 7), |
| | stride=(1, 2, 2), padding=(1, 3, 3), bias=False) |
| | self.bn1 = nn.BatchNorm3d(64) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) |
| |
|
| | |
| | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) |
| | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) |
| | self.layer3 = self._make_layer(block, 256, layers[2], stride=(2, 2, 2)) |
| | self.layer4 = self._make_layer(block, 512, layers[3], stride=(2, 2, 2)) |
| |
|
| | self.avgpool = nn.AdaptiveAvgPool3d(1) |
| | self.feat_dim = 512 * block.expansion |
| |
|
| | |
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv3d): |
| | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
| | elif isinstance(m, nn.BatchNorm3d): |
| | nn.init.constant_(m.weight, 1) |
| | nn.init.constant_(m.bias, 0) |
| |
|
| | def _make_layer(self, block, planes, num_blocks, stride=1): |
| | downsample = None |
| | if stride != 1 or self.in_planes != planes * block.expansion: |
| | if isinstance(stride, int): |
| | s = stride |
| | else: |
| | s = stride |
| | downsample = nn.Sequential( |
| | nn.Conv3d(self.in_planes, planes * block.expansion, |
| | kernel_size=1, stride=s, bias=False), |
| | nn.BatchNorm3d(planes * block.expansion), |
| | ) |
| |
|
| | layers = [block(self.in_planes, planes, stride, downsample)] |
| | self.in_planes = planes * block.expansion |
| | for _ in range(1, num_blocks): |
| | layers.append(block(self.in_planes, planes)) |
| | return nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | |
| | x = self.relu(self.bn1(self.conv1(x))) |
| | x = self.maxpool(x) |
| | x = self.layer1(x) |
| | x = self.layer2(x) |
| | x = self.layer3(x) |
| | x = self.layer4(x) |
| | x = self.avgpool(x) |
| | return x.flatten(1) |
| |
|
| |
|
| | class DualHeadResNet3D(nn.Module): |
| | """Dual-head wrapper: detection + prediction heads on shared ResNet3D backbone.""" |
| |
|
| | def __init__(self, block, layers, in_channels=24, drop_rate=0.3): |
| | super().__init__() |
| | self.backbone = ResNet3D(block, layers, in_channels) |
| | feat_dim = self.backbone.feat_dim |
| |
|
| | self.dropout = nn.Dropout(drop_rate) |
| | self.detect_head = nn.Linear(feat_dim, 2) |
| | self.predict_head = nn.Linear(feat_dim, 2) |
| |
|
| | |
| | for head in [self.detect_head, self.predict_head]: |
| | nn.init.normal_(head.weight, std=0.01) |
| | nn.init.zeros_(head.bias) |
| |
|
| | def forward(self, x): |
| | |
| | features = self.backbone(x) |
| | |
| | features = features.float() |
| | features = self.dropout(features) |
| | det = self.detect_head(features) |
| | pred = self.predict_head(features) |
| | return torch.cat([det, pred], dim=1) |
| |
|
| |
|
| | |
| |
|
| | CONFIGS = { |
| | "resnet18": {"block": BasicBlock3D, "layers": [2, 2, 2, 2]}, |
| | "resnet34": {"block": BasicBlock3D, "layers": [3, 4, 6, 3]}, |
| | "resnet50": {"block": Bottleneck3D, "layers": [3, 4, 6, 3]}, |
| | } |
| |
|
| |
|
| | def build_resnet3d(config="resnet34", in_channels=24, drop_rate=0.3): |
| | """Build a DualHeadResNet3D from config name.""" |
| | cfg = CONFIGS[config] |
| | return DualHeadResNet3D(cfg["block"], cfg["layers"], in_channels, drop_rate) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("=== ResNet3D Model Configs ===\n") |
| |
|
| | for name in ["resnet18", "resnet34", "resnet50"]: |
| | model = build_resnet3d(name) |
| | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | n_backbone = sum(p.numel() for p in model.backbone.parameters() if p.requires_grad) |
| | print(f"{name}: {n_params:>12,} total params ({n_backbone:,} backbone)") |
| |
|
| | |
| | print("\nForward pass test (resnet34)...") |
| | model = build_resnet3d("resnet34") |
| | x = torch.randn(2, 24, 8, 128, 128) |
| | with torch.no_grad(): |
| | out = model(x) |
| | print(f" Input: {tuple(x.shape)}") |
| | print(f" Output: {tuple(out.shape)} (expected (2, 4))") |
| | assert out.shape == (2, 4), f"Expected (2, 4), got {out.shape}" |
| | print(" PASSED") |
| |
|