File size: 2,624 Bytes
8e23aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
Pulmo inference example.

Shows how to:
  1. download the weights from the Hub,
  2. build the model,
  3. run a forward pass on a 64^3 CT patch (or 7 axial slices),
  4. read the concept-bottleneck explanation for the malignancy decision.

Requires: torch, numpy, huggingface_hub
    pip install torch numpy huggingface_hub
"""

import numpy as np
import torch
from huggingface_hub import hf_hub_download

from modeling import Student2p5D, CONCEPT_NAMES

REPO_ID = "ariyul/Pulmo"
HU_CLIP = (-1000, 1000)
ROI = 64
N_SLICES = 7


def preprocess_patch(patch_3d):
    """64^3 raw-HU patch (Z, Y, X) -> (1, 7, 64, 64) float32 tensor in [0, 1].

    The 7 central axial slices are extracted along Z, matching how the model
    was trained.
    """
    p = np.clip(patch_3d.astype(np.float32), HU_CLIP[0], HU_CLIP[1])
    p = (p - HU_CLIP[0]) / (HU_CLIP[1] - HU_CLIP[0])
    c = ROI // 2
    h = N_SLICES // 2
    slices = p[c - h:c + h + 1]              # (7, 64, 64)
    return torch.from_numpy(slices[None])     # (1, 7, 64, 64)


def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"

    ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="student_2p5d_best.pth")
    model = Student2p5D(n_slices=N_SLICES, n_concepts=8, base=24).to(device)
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    state = ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt
    model.load_state_dict(state, strict=True)
    model.eval()

    # --- Replace this with a real 64^3 patch cropped around a candidate ---
    dummy_patch = np.random.randint(-1000, 400, size=(64, 64, 64)).astype(np.int16)
    x = preprocess_patch(dummy_patch).to(device)

    with torch.no_grad():
        out = model(x)

    det_p = torch.softmax(out["detection"][0], 0)[1].item()
    mal_p = torch.softmax(out["malignancy"][0], 0)[1].item()
    concepts = out["concepts"][0].cpu().numpy()
    seg = torch.sigmoid(out["segmentation"][0, 0]).cpu().numpy()

    print(f"Nodule probability    : {det_p:.3f}")
    print(f"Malignancy probability: {mal_p:.3f}")
    print(f"Segmented voxels (>0.5): {(seg > 0.5).sum()}")

    # Concept-bottleneck explanation:
    # malignancy logit = sum_i  w_net[i] * concept[i],  w_net = W[malign] - W[benign]
    W = model.malignancy_head.weight.detach().cpu().numpy()
    w_net = W[1] - W[0]
    contrib = w_net * concepts
    print("\nPer-concept contribution to the malignancy decision:")
    for i in np.argsort(contrib)[::-1]:
        print(f"  {CONCEPT_NAMES[i]:18s}  value={concepts[i]:+.2f}  contribution={contrib[i]:+.3f}")


if __name__ == "__main__":
    main()