""" Pulmo — 2.5D Concept-Bottleneck Multi-task model for lung nodule analysis. Self-contained model definition. The weights in `student_2p5d_best.pth` were produced by online knowledge distillation from a 3D teacher (see model card). The module keys here MUST match the checkpoint exactly: cnn.* -> 2D U-Net backbone (shared trunk) detection_head.* -> binary nodule / non-nodule concept_head.* -> 8 LIDC radiological concepts (regression) malignancy_head.*-> Linear(8 -> 2) (the concept bottleneck) cnn.final.* -> segmentation logits of the middle slice Input : (B, n_slices, 64, 64) float32 in [0, 1] (n_slices = 7 axial slices) Output: dict with keys 'detection', 'concepts', 'malignancy', 'segmentation' Only `torch` is required. """ import torch import torch.nn as nn CONCEPT_NAMES = [ "subtlety", "internalStructure", "calcification", "sphericity", "margin", "lobulation", "spiculation", "texture", ] class ResBlock2D(nn.Module): def __init__(self, i, o): super().__init__() self.conv1 = nn.Conv2d(i, o, 3, padding=1, bias=False) self.norm1 = nn.InstanceNorm2d(o) self.conv2 = nn.Conv2d(o, o, 3, padding=1, bias=False) self.norm2 = nn.InstanceNorm2d(o) self.act = nn.LeakyReLU(0.1, inplace=True) self.skip = nn.Conv2d(i, o, 1, bias=False) if i != o else nn.Identity() def forward(self, x): idt = self.skip(x) out = self.act(self.norm1(self.conv1(x))) out = self.norm2(self.conv2(out)) return self.act(out + idt) class UNet2D(nn.Module): def __init__(self, in_channels, base=24): super().__init__() self.stem = ResBlock2D(in_channels, base) self.down1 = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base, base * 2)) self.down2 = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base * 2, base * 4)) self.down3 = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base * 4, base * 8)) self.bottom = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base * 8, base * 16)) self.global_pool = nn.AdaptiveAvgPool2d(1) self.up4 = nn.ConvTranspose2d(base * 16, base * 8, 2, 2) self.dec4 = ResBlock2D(base * 16, base * 8) self.up3 = nn.ConvTranspose2d(base * 8, base * 4, 2, 2) self.dec3 = ResBlock2D(base * 8, base * 4) self.up2 = nn.ConvTranspose2d(base * 4, base * 2, 2, 2) self.dec2 = ResBlock2D(base * 4, base * 2) self.up1 = nn.ConvTranspose2d(base * 2, base, 2, 2) self.dec1 = ResBlock2D(base * 2, base) self.final = nn.Conv2d(base, 1, 1) self.out_dim = base * 16 def forward(self, x): s0 = self.stem(x) s1 = self.down1(s0) s2 = self.down2(s1) s3 = self.down3(s2) b = self.bottom(s3) gf = self.global_pool(b).flatten(1) u4 = self.up4(b); d4 = self.dec4(torch.cat([u4, s3], 1)) u3 = self.up3(d4); d3 = self.dec3(torch.cat([u3, s2], 1)) u2 = self.up2(d3); d2 = self.dec2(torch.cat([u2, s1], 1)) u1 = self.up1(d2); d1 = self.dec1(torch.cat([u1, s0], 1)) return gf, self.final(d1) class Student2p5D(nn.Module): """2.5D Concept-Bottleneck multi-task model (the released `Pulmo` model).""" def __init__(self, n_slices=7, n_concepts=8, base=24, head_dropout=0.1): super().__init__() self.n_slices = n_slices self.n_concepts = n_concepts self.cnn = UNet2D(n_slices, base=base) cd = self.cnn.out_dim self.detection_head = nn.Sequential( nn.LayerNorm(cd), nn.Linear(cd, 256), nn.GELU(), nn.Dropout(head_dropout), nn.Linear(256, 2), ) self.concept_head = nn.Sequential( nn.LayerNorm(cd), nn.Linear(cd, 256), nn.GELU(), nn.Dropout(0.3), nn.Linear(256, n_concepts), ) # Concept bottleneck: malignancy is predicted ONLY from the 8 concepts. self.malignancy_head = nn.Linear(n_concepts, 2) def forward(self, x): gf, seg = self.cnn(x) concepts = self.concept_head(gf) return { "detection": self.detection_head(gf), # (B, 2) "concepts": concepts, # (B, 8) "malignancy": self.malignancy_head(concepts), # (B, 2) "segmentation": seg, # (B, 1, 64, 64) } def load_pulmo(ckpt_path, device="cpu", n_slices=7, n_concepts=8, base=24): """Build the model and load weights from `student_2p5d_best.pth`.""" model = Student2p5D(n_slices=n_slices, n_concepts=n_concepts, base=base).to(device) ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) state = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt model.load_state_dict(state, strict=True) model.eval() return model if __name__ == "__main__": m = Student2p5D() n = sum(p.numel() for p in m.parameters()) / 1e6 print(f"Pulmo (Student2p5D): {n:.2f}M params") out = m(torch.randn(2, 7, 64, 64)) for k, v in out.items(): print(f" {k:13s}: {tuple(v.shape)}")