Pulmo / modeling.py
ariyul's picture
Add Pulmo: 2.5D concept-bottleneck student model distilled from 3D teacher for lung nodule detection, malignancy classification, and segmentation on LUNA16/LIDC-IDRI
8e23aec verified
"""
Pulmo — 2.5D Concept-Bottleneck Multi-task model for lung nodule analysis.
Self-contained model definition. The weights in `student_2p5d_best.pth` were
produced by online knowledge distillation from a 3D teacher (see model card).
The module keys here MUST match the checkpoint exactly:
cnn.* -> 2D U-Net backbone (shared trunk)
detection_head.* -> binary nodule / non-nodule
concept_head.* -> 8 LIDC radiological concepts (regression)
malignancy_head.*-> Linear(8 -> 2) (the concept bottleneck)
cnn.final.* -> segmentation logits of the middle slice
Input : (B, n_slices, 64, 64) float32 in [0, 1] (n_slices = 7 axial slices)
Output: dict with keys 'detection', 'concepts', 'malignancy', 'segmentation'
Only `torch` is required.
"""
import torch
import torch.nn as nn
CONCEPT_NAMES = [
"subtlety", "internalStructure", "calcification", "sphericity",
"margin", "lobulation", "spiculation", "texture",
]
class ResBlock2D(nn.Module):
def __init__(self, i, o):
super().__init__()
self.conv1 = nn.Conv2d(i, o, 3, padding=1, bias=False)
self.norm1 = nn.InstanceNorm2d(o)
self.conv2 = nn.Conv2d(o, o, 3, padding=1, bias=False)
self.norm2 = nn.InstanceNorm2d(o)
self.act = nn.LeakyReLU(0.1, inplace=True)
self.skip = nn.Conv2d(i, o, 1, bias=False) if i != o else nn.Identity()
def forward(self, x):
idt = self.skip(x)
out = self.act(self.norm1(self.conv1(x)))
out = self.norm2(self.conv2(out))
return self.act(out + idt)
class UNet2D(nn.Module):
def __init__(self, in_channels, base=24):
super().__init__()
self.stem = ResBlock2D(in_channels, base)
self.down1 = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base, base * 2))
self.down2 = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base * 2, base * 4))
self.down3 = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base * 4, base * 8))
self.bottom = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base * 8, base * 16))
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.up4 = nn.ConvTranspose2d(base * 16, base * 8, 2, 2)
self.dec4 = ResBlock2D(base * 16, base * 8)
self.up3 = nn.ConvTranspose2d(base * 8, base * 4, 2, 2)
self.dec3 = ResBlock2D(base * 8, base * 4)
self.up2 = nn.ConvTranspose2d(base * 4, base * 2, 2, 2)
self.dec2 = ResBlock2D(base * 4, base * 2)
self.up1 = nn.ConvTranspose2d(base * 2, base, 2, 2)
self.dec1 = ResBlock2D(base * 2, base)
self.final = nn.Conv2d(base, 1, 1)
self.out_dim = base * 16
def forward(self, x):
s0 = self.stem(x)
s1 = self.down1(s0)
s2 = self.down2(s1)
s3 = self.down3(s2)
b = self.bottom(s3)
gf = self.global_pool(b).flatten(1)
u4 = self.up4(b); d4 = self.dec4(torch.cat([u4, s3], 1))
u3 = self.up3(d4); d3 = self.dec3(torch.cat([u3, s2], 1))
u2 = self.up2(d3); d2 = self.dec2(torch.cat([u2, s1], 1))
u1 = self.up1(d2); d1 = self.dec1(torch.cat([u1, s0], 1))
return gf, self.final(d1)
class Student2p5D(nn.Module):
"""2.5D Concept-Bottleneck multi-task model (the released `Pulmo` model)."""
def __init__(self, n_slices=7, n_concepts=8, base=24, head_dropout=0.1):
super().__init__()
self.n_slices = n_slices
self.n_concepts = n_concepts
self.cnn = UNet2D(n_slices, base=base)
cd = self.cnn.out_dim
self.detection_head = nn.Sequential(
nn.LayerNorm(cd), nn.Linear(cd, 256), nn.GELU(),
nn.Dropout(head_dropout), nn.Linear(256, 2),
)
self.concept_head = nn.Sequential(
nn.LayerNorm(cd), nn.Linear(cd, 256), nn.GELU(),
nn.Dropout(0.3), nn.Linear(256, n_concepts),
)
# Concept bottleneck: malignancy is predicted ONLY from the 8 concepts.
self.malignancy_head = nn.Linear(n_concepts, 2)
def forward(self, x):
gf, seg = self.cnn(x)
concepts = self.concept_head(gf)
return {
"detection": self.detection_head(gf), # (B, 2)
"concepts": concepts, # (B, 8)
"malignancy": self.malignancy_head(concepts), # (B, 2)
"segmentation": seg, # (B, 1, 64, 64)
}
def load_pulmo(ckpt_path, device="cpu", n_slices=7, n_concepts=8, base=24):
"""Build the model and load weights from `student_2p5d_best.pth`."""
model = Student2p5D(n_slices=n_slices, n_concepts=n_concepts, base=base).to(device)
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
state = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt
model.load_state_dict(state, strict=True)
model.eval()
return model
if __name__ == "__main__":
m = Student2p5D()
n = sum(p.numel() for p in m.parameters()) / 1e6
print(f"Pulmo (Student2p5D): {n:.2f}M params")
out = m(torch.randn(2, 7, 64, 64))
for k, v in out.items():
print(f" {k:13s}: {tuple(v.shape)}")