| """ |
| chexpert_classifier.py |
| ---------------------- |
| Multi-label, multi-CLASS CheXpert pathology classifier (U-MultiClass). |
| |
| Each of the 14 pathologies is predicted as one of THREE classes — |
| negative / positive / uncertain — via a per-pathology softmax, mirroring |
| META-CXR's MHCAC head and the CheXpert "U-MultiClass" uncertainty policy. |
| |
| The structured findings injected into the LLM prompt use the PNU |
| (Positive / Negative / Uncertain) 3-section format. `format_pnu()` is the |
| single source of truth for that string so the oracle path |
| (data/mimic_cxr_builder.py, GT from chexpert.csv) and the learned path |
| (this classifier at inference) produce byte-identical prompts. |
| |
| Trained separately (Stage 0) on MIMIC-CXR CheXbert labels; frozen during |
| Stage 1 / Stage 2 of the main VLM. |
| |
| Reference: RaDialog (Pellegrini et al., 2023) for the prompt-conditioning |
| idea; META-CXR (Edirisinghe et al., 2025) for the explicit uncertain class. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Optional, List, Dict, Sequence |
|
|
|
|
| PATHOLOGIES = [ |
| "No Finding", |
| "Enlarged Cardiomediastinum", |
| "Cardiomegaly", |
| "Lung Opacity", |
| "Lung Lesion", |
| "Edema", |
| "Consolidation", |
| "Pneumonia", |
| "Atelectasis", |
| "Pneumothorax", |
| "Pleural Effusion", |
| "Pleural Other", |
| "Fracture", |
| "Support Devices", |
| ] |
|
|
| |
| |
| |
| CLASS_NEGATIVE = 0 |
| CLASS_POSITIVE = 1 |
| CLASS_UNCERTAIN = 2 |
| NUM_STATES = 3 |
| CLASS_NAMES = {CLASS_NEGATIVE: "negative", |
| CLASS_POSITIVE: "positive", |
| CLASS_UNCERTAIN: "uncertain"} |
|
|
|
|
| def format_pnu(positive: Sequence[str], |
| negative: Sequence[str], |
| uncertain: Sequence[str]) -> str: |
| """ |
| Build the PNU structured-findings string (META-CXR prompt format). |
| |
| Positive Abnormalities: Cardiomegaly, Pleural Effusion |
| Negative Abnormalities: No Finding, Edema, ... |
| Uncertain Abnormalities: Atelectasis |
| |
| Empty sections render as "None" so the three lines are always present |
| (the LLM sees a fixed structure regardless of the case). |
| """ |
| def _fmt(xs: Sequence[str]) -> str: |
| return ", ".join(xs) if xs else "None" |
| return (f"Positive Abnormalities: {_fmt(positive)}\n" |
| f"Negative Abnormalities: {_fmt(negative)}\n" |
| f"Uncertain Abnormalities: {_fmt(uncertain)}") |
|
|
|
|
| def buckets_to_pnu(class_by_pathology: Dict[str, int]) -> str: |
| """Group a {pathology: class_idx} dict into the PNU string.""" |
| pos = [p for p, c in class_by_pathology.items() if c == CLASS_POSITIVE] |
| neg = [p for p, c in class_by_pathology.items() if c == CLASS_NEGATIVE] |
| unc = [p for p, c in class_by_pathology.items() if c == CLASS_UNCERTAIN] |
| return format_pnu(pos, neg, unc) |
|
|
|
|
| class CheXpertClassifier(nn.Module): |
| """ |
| Multi-label, 3-class-per-label classifier on BioViL-T global embeddings. |
| |
| Output logits have shape (B, 14, 3); a per-pathology softmax/argmax |
| yields negative / positive / uncertain. |
| |
| Args: |
| input_dim: global CXR embedding dim |
| num_classes: number of pathologies (14) |
| checkpoint: trained weights (None = not loaded) |
| """ |
|
|
| def __init__( |
| self, |
| input_dim: int = 512, |
| num_classes: int = 14, |
| checkpoint: Optional[str] = None, |
| ): |
| super().__init__() |
|
|
| self.num_classes = num_classes |
| self.num_states = NUM_STATES |
| self.pathologies = PATHOLOGIES |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.Linear(input_dim, 256), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(256, num_classes * NUM_STATES), |
| ) |
|
|
| if checkpoint is not None: |
| self._load_checkpoint(checkpoint) |
|
|
| def _load_checkpoint(self, path: str): |
| state_dict = torch.load(path, map_location="cpu") |
| self.load_state_dict(state_dict) |
| print(f"[CheXpertClassifier] Loaded weights from {path}") |
|
|
| def forward(self, global_features: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| global_features: (B, input_dim) |
| |
| Returns: |
| logits: (B, num_classes, 3) — softmax over the last dim gives |
| P(negative), P(positive), P(uncertain) per pathology. |
| Train with cross-entropy over the last dim (the natural |
| U-MultiClass objective). |
| """ |
| flat = self.classifier(global_features) |
| return flat.view(-1, self.num_classes, NUM_STATES) |
|
|
| @torch.no_grad() |
| def predict(self, global_features: torch.Tensor) -> List[Dict[str, str]]: |
| """ |
| Returns a list (per sample) of {pathology: "negative"|"positive"| |
| "uncertain"} using argmax over the 3-state softmax. |
| """ |
| logits = self.forward(global_features) |
| cls = logits.argmax(dim=-1).cpu() |
| out: List[Dict[str, str]] = [] |
| for i in range(cls.size(0)): |
| out.append({ |
| name: CLASS_NAMES[int(cls[i, j].item())] |
| for j, name in enumerate(self.pathologies) |
| }) |
| return out |
|
|
| @torch.no_grad() |
| def findings_to_text(self, global_features: torch.Tensor) -> List[str]: |
| """ |
| Per-sample PNU structured-findings string, identical in format to the |
| GT oracle path (data/mimic_cxr_builder.py). One string per sample. |
| """ |
| logits = self.forward(global_features) |
| cls = logits.argmax(dim=-1).cpu() |
| texts: List[str] = [] |
| for i in range(cls.size(0)): |
| mapping = {name: int(cls[i, j].item()) |
| for j, name in enumerate(self.pathologies)} |
| texts.append(buckets_to_pnu(mapping)) |
| return texts |
|
|