chexvision-demo / src /data /dataset.py
arudaev's picture
fix(types): resolve all 8 mypy errors across 5 files
91347e8
"""PyTorch Dataset for NIH Chest X-ray14."""
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
from src.data.transforms import get_eval_transforms, get_train_transforms
# All 14 pathology labels in the NIH Chest X-ray14 dataset
PATHOLOGY_LABELS = [
"Atelectasis",
"Cardiomegaly",
"Effusion",
"Infiltration",
"Mass",
"Nodule",
"Pneumonia",
"Pneumothorax",
"Consolidation",
"Edema",
"Emphysema",
"Fibrosis",
"Pleural_Thickening",
"Hernia",
]
NUM_CLASSES = len(PATHOLOGY_LABELS)
class ChestXrayDataset(Dataset):
"""NIH Chest X-ray14 dataset for multi-label and binary classification.
Each sample returns:
image: Tensor of shape (C, H, W)
multilabel_target: Tensor of shape (14,) — one-hot encoded pathology labels
binary_target: Tensor of shape (1,) — 0 for Normal, 1 for Abnormal
"""
def __init__(
self,
image_dir: Path | str,
labels_csv: Path | str,
split: str = "train",
image_size: int = 224,
transform: Any | None = None,
) -> None:
self.image_dir = Path(image_dir)
self.split = split
self.image_size = image_size
# Load labels CSV
df = pd.read_csv(labels_csv)
if "split" in df.columns:
df = df[df["split"] == split].reset_index(drop=True)
self.image_paths = df["image_path"].tolist()
self.labels_raw = df["labels"].tolist()
# Set transforms
if transform is not None:
self.transform = transform
elif split == "train":
self.transform = get_train_transforms(image_size)
else:
self.transform = get_eval_transforms(image_size)
# Precompute label vectors
self._multilabel_targets = self._encode_multilabel(self.labels_raw)
self._binary_targets = self._encode_binary(self.labels_raw)
def _encode_multilabel(self, labels_list: list[str]) -> np.ndarray:
"""Convert string labels to multi-hot vectors."""
targets = np.zeros((len(labels_list), NUM_CLASSES), dtype=np.float32)
for i, labels_str in enumerate(labels_list):
if labels_str == "No Finding" or pd.isna(labels_str):
continue
for label in labels_str.split("|"):
label = label.strip()
if label in PATHOLOGY_LABELS:
targets[i, PATHOLOGY_LABELS.index(label)] = 1.0
return targets
def _encode_binary(self, labels_list: list[str]) -> np.ndarray:
"""Convert labels to binary: 0=Normal, 1=Abnormal."""
targets = np.zeros((len(labels_list), 1), dtype=np.float32)
for i, labels_str in enumerate(labels_list):
if labels_str != "No Finding" and not pd.isna(labels_str):
targets[i, 0] = 1.0
return targets
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
# Load image
img_path = self.image_dir / self.image_paths[idx]
pil_image = Image.open(img_path).convert("RGB")
# Apply transforms — always required; get_train/eval_transforms() return a Tensor
if self.transform is None:
raise ValueError(
"ChestXrayDataset requires a transform. "
"Use get_train_transforms() or get_eval_transforms()."
)
image: torch.Tensor = self.transform(pil_image)
return {
"image": image,
"multilabel_target": torch.from_numpy(self._multilabel_targets[idx]),
"binary_target": torch.from_numpy(self._binary_targets[idx]),
}
def get_label_weights(self) -> torch.Tensor:
"""Compute positive class weights for handling class imbalance (pos_weight for BCEWithLogitsLoss)."""
pos_counts = self._multilabel_targets.sum(axis=0)
neg_counts = len(self) - pos_counts
# Avoid division by zero
pos_weights = neg_counts / np.maximum(pos_counts, 1.0)
return torch.from_numpy(pos_weights.astype(np.float32))