File size: 4,576 Bytes
8960670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import torch.nn as nn
import torchvision.models as models

class ResNet3D18(nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()
        # Use torchvision's 3D ResNet-18
        self.model = models.video.r3d_18(weights=None)
        
        # Modify first conv layer to accept 1 channel instead of 3
        old_conv = self.model.stem[0]
        self.model.stem[0] = nn.Conv3d(
            1, old_conv.out_channels, 
            kernel_size=old_conv.kernel_size,
            stride=old_conv.stride,
            padding=old_conv.padding,
            bias=False
        )
        
        # Modify final fully connected layer
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, num_classes)
        
    def forward(self, nodule_patch, context_patch=None):
        # We only use the nodule patch for the baseline
        # input shape: (B, 1, D, H, W)
        return self.model(nodule_patch)

    @torch.no_grad()
    def predict_with_uncertainty(self, nodule_patch, context_patch=None, mc_passes=5):
        """Monte Carlo Dropout uncertainty estimation.
        
        Args:
            nodule_patch: (B, 1, D, H, W)
            context_patch: (B, 1, D, H, W) or None
            mc_passes: int, number of forward passes
        """
        self.train()  # Enable dropout
        
        preds = []
        for _ in range(mc_passes):
            logits = self.forward(nodule_patch, context_patch)
            prob = torch.sigmoid(logits.squeeze(-1))
            preds.append(prob)
            
        preds = torch.stack(preds, dim=0)
        mean_prob = preds.mean(dim=0)
        variance = preds.var(dim=0)
        confidence = 1.0 - (variance / 0.25).clamp(0, 1)
        
        self.eval()
        return mean_prob, confidence

class ResNet2D18SliceLevel(nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()
        # Use torchvision's 2D ResNet-18
        self.backbone = models.resnet18(weights=None)
        
        # Modify first conv to accept 1 channel instead of 3
        old_conv = self.backbone.conv1
        self.backbone.conv1 = nn.Conv2d(
            1, old_conv.out_channels,
            kernel_size=old_conv.kernel_size,
            stride=old_conv.stride,
            padding=old_conv.padding,
            bias=False
        )
        
        # Remove original fc layer, use GAP instead
        num_ftrs = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        
        # Final classification head after pooling slices
        self.head = nn.Sequential(
            nn.Linear(num_ftrs, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
        
    def forward(self, nodule_patch, context_patch=None):
        # input shape: (B, 1, D, H, W)
        B, C, D, H, W = nodule_patch.shape
        
        # Reshape to treat depth as batch: (B*D, 1, H, W)
        slices = nodule_patch.squeeze(1).view(B*D, 1, H, W)
        
        # Extract features per slice: (B*D, num_ftrs)
        slice_feats = self.backbone(slices)
        
        # Reshape back to (B, D, num_ftrs)
        slice_feats = slice_feats.view(B, D, -1)
        
        # Global Average Pooling over the depth dimension (slices) -> (B, num_ftrs)
        pooled_feats = slice_feats.mean(dim=1)
        
        # Classification
        return self.head(pooled_feats)

    @torch.no_grad()
    def predict_with_uncertainty(self, nodule_patch, context_patch=None, mc_passes=5):
        """Monte Carlo Dropout uncertainty estimation.
        
        Args:
            nodule_patch: (B, 1, D, H, W)
            context_patch: (B, 1, D, H, W) or None
            mc_passes: int, number of forward passes
        """
        self.train()  # Enable dropout
        
        preds = []
        for _ in range(mc_passes):
            logits = self.forward(nodule_patch, context_patch)
            prob = torch.sigmoid(logits.squeeze(-1))
            preds.append(prob)
            
        preds = torch.stack(preds, dim=0)
        mean_prob = preds.mean(dim=0)
        variance = preds.var(dim=0)
        confidence = 1.0 - (variance / 0.25).clamp(0, 1)
        
        self.eval()
        return mean_prob, confidence

def get_baseline_model(model_name):
    if model_name == 'resnet3d18':
        return ResNet3D18(num_classes=1)
    elif model_name == 'resnet2d18':
        return ResNet2D18SliceLevel(num_classes=1)
    else:
        raise ValueError(f"Unknown baseline model: {model_name}")