File size: 6,050 Bytes
28b13fc
 
 
215ecd6
28b13fc
215ecd6
 
 
28b13fc
215ecd6
 
 
 
 
 
 
 
 
 
 
28b13fc
 
 
 
215ecd6
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215ecd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28b13fc
 
 
215ecd6
28b13fc
215ecd6
 
28b13fc
 
215ecd6
 
 
28b13fc
 
 
 
215ecd6
28b13fc
215ecd6
28b13fc
 
 
 
215ecd6
28b13fc
 
215ecd6
28b13fc
 
 
 
215ecd6
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
215ecd6
28b13fc
 
215ecd6
 
 
 
28b13fc
215ecd6
 
28b13fc
 
 
 
215ecd6
 
28b13fc
215ecd6
 
 
 
 
 
 
 
 
28b13fc
215ecd6
 
28b13fc
215ecd6
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
chexpert_classifier.py
----------------------
Multi-label, multi-CLASS CheXpert pathology classifier (U-MultiClass).

Each of the 14 pathologies is predicted as one of THREE classes —
negative / positive / uncertain — via a per-pathology softmax, mirroring
META-CXR's MHCAC head and the CheXpert "U-MultiClass" uncertainty policy.

The structured findings injected into the LLM prompt use the PNU
(Positive / Negative / Uncertain) 3-section format. `format_pnu()` is the
single source of truth for that string so the oracle path
(data/mimic_cxr_builder.py, GT from chexpert.csv) and the learned path
(this classifier at inference) produce byte-identical prompts.

Trained separately (Stage 0) on MIMIC-CXR CheXbert labels; frozen during
Stage 1 / Stage 2 of the main VLM.

Reference: RaDialog (Pellegrini et al., 2023) for the prompt-conditioning
idea; META-CXR (Edirisinghe et al., 2025) for the explicit uncertain class.
"""

import torch
import torch.nn as nn
from typing import Optional, List, Dict, Sequence


PATHOLOGIES = [
    "No Finding",
    "Enlarged Cardiomediastinum",
    "Cardiomegaly",
    "Lung Opacity",
    "Lung Lesion",
    "Edema",
    "Consolidation",
    "Pneumonia",
    "Atelectasis",
    "Pneumothorax",
    "Pleural Effusion",
    "Pleural Other",
    "Fracture",
    "Support Devices",
]

# Per-pathology class indices (softmax dim order). Keep this stable: the
# trained checkpoint and the GT-label mapping in mimic_cxr_builder.py both
# rely on it.
CLASS_NEGATIVE  = 0
CLASS_POSITIVE  = 1
CLASS_UNCERTAIN = 2
NUM_STATES      = 3
CLASS_NAMES     = {CLASS_NEGATIVE: "negative",
                   CLASS_POSITIVE: "positive",
                   CLASS_UNCERTAIN: "uncertain"}


def format_pnu(positive: Sequence[str],
               negative: Sequence[str],
               uncertain: Sequence[str]) -> str:
    """
    Build the PNU structured-findings string (META-CXR prompt format).

        Positive Abnormalities: Cardiomegaly, Pleural Effusion
        Negative Abnormalities: No Finding, Edema, ...
        Uncertain Abnormalities: Atelectasis

    Empty sections render as "None" so the three lines are always present
    (the LLM sees a fixed structure regardless of the case).
    """
    def _fmt(xs: Sequence[str]) -> str:
        return ", ".join(xs) if xs else "None"
    return (f"Positive Abnormalities: {_fmt(positive)}\n"
            f"Negative Abnormalities: {_fmt(negative)}\n"
            f"Uncertain Abnormalities: {_fmt(uncertain)}")


def buckets_to_pnu(class_by_pathology: Dict[str, int]) -> str:
    """Group a {pathology: class_idx} dict into the PNU string."""
    pos = [p for p, c in class_by_pathology.items() if c == CLASS_POSITIVE]
    neg = [p for p, c in class_by_pathology.items() if c == CLASS_NEGATIVE]
    unc = [p for p, c in class_by_pathology.items() if c == CLASS_UNCERTAIN]
    return format_pnu(pos, neg, unc)


class CheXpertClassifier(nn.Module):
    """
    Multi-label, 3-class-per-label classifier on BioViL-T global embeddings.

    Output logits have shape (B, 14, 3); a per-pathology softmax/argmax
    yields negative / positive / uncertain.

    Args:
        input_dim:   global CXR embedding dim
        num_classes: number of pathologies (14)
        checkpoint:  trained weights (None = not loaded)
    """

    def __init__(
        self,
        input_dim:   int = 512,
        num_classes: int = 14,
        checkpoint:  Optional[str] = None,
    ):
        super().__init__()

        self.num_classes = num_classes
        self.num_states  = NUM_STATES
        self.pathologies = PATHOLOGIES

        # MLP head → num_classes * 3 logits, reshaped to (B, num_classes, 3)
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes * NUM_STATES),
        )

        if checkpoint is not None:
            self._load_checkpoint(checkpoint)

    def _load_checkpoint(self, path: str):
        state_dict = torch.load(path, map_location="cpu")
        self.load_state_dict(state_dict)
        print(f"[CheXpertClassifier] Loaded weights from {path}")

    def forward(self, global_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            global_features: (B, input_dim)

        Returns:
            logits: (B, num_classes, 3)  — softmax over the last dim gives
                    P(negative), P(positive), P(uncertain) per pathology.
                    Train with cross-entropy over the last dim (the natural
                    U-MultiClass objective).
        """
        flat = self.classifier(global_features)              # (B, 14*3)
        return flat.view(-1, self.num_classes, NUM_STATES)   # (B, 14, 3)

    @torch.no_grad()
    def predict(self, global_features: torch.Tensor) -> List[Dict[str, str]]:
        """
        Returns a list (per sample) of {pathology: "negative"|"positive"|
        "uncertain"} using argmax over the 3-state softmax.
        """
        logits = self.forward(global_features)        # (B, 14, 3)
        cls    = logits.argmax(dim=-1).cpu()          # (B, 14)
        out: List[Dict[str, str]] = []
        for i in range(cls.size(0)):
            out.append({
                name: CLASS_NAMES[int(cls[i, j].item())]
                for j, name in enumerate(self.pathologies)
            })
        return out

    @torch.no_grad()
    def findings_to_text(self, global_features: torch.Tensor) -> List[str]:
        """
        Per-sample PNU structured-findings string, identical in format to the
        GT oracle path (data/mimic_cxr_builder.py). One string per sample.
        """
        logits = self.forward(global_features)        # (B, 14, 3)
        cls    = logits.argmax(dim=-1).cpu()          # (B, 14)
        texts: List[str] = []
        for i in range(cls.size(0)):
            mapping = {name: int(cls[i, j].item())
                       for j, name in enumerate(self.pathologies)}
            texts.append(buckets_to_pnu(mapping))
        return texts