| """ |
| Pulmo — 2.5D Concept-Bottleneck Multi-task model for lung nodule analysis. |
| |
| Self-contained model definition. The weights in `student_2p5d_best.pth` were |
| produced by online knowledge distillation from a 3D teacher (see model card). |
| |
| The module keys here MUST match the checkpoint exactly: |
| cnn.* -> 2D U-Net backbone (shared trunk) |
| detection_head.* -> binary nodule / non-nodule |
| concept_head.* -> 8 LIDC radiological concepts (regression) |
| malignancy_head.*-> Linear(8 -> 2) (the concept bottleneck) |
| cnn.final.* -> segmentation logits of the middle slice |
| |
| Input : (B, n_slices, 64, 64) float32 in [0, 1] (n_slices = 7 axial slices) |
| Output: dict with keys 'detection', 'concepts', 'malignancy', 'segmentation' |
| |
| Only `torch` is required. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| CONCEPT_NAMES = [ |
| "subtlety", "internalStructure", "calcification", "sphericity", |
| "margin", "lobulation", "spiculation", "texture", |
| ] |
|
|
|
|
| class ResBlock2D(nn.Module): |
| def __init__(self, i, o): |
| super().__init__() |
| self.conv1 = nn.Conv2d(i, o, 3, padding=1, bias=False) |
| self.norm1 = nn.InstanceNorm2d(o) |
| self.conv2 = nn.Conv2d(o, o, 3, padding=1, bias=False) |
| self.norm2 = nn.InstanceNorm2d(o) |
| self.act = nn.LeakyReLU(0.1, inplace=True) |
| self.skip = nn.Conv2d(i, o, 1, bias=False) if i != o else nn.Identity() |
|
|
| def forward(self, x): |
| idt = self.skip(x) |
| out = self.act(self.norm1(self.conv1(x))) |
| out = self.norm2(self.conv2(out)) |
| return self.act(out + idt) |
|
|
|
|
| class UNet2D(nn.Module): |
| def __init__(self, in_channels, base=24): |
| super().__init__() |
| self.stem = ResBlock2D(in_channels, base) |
| self.down1 = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base, base * 2)) |
| self.down2 = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base * 2, base * 4)) |
| self.down3 = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base * 4, base * 8)) |
| self.bottom = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base * 8, base * 16)) |
| self.global_pool = nn.AdaptiveAvgPool2d(1) |
| self.up4 = nn.ConvTranspose2d(base * 16, base * 8, 2, 2) |
| self.dec4 = ResBlock2D(base * 16, base * 8) |
| self.up3 = nn.ConvTranspose2d(base * 8, base * 4, 2, 2) |
| self.dec3 = ResBlock2D(base * 8, base * 4) |
| self.up2 = nn.ConvTranspose2d(base * 4, base * 2, 2, 2) |
| self.dec2 = ResBlock2D(base * 4, base * 2) |
| self.up1 = nn.ConvTranspose2d(base * 2, base, 2, 2) |
| self.dec1 = ResBlock2D(base * 2, base) |
| self.final = nn.Conv2d(base, 1, 1) |
| self.out_dim = base * 16 |
|
|
| def forward(self, x): |
| s0 = self.stem(x) |
| s1 = self.down1(s0) |
| s2 = self.down2(s1) |
| s3 = self.down3(s2) |
| b = self.bottom(s3) |
| gf = self.global_pool(b).flatten(1) |
| u4 = self.up4(b); d4 = self.dec4(torch.cat([u4, s3], 1)) |
| u3 = self.up3(d4); d3 = self.dec3(torch.cat([u3, s2], 1)) |
| u2 = self.up2(d3); d2 = self.dec2(torch.cat([u2, s1], 1)) |
| u1 = self.up1(d2); d1 = self.dec1(torch.cat([u1, s0], 1)) |
| return gf, self.final(d1) |
|
|
|
|
| class Student2p5D(nn.Module): |
| """2.5D Concept-Bottleneck multi-task model (the released `Pulmo` model).""" |
|
|
| def __init__(self, n_slices=7, n_concepts=8, base=24, head_dropout=0.1): |
| super().__init__() |
| self.n_slices = n_slices |
| self.n_concepts = n_concepts |
| self.cnn = UNet2D(n_slices, base=base) |
| cd = self.cnn.out_dim |
| self.detection_head = nn.Sequential( |
| nn.LayerNorm(cd), nn.Linear(cd, 256), nn.GELU(), |
| nn.Dropout(head_dropout), nn.Linear(256, 2), |
| ) |
| self.concept_head = nn.Sequential( |
| nn.LayerNorm(cd), nn.Linear(cd, 256), nn.GELU(), |
| nn.Dropout(0.3), nn.Linear(256, n_concepts), |
| ) |
| |
| self.malignancy_head = nn.Linear(n_concepts, 2) |
|
|
| def forward(self, x): |
| gf, seg = self.cnn(x) |
| concepts = self.concept_head(gf) |
| return { |
| "detection": self.detection_head(gf), |
| "concepts": concepts, |
| "malignancy": self.malignancy_head(concepts), |
| "segmentation": seg, |
| } |
|
|
|
|
| def load_pulmo(ckpt_path, device="cpu", n_slices=7, n_concepts=8, base=24): |
| """Build the model and load weights from `student_2p5d_best.pth`.""" |
| model = Student2p5D(n_slices=n_slices, n_concepts=n_concepts, base=base).to(device) |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
| state = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt |
| model.load_state_dict(state, strict=True) |
| model.eval() |
| return model |
|
|
|
|
| if __name__ == "__main__": |
| m = Student2p5D() |
| n = sum(p.numel() for p in m.parameters()) / 1e6 |
| print(f"Pulmo (Student2p5D): {n:.2f}M params") |
| out = m(torch.randn(2, 7, 64, 64)) |
| for k, v in out.items(): |
| print(f" {k:13s}: {tuple(v.shape)}") |
|
|