cxr-vlm-code / model /chexpert_classifier.py
convitom
feat(chexpert): U-MultiClass PNU abnormality guidance + abnormality-guided VQA
215ecd6
"""
chexpert_classifier.py
----------------------
Multi-label, multi-CLASS CheXpert pathology classifier (U-MultiClass).
Each of the 14 pathologies is predicted as one of THREE classes —
negative / positive / uncertain — via a per-pathology softmax, mirroring
META-CXR's MHCAC head and the CheXpert "U-MultiClass" uncertainty policy.
The structured findings injected into the LLM prompt use the PNU
(Positive / Negative / Uncertain) 3-section format. `format_pnu()` is the
single source of truth for that string so the oracle path
(data/mimic_cxr_builder.py, GT from chexpert.csv) and the learned path
(this classifier at inference) produce byte-identical prompts.
Trained separately (Stage 0) on MIMIC-CXR CheXbert labels; frozen during
Stage 1 / Stage 2 of the main VLM.
Reference: RaDialog (Pellegrini et al., 2023) for the prompt-conditioning
idea; META-CXR (Edirisinghe et al., 2025) for the explicit uncertain class.
"""
import torch
import torch.nn as nn
from typing import Optional, List, Dict, Sequence
PATHOLOGIES = [
"No Finding",
"Enlarged Cardiomediastinum",
"Cardiomegaly",
"Lung Opacity",
"Lung Lesion",
"Edema",
"Consolidation",
"Pneumonia",
"Atelectasis",
"Pneumothorax",
"Pleural Effusion",
"Pleural Other",
"Fracture",
"Support Devices",
]
# Per-pathology class indices (softmax dim order). Keep this stable: the
# trained checkpoint and the GT-label mapping in mimic_cxr_builder.py both
# rely on it.
CLASS_NEGATIVE = 0
CLASS_POSITIVE = 1
CLASS_UNCERTAIN = 2
NUM_STATES = 3
CLASS_NAMES = {CLASS_NEGATIVE: "negative",
CLASS_POSITIVE: "positive",
CLASS_UNCERTAIN: "uncertain"}
def format_pnu(positive: Sequence[str],
negative: Sequence[str],
uncertain: Sequence[str]) -> str:
"""
Build the PNU structured-findings string (META-CXR prompt format).
Positive Abnormalities: Cardiomegaly, Pleural Effusion
Negative Abnormalities: No Finding, Edema, ...
Uncertain Abnormalities: Atelectasis
Empty sections render as "None" so the three lines are always present
(the LLM sees a fixed structure regardless of the case).
"""
def _fmt(xs: Sequence[str]) -> str:
return ", ".join(xs) if xs else "None"
return (f"Positive Abnormalities: {_fmt(positive)}\n"
f"Negative Abnormalities: {_fmt(negative)}\n"
f"Uncertain Abnormalities: {_fmt(uncertain)}")
def buckets_to_pnu(class_by_pathology: Dict[str, int]) -> str:
"""Group a {pathology: class_idx} dict into the PNU string."""
pos = [p for p, c in class_by_pathology.items() if c == CLASS_POSITIVE]
neg = [p for p, c in class_by_pathology.items() if c == CLASS_NEGATIVE]
unc = [p for p, c in class_by_pathology.items() if c == CLASS_UNCERTAIN]
return format_pnu(pos, neg, unc)
class CheXpertClassifier(nn.Module):
"""
Multi-label, 3-class-per-label classifier on BioViL-T global embeddings.
Output logits have shape (B, 14, 3); a per-pathology softmax/argmax
yields negative / positive / uncertain.
Args:
input_dim: global CXR embedding dim
num_classes: number of pathologies (14)
checkpoint: trained weights (None = not loaded)
"""
def __init__(
self,
input_dim: int = 512,
num_classes: int = 14,
checkpoint: Optional[str] = None,
):
super().__init__()
self.num_classes = num_classes
self.num_states = NUM_STATES
self.pathologies = PATHOLOGIES
# MLP head → num_classes * 3 logits, reshaped to (B, num_classes, 3)
self.classifier = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, num_classes * NUM_STATES),
)
if checkpoint is not None:
self._load_checkpoint(checkpoint)
def _load_checkpoint(self, path: str):
state_dict = torch.load(path, map_location="cpu")
self.load_state_dict(state_dict)
print(f"[CheXpertClassifier] Loaded weights from {path}")
def forward(self, global_features: torch.Tensor) -> torch.Tensor:
"""
Args:
global_features: (B, input_dim)
Returns:
logits: (B, num_classes, 3) — softmax over the last dim gives
P(negative), P(positive), P(uncertain) per pathology.
Train with cross-entropy over the last dim (the natural
U-MultiClass objective).
"""
flat = self.classifier(global_features) # (B, 14*3)
return flat.view(-1, self.num_classes, NUM_STATES) # (B, 14, 3)
@torch.no_grad()
def predict(self, global_features: torch.Tensor) -> List[Dict[str, str]]:
"""
Returns a list (per sample) of {pathology: "negative"|"positive"|
"uncertain"} using argmax over the 3-state softmax.
"""
logits = self.forward(global_features) # (B, 14, 3)
cls = logits.argmax(dim=-1).cpu() # (B, 14)
out: List[Dict[str, str]] = []
for i in range(cls.size(0)):
out.append({
name: CLASS_NAMES[int(cls[i, j].item())]
for j, name in enumerate(self.pathologies)
})
return out
@torch.no_grad()
def findings_to_text(self, global_features: torch.Tensor) -> List[str]:
"""
Per-sample PNU structured-findings string, identical in format to the
GT oracle path (data/mimic_cxr_builder.py). One string per sample.
"""
logits = self.forward(global_features) # (B, 14, 3)
cls = logits.argmax(dim=-1).cpu() # (B, 14)
texts: List[str] = []
for i in range(cls.size(0)):
mapping = {name: int(cls[i, j].item())
for j, name in enumerate(self.pathologies)}
texts.append(buckets_to_pnu(mapping))
return texts