import torch import torch.nn as nn import torchvision.models as models class ResNet3D18(nn.Module): def __init__(self, num_classes=1): super().__init__() # Use torchvision's 3D ResNet-18 self.model = models.video.r3d_18(weights=None) # Modify first conv layer to accept 1 channel instead of 3 old_conv = self.model.stem[0] self.model.stem[0] = nn.Conv3d( 1, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding, bias=False ) # Modify final fully connected layer num_ftrs = self.model.fc.in_features self.model.fc = nn.Linear(num_ftrs, num_classes) def forward(self, nodule_patch, context_patch=None): # We only use the nodule patch for the baseline # input shape: (B, 1, D, H, W) return self.model(nodule_patch) @torch.no_grad() def predict_with_uncertainty(self, nodule_patch, context_patch=None, mc_passes=5): """Monte Carlo Dropout uncertainty estimation. Args: nodule_patch: (B, 1, D, H, W) context_patch: (B, 1, D, H, W) or None mc_passes: int, number of forward passes """ self.train() # Enable dropout preds = [] for _ in range(mc_passes): logits = self.forward(nodule_patch, context_patch) prob = torch.sigmoid(logits.squeeze(-1)) preds.append(prob) preds = torch.stack(preds, dim=0) mean_prob = preds.mean(dim=0) variance = preds.var(dim=0) confidence = 1.0 - (variance / 0.25).clamp(0, 1) self.eval() return mean_prob, confidence class ResNet2D18SliceLevel(nn.Module): def __init__(self, num_classes=1): super().__init__() # Use torchvision's 2D ResNet-18 self.backbone = models.resnet18(weights=None) # Modify first conv to accept 1 channel instead of 3 old_conv = self.backbone.conv1 self.backbone.conv1 = nn.Conv2d( 1, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding, bias=False ) # Remove original fc layer, use GAP instead num_ftrs = self.backbone.fc.in_features self.backbone.fc = nn.Identity() # Final classification head after pooling slices self.head = nn.Sequential( nn.Linear(num_ftrs, 128), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(128, num_classes) ) def forward(self, nodule_patch, context_patch=None): # input shape: (B, 1, D, H, W) B, C, D, H, W = nodule_patch.shape # Reshape to treat depth as batch: (B*D, 1, H, W) slices = nodule_patch.squeeze(1).view(B*D, 1, H, W) # Extract features per slice: (B*D, num_ftrs) slice_feats = self.backbone(slices) # Reshape back to (B, D, num_ftrs) slice_feats = slice_feats.view(B, D, -1) # Global Average Pooling over the depth dimension (slices) -> (B, num_ftrs) pooled_feats = slice_feats.mean(dim=1) # Classification return self.head(pooled_feats) @torch.no_grad() def predict_with_uncertainty(self, nodule_patch, context_patch=None, mc_passes=5): """Monte Carlo Dropout uncertainty estimation. Args: nodule_patch: (B, 1, D, H, W) context_patch: (B, 1, D, H, W) or None mc_passes: int, number of forward passes """ self.train() # Enable dropout preds = [] for _ in range(mc_passes): logits = self.forward(nodule_patch, context_patch) prob = torch.sigmoid(logits.squeeze(-1)) preds.append(prob) preds = torch.stack(preds, dim=0) mean_prob = preds.mean(dim=0) variance = preds.var(dim=0) confidence = 1.0 - (variance / 0.25).clamp(0, 1) self.eval() return mean_prob, confidence def get_baseline_model(model_name): if model_name == 'resnet3d18': return ResNet3D18(num_classes=1) elif model_name == 'resnet2d18': return ResNet2D18SliceLevel(num_classes=1) else: raise ValueError(f"Unknown baseline model: {model_name}")