""" chexpert_classifier.py ---------------------- Multi-label, multi-CLASS CheXpert pathology classifier (U-MultiClass). Each of the 14 pathologies is predicted as one of THREE classes — negative / positive / uncertain — via a per-pathology softmax, mirroring META-CXR's MHCAC head and the CheXpert "U-MultiClass" uncertainty policy. The structured findings injected into the LLM prompt use the PNU (Positive / Negative / Uncertain) 3-section format. `format_pnu()` is the single source of truth for that string so the oracle path (data/mimic_cxr_builder.py, GT from chexpert.csv) and the learned path (this classifier at inference) produce byte-identical prompts. Trained separately (Stage 0) on MIMIC-CXR CheXbert labels; frozen during Stage 1 / Stage 2 of the main VLM. Reference: RaDialog (Pellegrini et al., 2023) for the prompt-conditioning idea; META-CXR (Edirisinghe et al., 2025) for the explicit uncertain class. """ import torch import torch.nn as nn from typing import Optional, List, Dict, Sequence PATHOLOGIES = [ "No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity", "Lung Lesion", "Edema", "Consolidation", "Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices", ] # Per-pathology class indices (softmax dim order). Keep this stable: the # trained checkpoint and the GT-label mapping in mimic_cxr_builder.py both # rely on it. CLASS_NEGATIVE = 0 CLASS_POSITIVE = 1 CLASS_UNCERTAIN = 2 NUM_STATES = 3 CLASS_NAMES = {CLASS_NEGATIVE: "negative", CLASS_POSITIVE: "positive", CLASS_UNCERTAIN: "uncertain"} def format_pnu(positive: Sequence[str], negative: Sequence[str], uncertain: Sequence[str]) -> str: """ Build the PNU structured-findings string (META-CXR prompt format). Positive Abnormalities: Cardiomegaly, Pleural Effusion Negative Abnormalities: No Finding, Edema, ... Uncertain Abnormalities: Atelectasis Empty sections render as "None" so the three lines are always present (the LLM sees a fixed structure regardless of the case). """ def _fmt(xs: Sequence[str]) -> str: return ", ".join(xs) if xs else "None" return (f"Positive Abnormalities: {_fmt(positive)}\n" f"Negative Abnormalities: {_fmt(negative)}\n" f"Uncertain Abnormalities: {_fmt(uncertain)}") def buckets_to_pnu(class_by_pathology: Dict[str, int]) -> str: """Group a {pathology: class_idx} dict into the PNU string.""" pos = [p for p, c in class_by_pathology.items() if c == CLASS_POSITIVE] neg = [p for p, c in class_by_pathology.items() if c == CLASS_NEGATIVE] unc = [p for p, c in class_by_pathology.items() if c == CLASS_UNCERTAIN] return format_pnu(pos, neg, unc) class CheXpertClassifier(nn.Module): """ Multi-label, 3-class-per-label classifier on BioViL-T global embeddings. Output logits have shape (B, 14, 3); a per-pathology softmax/argmax yields negative / positive / uncertain. Args: input_dim: global CXR embedding dim num_classes: number of pathologies (14) checkpoint: trained weights (None = not loaded) """ def __init__( self, input_dim: int = 512, num_classes: int = 14, checkpoint: Optional[str] = None, ): super().__init__() self.num_classes = num_classes self.num_states = NUM_STATES self.pathologies = PATHOLOGIES # MLP head → num_classes * 3 logits, reshaped to (B, num_classes, 3) self.classifier = nn.Sequential( nn.Linear(input_dim, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, num_classes * NUM_STATES), ) if checkpoint is not None: self._load_checkpoint(checkpoint) def _load_checkpoint(self, path: str): state_dict = torch.load(path, map_location="cpu") self.load_state_dict(state_dict) print(f"[CheXpertClassifier] Loaded weights from {path}") def forward(self, global_features: torch.Tensor) -> torch.Tensor: """ Args: global_features: (B, input_dim) Returns: logits: (B, num_classes, 3) — softmax over the last dim gives P(negative), P(positive), P(uncertain) per pathology. Train with cross-entropy over the last dim (the natural U-MultiClass objective). """ flat = self.classifier(global_features) # (B, 14*3) return flat.view(-1, self.num_classes, NUM_STATES) # (B, 14, 3) @torch.no_grad() def predict(self, global_features: torch.Tensor) -> List[Dict[str, str]]: """ Returns a list (per sample) of {pathology: "negative"|"positive"| "uncertain"} using argmax over the 3-state softmax. """ logits = self.forward(global_features) # (B, 14, 3) cls = logits.argmax(dim=-1).cpu() # (B, 14) out: List[Dict[str, str]] = [] for i in range(cls.size(0)): out.append({ name: CLASS_NAMES[int(cls[i, j].item())] for j, name in enumerate(self.pathologies) }) return out @torch.no_grad() def findings_to_text(self, global_features: torch.Tensor) -> List[str]: """ Per-sample PNU structured-findings string, identical in format to the GT oracle path (data/mimic_cxr_builder.py). One string per sample. """ logits = self.forward(global_features) # (B, 14, 3) cls = logits.argmax(dim=-1).cpu() # (B, 14) texts: List[str] = [] for i in range(cls.size(0)): mapping = {name: int(cls[i, j].item()) for j, name in enumerate(self.pathologies)} texts.append(buckets_to_pnu(mapping)) return texts