Spaces:
Sleeping
Sleeping
File size: 4,576 Bytes
8960670 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | import torch
import torch.nn as nn
import torchvision.models as models
class ResNet3D18(nn.Module):
def __init__(self, num_classes=1):
super().__init__()
# Use torchvision's 3D ResNet-18
self.model = models.video.r3d_18(weights=None)
# Modify first conv layer to accept 1 channel instead of 3
old_conv = self.model.stem[0]
self.model.stem[0] = nn.Conv3d(
1, old_conv.out_channels,
kernel_size=old_conv.kernel_size,
stride=old_conv.stride,
padding=old_conv.padding,
bias=False
)
# Modify final fully connected layer
num_ftrs = self.model.fc.in_features
self.model.fc = nn.Linear(num_ftrs, num_classes)
def forward(self, nodule_patch, context_patch=None):
# We only use the nodule patch for the baseline
# input shape: (B, 1, D, H, W)
return self.model(nodule_patch)
@torch.no_grad()
def predict_with_uncertainty(self, nodule_patch, context_patch=None, mc_passes=5):
"""Monte Carlo Dropout uncertainty estimation.
Args:
nodule_patch: (B, 1, D, H, W)
context_patch: (B, 1, D, H, W) or None
mc_passes: int, number of forward passes
"""
self.train() # Enable dropout
preds = []
for _ in range(mc_passes):
logits = self.forward(nodule_patch, context_patch)
prob = torch.sigmoid(logits.squeeze(-1))
preds.append(prob)
preds = torch.stack(preds, dim=0)
mean_prob = preds.mean(dim=0)
variance = preds.var(dim=0)
confidence = 1.0 - (variance / 0.25).clamp(0, 1)
self.eval()
return mean_prob, confidence
class ResNet2D18SliceLevel(nn.Module):
def __init__(self, num_classes=1):
super().__init__()
# Use torchvision's 2D ResNet-18
self.backbone = models.resnet18(weights=None)
# Modify first conv to accept 1 channel instead of 3
old_conv = self.backbone.conv1
self.backbone.conv1 = nn.Conv2d(
1, old_conv.out_channels,
kernel_size=old_conv.kernel_size,
stride=old_conv.stride,
padding=old_conv.padding,
bias=False
)
# Remove original fc layer, use GAP instead
num_ftrs = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
# Final classification head after pooling slices
self.head = nn.Sequential(
nn.Linear(num_ftrs, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(128, num_classes)
)
def forward(self, nodule_patch, context_patch=None):
# input shape: (B, 1, D, H, W)
B, C, D, H, W = nodule_patch.shape
# Reshape to treat depth as batch: (B*D, 1, H, W)
slices = nodule_patch.squeeze(1).view(B*D, 1, H, W)
# Extract features per slice: (B*D, num_ftrs)
slice_feats = self.backbone(slices)
# Reshape back to (B, D, num_ftrs)
slice_feats = slice_feats.view(B, D, -1)
# Global Average Pooling over the depth dimension (slices) -> (B, num_ftrs)
pooled_feats = slice_feats.mean(dim=1)
# Classification
return self.head(pooled_feats)
@torch.no_grad()
def predict_with_uncertainty(self, nodule_patch, context_patch=None, mc_passes=5):
"""Monte Carlo Dropout uncertainty estimation.
Args:
nodule_patch: (B, 1, D, H, W)
context_patch: (B, 1, D, H, W) or None
mc_passes: int, number of forward passes
"""
self.train() # Enable dropout
preds = []
for _ in range(mc_passes):
logits = self.forward(nodule_patch, context_patch)
prob = torch.sigmoid(logits.squeeze(-1))
preds.append(prob)
preds = torch.stack(preds, dim=0)
mean_prob = preds.mean(dim=0)
variance = preds.var(dim=0)
confidence = 1.0 - (variance / 0.25).clamp(0, 1)
self.eval()
return mean_prob, confidence
def get_baseline_model(model_name):
if model_name == 'resnet3d18':
return ResNet3D18(num_classes=1)
elif model_name == 'resnet2d18':
return ResNet2D18SliceLevel(num_classes=1)
else:
raise ValueError(f"Unknown baseline model: {model_name}")
|