Spaces:
Running
Running
| """PyTorch Dataset for NIH Chest X-ray14.""" | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from src.data.transforms import get_eval_transforms, get_train_transforms | |
| # All 14 pathology labels in the NIH Chest X-ray14 dataset | |
| PATHOLOGY_LABELS = [ | |
| "Atelectasis", | |
| "Cardiomegaly", | |
| "Effusion", | |
| "Infiltration", | |
| "Mass", | |
| "Nodule", | |
| "Pneumonia", | |
| "Pneumothorax", | |
| "Consolidation", | |
| "Edema", | |
| "Emphysema", | |
| "Fibrosis", | |
| "Pleural_Thickening", | |
| "Hernia", | |
| ] | |
| NUM_CLASSES = len(PATHOLOGY_LABELS) | |
| class ChestXrayDataset(Dataset): | |
| """NIH Chest X-ray14 dataset for multi-label and binary classification. | |
| Each sample returns: | |
| image: Tensor of shape (C, H, W) | |
| multilabel_target: Tensor of shape (14,) — one-hot encoded pathology labels | |
| binary_target: Tensor of shape (1,) — 0 for Normal, 1 for Abnormal | |
| """ | |
| def __init__( | |
| self, | |
| image_dir: Path | str, | |
| labels_csv: Path | str, | |
| split: str = "train", | |
| image_size: int = 224, | |
| transform: Any | None = None, | |
| ) -> None: | |
| self.image_dir = Path(image_dir) | |
| self.split = split | |
| self.image_size = image_size | |
| # Load labels CSV | |
| df = pd.read_csv(labels_csv) | |
| if "split" in df.columns: | |
| df = df[df["split"] == split].reset_index(drop=True) | |
| self.image_paths = df["image_path"].tolist() | |
| self.labels_raw = df["labels"].tolist() | |
| # Set transforms | |
| if transform is not None: | |
| self.transform = transform | |
| elif split == "train": | |
| self.transform = get_train_transforms(image_size) | |
| else: | |
| self.transform = get_eval_transforms(image_size) | |
| # Precompute label vectors | |
| self._multilabel_targets = self._encode_multilabel(self.labels_raw) | |
| self._binary_targets = self._encode_binary(self.labels_raw) | |
| def _encode_multilabel(self, labels_list: list[str]) -> np.ndarray: | |
| """Convert string labels to multi-hot vectors.""" | |
| targets = np.zeros((len(labels_list), NUM_CLASSES), dtype=np.float32) | |
| for i, labels_str in enumerate(labels_list): | |
| if labels_str == "No Finding" or pd.isna(labels_str): | |
| continue | |
| for label in labels_str.split("|"): | |
| label = label.strip() | |
| if label in PATHOLOGY_LABELS: | |
| targets[i, PATHOLOGY_LABELS.index(label)] = 1.0 | |
| return targets | |
| def _encode_binary(self, labels_list: list[str]) -> np.ndarray: | |
| """Convert labels to binary: 0=Normal, 1=Abnormal.""" | |
| targets = np.zeros((len(labels_list), 1), dtype=np.float32) | |
| for i, labels_str in enumerate(labels_list): | |
| if labels_str != "No Finding" and not pd.isna(labels_str): | |
| targets[i, 0] = 1.0 | |
| return targets | |
| def __len__(self) -> int: | |
| return len(self.image_paths) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| # Load image | |
| img_path = self.image_dir / self.image_paths[idx] | |
| pil_image = Image.open(img_path).convert("RGB") | |
| # Apply transforms — always required; get_train/eval_transforms() return a Tensor | |
| if self.transform is None: | |
| raise ValueError( | |
| "ChestXrayDataset requires a transform. " | |
| "Use get_train_transforms() or get_eval_transforms()." | |
| ) | |
| image: torch.Tensor = self.transform(pil_image) | |
| return { | |
| "image": image, | |
| "multilabel_target": torch.from_numpy(self._multilabel_targets[idx]), | |
| "binary_target": torch.from_numpy(self._binary_targets[idx]), | |
| } | |
| def get_label_weights(self) -> torch.Tensor: | |
| """Compute positive class weights for handling class imbalance (pos_weight for BCEWithLogitsLoss).""" | |
| pos_counts = self._multilabel_targets.sum(axis=0) | |
| neg_counts = len(self) - pos_counts | |
| # Avoid division by zero | |
| pos_weights = neg_counts / np.maximum(pos_counts, 1.0) | |
| return torch.from_numpy(pos_weights.astype(np.float32)) | |