Spaces:
Running
Running
File size: 4,275 Bytes
ac0940b 91347e8 ac0940b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | """PyTorch Dataset for NIH Chest X-ray14."""
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
from src.data.transforms import get_eval_transforms, get_train_transforms
# All 14 pathology labels in the NIH Chest X-ray14 dataset
PATHOLOGY_LABELS = [
"Atelectasis",
"Cardiomegaly",
"Effusion",
"Infiltration",
"Mass",
"Nodule",
"Pneumonia",
"Pneumothorax",
"Consolidation",
"Edema",
"Emphysema",
"Fibrosis",
"Pleural_Thickening",
"Hernia",
]
NUM_CLASSES = len(PATHOLOGY_LABELS)
class ChestXrayDataset(Dataset):
"""NIH Chest X-ray14 dataset for multi-label and binary classification.
Each sample returns:
image: Tensor of shape (C, H, W)
multilabel_target: Tensor of shape (14,) — one-hot encoded pathology labels
binary_target: Tensor of shape (1,) — 0 for Normal, 1 for Abnormal
"""
def __init__(
self,
image_dir: Path | str,
labels_csv: Path | str,
split: str = "train",
image_size: int = 224,
transform: Any | None = None,
) -> None:
self.image_dir = Path(image_dir)
self.split = split
self.image_size = image_size
# Load labels CSV
df = pd.read_csv(labels_csv)
if "split" in df.columns:
df = df[df["split"] == split].reset_index(drop=True)
self.image_paths = df["image_path"].tolist()
self.labels_raw = df["labels"].tolist()
# Set transforms
if transform is not None:
self.transform = transform
elif split == "train":
self.transform = get_train_transforms(image_size)
else:
self.transform = get_eval_transforms(image_size)
# Precompute label vectors
self._multilabel_targets = self._encode_multilabel(self.labels_raw)
self._binary_targets = self._encode_binary(self.labels_raw)
def _encode_multilabel(self, labels_list: list[str]) -> np.ndarray:
"""Convert string labels to multi-hot vectors."""
targets = np.zeros((len(labels_list), NUM_CLASSES), dtype=np.float32)
for i, labels_str in enumerate(labels_list):
if labels_str == "No Finding" or pd.isna(labels_str):
continue
for label in labels_str.split("|"):
label = label.strip()
if label in PATHOLOGY_LABELS:
targets[i, PATHOLOGY_LABELS.index(label)] = 1.0
return targets
def _encode_binary(self, labels_list: list[str]) -> np.ndarray:
"""Convert labels to binary: 0=Normal, 1=Abnormal."""
targets = np.zeros((len(labels_list), 1), dtype=np.float32)
for i, labels_str in enumerate(labels_list):
if labels_str != "No Finding" and not pd.isna(labels_str):
targets[i, 0] = 1.0
return targets
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
# Load image
img_path = self.image_dir / self.image_paths[idx]
pil_image = Image.open(img_path).convert("RGB")
# Apply transforms — always required; get_train/eval_transforms() return a Tensor
if self.transform is None:
raise ValueError(
"ChestXrayDataset requires a transform. "
"Use get_train_transforms() or get_eval_transforms()."
)
image: torch.Tensor = self.transform(pil_image)
return {
"image": image,
"multilabel_target": torch.from_numpy(self._multilabel_targets[idx]),
"binary_target": torch.from_numpy(self._binary_targets[idx]),
}
def get_label_weights(self) -> torch.Tensor:
"""Compute positive class weights for handling class imbalance (pos_weight for BCEWithLogitsLoss)."""
pos_counts = self._multilabel_targets.sum(axis=0)
neg_counts = len(self) - pos_counts
# Avoid division by zero
pos_weights = neg_counts / np.maximum(pos_counts, 1.0)
return torch.from_numpy(pos_weights.astype(np.float32))
|