File size: 7,502 Bytes
6202bfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
3D ResNet for tornado detection + prediction.

Configs:
  ResNet3D-18:  BasicBlock, [2,2,2,2]  (~11M params)
  ResNet3D-34:  BasicBlock, [3,4,6,3]  (~21M params)
  ResNet3D-50:  Bottleneck, [3,4,6,3]  (~40M params)

Input: (B, 24, 8, 128, 128) — 24 dual-pol channels, 8 time frames, 128x128 grid
Output: (B, 4) — [det_neg, det_pos, pred_neg, pred_pos]
"""
import torch
import torch.nn as nn


class BasicBlock3D(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)


class Bottleneck3D(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)


class ResNet3D(nn.Module):
    """3D ResNet backbone. Returns feature vector of size 512 * block.expansion."""

    def __init__(self, block, layers, in_channels=24):
        super().__init__()
        self.in_planes = 64

        # Initial conv: don't downsample time aggressively
        self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=(3, 7, 7),
                               stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))

        # Residual layers — spatial downsampling in layers 2-4, temporal in layer 3
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)           # spatial /2
        self.layer3 = self._make_layer(block, 256, layers[2], stride=(2, 2, 2))   # temporal /2, spatial /2
        self.layer4 = self._make_layer(block, 512, layers[3], stride=(2, 2, 2))   # temporal /2, spatial /2

        self.avgpool = nn.AdaptiveAvgPool3d(1)
        self.feat_dim = 512 * block.expansion

        # Weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, num_blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            if isinstance(stride, int):
                s = stride
            else:
                s = stride
            downsample = nn.Sequential(
                nn.Conv3d(self.in_planes, planes * block.expansion,
                          kernel_size=1, stride=s, bias=False),
                nn.BatchNorm3d(planes * block.expansion),
            )

        layers = [block(self.in_planes, planes, stride, downsample)]
        self.in_planes = planes * block.expansion
        for _ in range(1, num_blocks):
            layers.append(block(self.in_planes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        # x: (B, C, T, H, W) = (B, 24, 8, 128, 128)
        x = self.relu(self.bn1(self.conv1(x)))   # (B, 64, 8, 64, 64)
        x = self.maxpool(x)                       # (B, 64, 8, 32, 32)
        x = self.layer1(x)                        # (B, 64, 8, 32, 32)
        x = self.layer2(x)                        # (B, 128, 8, 16, 16)
        x = self.layer3(x)                        # (B, 256, 4, 8, 8)
        x = self.layer4(x)                        # (B, 512, 2, 4, 4)
        x = self.avgpool(x)                       # (B, 512, 1, 1, 1)
        return x.flatten(1)                        # (B, 512)


class DualHeadResNet3D(nn.Module):
    """Dual-head wrapper: detection + prediction heads on shared ResNet3D backbone."""

    def __init__(self, block, layers, in_channels=24, drop_rate=0.3):
        super().__init__()
        self.backbone = ResNet3D(block, layers, in_channels)
        feat_dim = self.backbone.feat_dim

        self.dropout = nn.Dropout(drop_rate)
        self.detect_head = nn.Linear(feat_dim, 2)
        self.predict_head = nn.Linear(feat_dim, 2)

        # Init heads
        for head in [self.detect_head, self.predict_head]:
            nn.init.normal_(head.weight, std=0.01)
            nn.init.zeros_(head.bias)

    def forward(self, x):
        # x: (B, C, T, H, W)
        features = self.backbone(x)          # (B, feat_dim)
        # FP32 cast before heads to prevent Inf grads under AMP
        features = features.float()
        features = self.dropout(features)
        det = self.detect_head(features)     # (B, 2)
        pred = self.predict_head(features)   # (B, 2)
        return torch.cat([det, pred], dim=1) # (B, 4)


# --- Factory functions ---

CONFIGS = {
    "resnet18": {"block": BasicBlock3D, "layers": [2, 2, 2, 2]},
    "resnet34": {"block": BasicBlock3D, "layers": [3, 4, 6, 3]},
    "resnet50": {"block": Bottleneck3D, "layers": [3, 4, 6, 3]},
}


def build_resnet3d(config="resnet34", in_channels=24, drop_rate=0.3):
    """Build a DualHeadResNet3D from config name."""
    cfg = CONFIGS[config]
    return DualHeadResNet3D(cfg["block"], cfg["layers"], in_channels, drop_rate)


if __name__ == "__main__":
    print("=== ResNet3D Model Configs ===\n")

    for name in ["resnet18", "resnet34", "resnet50"]:
        model = build_resnet3d(name)
        n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        n_backbone = sum(p.numel() for p in model.backbone.parameters() if p.requires_grad)
        print(f"{name}: {n_params:>12,} total params  ({n_backbone:,} backbone)")

    # Forward pass test
    print("\nForward pass test (resnet34)...")
    model = build_resnet3d("resnet34")
    x = torch.randn(2, 24, 8, 128, 128)
    with torch.no_grad():
        out = model(x)
    print(f"  Input:  {tuple(x.shape)}")
    print(f"  Output: {tuple(out.shape)}  (expected (2, 4))")
    assert out.shape == (2, 4), f"Expected (2, 4), got {out.shape}"
    print("  PASSED")