File size: 6,050 Bytes
28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 28b13fc 215ecd6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """
chexpert_classifier.py
----------------------
Multi-label, multi-CLASS CheXpert pathology classifier (U-MultiClass).
Each of the 14 pathologies is predicted as one of THREE classes —
negative / positive / uncertain — via a per-pathology softmax, mirroring
META-CXR's MHCAC head and the CheXpert "U-MultiClass" uncertainty policy.
The structured findings injected into the LLM prompt use the PNU
(Positive / Negative / Uncertain) 3-section format. `format_pnu()` is the
single source of truth for that string so the oracle path
(data/mimic_cxr_builder.py, GT from chexpert.csv) and the learned path
(this classifier at inference) produce byte-identical prompts.
Trained separately (Stage 0) on MIMIC-CXR CheXbert labels; frozen during
Stage 1 / Stage 2 of the main VLM.
Reference: RaDialog (Pellegrini et al., 2023) for the prompt-conditioning
idea; META-CXR (Edirisinghe et al., 2025) for the explicit uncertain class.
"""
import torch
import torch.nn as nn
from typing import Optional, List, Dict, Sequence
PATHOLOGIES = [
"No Finding",
"Enlarged Cardiomediastinum",
"Cardiomegaly",
"Lung Opacity",
"Lung Lesion",
"Edema",
"Consolidation",
"Pneumonia",
"Atelectasis",
"Pneumothorax",
"Pleural Effusion",
"Pleural Other",
"Fracture",
"Support Devices",
]
# Per-pathology class indices (softmax dim order). Keep this stable: the
# trained checkpoint and the GT-label mapping in mimic_cxr_builder.py both
# rely on it.
CLASS_NEGATIVE = 0
CLASS_POSITIVE = 1
CLASS_UNCERTAIN = 2
NUM_STATES = 3
CLASS_NAMES = {CLASS_NEGATIVE: "negative",
CLASS_POSITIVE: "positive",
CLASS_UNCERTAIN: "uncertain"}
def format_pnu(positive: Sequence[str],
negative: Sequence[str],
uncertain: Sequence[str]) -> str:
"""
Build the PNU structured-findings string (META-CXR prompt format).
Positive Abnormalities: Cardiomegaly, Pleural Effusion
Negative Abnormalities: No Finding, Edema, ...
Uncertain Abnormalities: Atelectasis
Empty sections render as "None" so the three lines are always present
(the LLM sees a fixed structure regardless of the case).
"""
def _fmt(xs: Sequence[str]) -> str:
return ", ".join(xs) if xs else "None"
return (f"Positive Abnormalities: {_fmt(positive)}\n"
f"Negative Abnormalities: {_fmt(negative)}\n"
f"Uncertain Abnormalities: {_fmt(uncertain)}")
def buckets_to_pnu(class_by_pathology: Dict[str, int]) -> str:
"""Group a {pathology: class_idx} dict into the PNU string."""
pos = [p for p, c in class_by_pathology.items() if c == CLASS_POSITIVE]
neg = [p for p, c in class_by_pathology.items() if c == CLASS_NEGATIVE]
unc = [p for p, c in class_by_pathology.items() if c == CLASS_UNCERTAIN]
return format_pnu(pos, neg, unc)
class CheXpertClassifier(nn.Module):
"""
Multi-label, 3-class-per-label classifier on BioViL-T global embeddings.
Output logits have shape (B, 14, 3); a per-pathology softmax/argmax
yields negative / positive / uncertain.
Args:
input_dim: global CXR embedding dim
num_classes: number of pathologies (14)
checkpoint: trained weights (None = not loaded)
"""
def __init__(
self,
input_dim: int = 512,
num_classes: int = 14,
checkpoint: Optional[str] = None,
):
super().__init__()
self.num_classes = num_classes
self.num_states = NUM_STATES
self.pathologies = PATHOLOGIES
# MLP head → num_classes * 3 logits, reshaped to (B, num_classes, 3)
self.classifier = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, num_classes * NUM_STATES),
)
if checkpoint is not None:
self._load_checkpoint(checkpoint)
def _load_checkpoint(self, path: str):
state_dict = torch.load(path, map_location="cpu")
self.load_state_dict(state_dict)
print(f"[CheXpertClassifier] Loaded weights from {path}")
def forward(self, global_features: torch.Tensor) -> torch.Tensor:
"""
Args:
global_features: (B, input_dim)
Returns:
logits: (B, num_classes, 3) — softmax over the last dim gives
P(negative), P(positive), P(uncertain) per pathology.
Train with cross-entropy over the last dim (the natural
U-MultiClass objective).
"""
flat = self.classifier(global_features) # (B, 14*3)
return flat.view(-1, self.num_classes, NUM_STATES) # (B, 14, 3)
@torch.no_grad()
def predict(self, global_features: torch.Tensor) -> List[Dict[str, str]]:
"""
Returns a list (per sample) of {pathology: "negative"|"positive"|
"uncertain"} using argmax over the 3-state softmax.
"""
logits = self.forward(global_features) # (B, 14, 3)
cls = logits.argmax(dim=-1).cpu() # (B, 14)
out: List[Dict[str, str]] = []
for i in range(cls.size(0)):
out.append({
name: CLASS_NAMES[int(cls[i, j].item())]
for j, name in enumerate(self.pathologies)
})
return out
@torch.no_grad()
def findings_to_text(self, global_features: torch.Tensor) -> List[str]:
"""
Per-sample PNU structured-findings string, identical in format to the
GT oracle path (data/mimic_cxr_builder.py). One string per sample.
"""
logits = self.forward(global_features) # (B, 14, 3)
cls = logits.argmax(dim=-1).cpu() # (B, 14)
texts: List[str] = []
for i in range(cls.size(0)):
mapping = {name: int(cls[i, j].item())
for j, name in enumerate(self.pathologies)}
texts.append(buckets_to_pnu(mapping))
return texts
|