File size: 5,149 Bytes
8e23aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Pulmo — 2.5D Concept-Bottleneck Multi-task model for lung nodule analysis.

Self-contained model definition. The weights in `student_2p5d_best.pth` were
produced by online knowledge distillation from a 3D teacher (see model card).

The module keys here MUST match the checkpoint exactly:
    cnn.*            -> 2D U-Net backbone (shared trunk)
    detection_head.* -> binary nodule / non-nodule
    concept_head.*   -> 8 LIDC radiological concepts (regression)
    malignancy_head.*-> Linear(8 -> 2)  (the concept bottleneck)
    cnn.final.*      -> segmentation logits of the middle slice

Input : (B, n_slices, 64, 64) float32 in [0, 1]   (n_slices = 7 axial slices)
Output: dict with keys 'detection', 'concepts', 'malignancy', 'segmentation'

Only `torch` is required.
"""

import torch
import torch.nn as nn


CONCEPT_NAMES = [
    "subtlety", "internalStructure", "calcification", "sphericity",
    "margin", "lobulation", "spiculation", "texture",
]


class ResBlock2D(nn.Module):
    def __init__(self, i, o):
        super().__init__()
        self.conv1 = nn.Conv2d(i, o, 3, padding=1, bias=False)
        self.norm1 = nn.InstanceNorm2d(o)
        self.conv2 = nn.Conv2d(o, o, 3, padding=1, bias=False)
        self.norm2 = nn.InstanceNorm2d(o)
        self.act = nn.LeakyReLU(0.1, inplace=True)
        self.skip = nn.Conv2d(i, o, 1, bias=False) if i != o else nn.Identity()

    def forward(self, x):
        idt = self.skip(x)
        out = self.act(self.norm1(self.conv1(x)))
        out = self.norm2(self.conv2(out))
        return self.act(out + idt)


class UNet2D(nn.Module):
    def __init__(self, in_channels, base=24):
        super().__init__()
        self.stem = ResBlock2D(in_channels, base)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base, base * 2))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base * 2, base * 4))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base * 4, base * 8))
        self.bottom = nn.Sequential(nn.MaxPool2d(2), ResBlock2D(base * 8, base * 16))
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.up4 = nn.ConvTranspose2d(base * 16, base * 8, 2, 2)
        self.dec4 = ResBlock2D(base * 16, base * 8)
        self.up3 = nn.ConvTranspose2d(base * 8, base * 4, 2, 2)
        self.dec3 = ResBlock2D(base * 8, base * 4)
        self.up2 = nn.ConvTranspose2d(base * 4, base * 2, 2, 2)
        self.dec2 = ResBlock2D(base * 4, base * 2)
        self.up1 = nn.ConvTranspose2d(base * 2, base, 2, 2)
        self.dec1 = ResBlock2D(base * 2, base)
        self.final = nn.Conv2d(base, 1, 1)
        self.out_dim = base * 16

    def forward(self, x):
        s0 = self.stem(x)
        s1 = self.down1(s0)
        s2 = self.down2(s1)
        s3 = self.down3(s2)
        b = self.bottom(s3)
        gf = self.global_pool(b).flatten(1)
        u4 = self.up4(b);  d4 = self.dec4(torch.cat([u4, s3], 1))
        u3 = self.up3(d4); d3 = self.dec3(torch.cat([u3, s2], 1))
        u2 = self.up2(d3); d2 = self.dec2(torch.cat([u2, s1], 1))
        u1 = self.up1(d2); d1 = self.dec1(torch.cat([u1, s0], 1))
        return gf, self.final(d1)


class Student2p5D(nn.Module):
    """2.5D Concept-Bottleneck multi-task model (the released `Pulmo` model)."""

    def __init__(self, n_slices=7, n_concepts=8, base=24, head_dropout=0.1):
        super().__init__()
        self.n_slices = n_slices
        self.n_concepts = n_concepts
        self.cnn = UNet2D(n_slices, base=base)
        cd = self.cnn.out_dim
        self.detection_head = nn.Sequential(
            nn.LayerNorm(cd), nn.Linear(cd, 256), nn.GELU(),
            nn.Dropout(head_dropout), nn.Linear(256, 2),
        )
        self.concept_head = nn.Sequential(
            nn.LayerNorm(cd), nn.Linear(cd, 256), nn.GELU(),
            nn.Dropout(0.3), nn.Linear(256, n_concepts),
        )
        # Concept bottleneck: malignancy is predicted ONLY from the 8 concepts.
        self.malignancy_head = nn.Linear(n_concepts, 2)

    def forward(self, x):
        gf, seg = self.cnn(x)
        concepts = self.concept_head(gf)
        return {
            "detection": self.detection_head(gf),     # (B, 2)
            "concepts": concepts,                      # (B, 8)
            "malignancy": self.malignancy_head(concepts),  # (B, 2)
            "segmentation": seg,                       # (B, 1, 64, 64)
        }


def load_pulmo(ckpt_path, device="cpu", n_slices=7, n_concepts=8, base=24):
    """Build the model and load weights from `student_2p5d_best.pth`."""
    model = Student2p5D(n_slices=n_slices, n_concepts=n_concepts, base=base).to(device)
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    state = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt
    model.load_state_dict(state, strict=True)
    model.eval()
    return model


if __name__ == "__main__":
    m = Student2p5D()
    n = sum(p.numel() for p in m.parameters()) / 1e6
    print(f"Pulmo (Student2p5D): {n:.2f}M params")
    out = m(torch.randn(2, 7, 64, 64))
    for k, v in out.items():
        print(f"  {k:13s}: {tuple(v.shape)}")