File size: 4,275 Bytes
ac0940b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91347e8
 
 
 
 
 
 
 
 
ac0940b
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""PyTorch Dataset for NIH Chest X-ray14."""

from __future__ import annotations

from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset

from src.data.transforms import get_eval_transforms, get_train_transforms

# All 14 pathology labels in the NIH Chest X-ray14 dataset
PATHOLOGY_LABELS = [
    "Atelectasis",
    "Cardiomegaly",
    "Effusion",
    "Infiltration",
    "Mass",
    "Nodule",
    "Pneumonia",
    "Pneumothorax",
    "Consolidation",
    "Edema",
    "Emphysema",
    "Fibrosis",
    "Pleural_Thickening",
    "Hernia",
]

NUM_CLASSES = len(PATHOLOGY_LABELS)


class ChestXrayDataset(Dataset):
    """NIH Chest X-ray14 dataset for multi-label and binary classification.

    Each sample returns:
        image: Tensor of shape (C, H, W)
        multilabel_target: Tensor of shape (14,) — one-hot encoded pathology labels
        binary_target: Tensor of shape (1,) — 0 for Normal, 1 for Abnormal
    """

    def __init__(
        self,
        image_dir: Path | str,
        labels_csv: Path | str,
        split: str = "train",
        image_size: int = 224,
        transform: Any | None = None,
    ) -> None:
        self.image_dir = Path(image_dir)
        self.split = split
        self.image_size = image_size

        # Load labels CSV
        df = pd.read_csv(labels_csv)
        if "split" in df.columns:
            df = df[df["split"] == split].reset_index(drop=True)

        self.image_paths = df["image_path"].tolist()
        self.labels_raw = df["labels"].tolist()

        # Set transforms
        if transform is not None:
            self.transform = transform
        elif split == "train":
            self.transform = get_train_transforms(image_size)
        else:
            self.transform = get_eval_transforms(image_size)

        # Precompute label vectors
        self._multilabel_targets = self._encode_multilabel(self.labels_raw)
        self._binary_targets = self._encode_binary(self.labels_raw)

    def _encode_multilabel(self, labels_list: list[str]) -> np.ndarray:
        """Convert string labels to multi-hot vectors."""
        targets = np.zeros((len(labels_list), NUM_CLASSES), dtype=np.float32)
        for i, labels_str in enumerate(labels_list):
            if labels_str == "No Finding" or pd.isna(labels_str):
                continue
            for label in labels_str.split("|"):
                label = label.strip()
                if label in PATHOLOGY_LABELS:
                    targets[i, PATHOLOGY_LABELS.index(label)] = 1.0
        return targets

    def _encode_binary(self, labels_list: list[str]) -> np.ndarray:
        """Convert labels to binary: 0=Normal, 1=Abnormal."""
        targets = np.zeros((len(labels_list), 1), dtype=np.float32)
        for i, labels_str in enumerate(labels_list):
            if labels_str != "No Finding" and not pd.isna(labels_str):
                targets[i, 0] = 1.0
        return targets

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
        # Load image
        img_path = self.image_dir / self.image_paths[idx]
        pil_image = Image.open(img_path).convert("RGB")

        # Apply transforms — always required; get_train/eval_transforms() return a Tensor
        if self.transform is None:
            raise ValueError(
                "ChestXrayDataset requires a transform. "
                "Use get_train_transforms() or get_eval_transforms()."
            )
        image: torch.Tensor = self.transform(pil_image)

        return {
            "image": image,
            "multilabel_target": torch.from_numpy(self._multilabel_targets[idx]),
            "binary_target": torch.from_numpy(self._binary_targets[idx]),
        }

    def get_label_weights(self) -> torch.Tensor:
        """Compute positive class weights for handling class imbalance (pos_weight for BCEWithLogitsLoss)."""
        pos_counts = self._multilabel_targets.sum(axis=0)
        neg_counts = len(self) - pos_counts
        # Avoid division by zero
        pos_weights = neg_counts / np.maximum(pos_counts, 1.0)
        return torch.from_numpy(pos_weights.astype(np.float32))