File size: 6,823 Bytes
8960670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#!/usr/bin/env python3
"""
LUNA16 PyTorch Dataset for DCA-Net training.
Loads preprocessed .npz patches and applies on-the-fly augmentation.
"""

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path


class LunaDataset(Dataset):
    """
    PyTorch Dataset for LUNA16 preprocessed patches.
    
    Loads nodule (64³) and context (48³) patches from .npz files.
    Applies random augmentation during training.
    
    Args:
        csv_path: Path to metadata CSV (train_samples.csv, etc.)
        augment: Whether to apply data augmentation
        aug_config: Augmentation configuration dict
    """

    def __init__(self, csv_path, augment=False, aug_config=None, curriculum_stage=None):
        self.metadata = pd.read_csv(csv_path)
        self.augment = augment
        self.aug_config = aug_config or {}
        
        # Apply curriculum filtering
        if curriculum_stage is not None and 'is_hard_negative' in self.metadata.columns:
            original_len = len(self.metadata)
            if curriculum_stage == 1:
                # Stage 1: Easy samples only — positives + non-hard negatives
                self.metadata = self.metadata[
                    (self.metadata['label'] == 1) |
                    (self.metadata['is_hard_negative'] == False)
                ].reset_index(drop=True)
            elif curriculum_stage == 2:
                # Stage 2: All samples (same as stage 3 for our data)
                pass  # Use all samples
            # Stage 3 or None: use all samples
            filtered_len = len(self.metadata)
            if filtered_len != original_len:
                pos = (self.metadata['label'] == 1).sum()
                neg = filtered_len - pos
                print(f"  Curriculum stage {curriculum_stage}: {original_len}{filtered_len} samples ({pos} pos, {neg} neg)")
        
        # Verify a sample exists
        if len(self.metadata) > 0:
            sample = self.metadata.iloc[0]
            if not Path(sample['nodule_path']).exists():
                raise FileNotFoundError(
                    f"Patch file not found: {sample['nodule_path']}. "
                    "Check that preprocessed_data/ paths are correct."
                )

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]

        # Load patches
        nodule_patch = np.load(row['nodule_path'])['patch'].astype(np.float32)
        context_patch = np.load(row['context_path'])['patch'].astype(np.float32)
        label = np.float32(row['label'])

        # Apply augmentation
        if self.augment:
            nodule_patch, context_patch = self._augment(
                nodule_patch, context_patch
            )

        # Convert to tensors: add channel dim → (1, D, H, W)
        nodule_tensor = torch.from_numpy(nodule_patch).unsqueeze(0)
        context_tensor = torch.from_numpy(context_patch).unsqueeze(0)
        label_tensor = torch.tensor(label)

        return nodule_tensor, context_tensor, label_tensor

    def _augment(self, nodule, context):
        """Apply random augmentations to both patches consistently."""
        cfg = self.aug_config

        # Random rotation (90° increments along each axis)
        if cfg.get('rotation', True):
            k = np.random.randint(0, 4)
            axes = [(0, 1), (0, 2), (1, 2)]
            ax = axes[np.random.randint(0, 3)]
            nodule = np.rot90(nodule, k=k, axes=ax).copy()
            context = np.rot90(context, k=k, axes=ax).copy()

        # Random flip
        if cfg.get('flip', True):
            for axis in range(3):
                if np.random.rand() > 0.5:
                    nodule = np.flip(nodule, axis=axis).copy()
                    context = np.flip(context, axis=axis).copy()

        # Gaussian noise
        if cfg.get('noise', True):
            std = cfg.get('noise_std', 0.05)
            noise = np.random.normal(0, std, nodule.shape).astype(np.float32)
            nodule = nodule + noise
            noise_c = np.random.normal(0, std, context.shape).astype(np.float32)
            context = context + noise_c

        # Random intensity shift
        if cfg.get('intensity_shift', 0) > 0:
            shift = np.random.uniform(
                -cfg['intensity_shift'], cfg['intensity_shift']
            )
            nodule = nodule + shift
            context = context + shift

        # Clamp back to [-1, 1]
        nodule = np.clip(nodule, -1.0, 1.0)
        context = np.clip(context, -1.0, 1.0)

        return nodule, context


def create_data_loaders(config, curriculum_stage=None):
    """Create train, validation, and test DataLoaders from config.
    
    Args:
        config: Full training configuration dict
        curriculum_stage: Optional curriculum stage (1, 2, or 3) for train set filtering
        
    Returns:
        train_loader, val_loader, test_loader
    """
    data_cfg = config.get('data', {})
    preprocessed_dir = Path(data_cfg.get('preprocessed_dir', 'preprocessed_data'))
    metadata_dir = preprocessed_dir / 'metadata'

    aug_config = data_cfg.get('augmentation', {})
    
    train_csv = metadata_dir / 'train_samples.csv'
    val_csv = metadata_dir / 'val_samples.csv'
    test_csv = metadata_dir / 'test_samples.csv'

    # Check files exist
    for csv_path in [train_csv, val_csv, test_csv]:
        if not csv_path.exists():
            raise FileNotFoundError(
                f"Metadata CSV not found: {csv_path}. "
                "Run generate_metadata.py first."
            )

    train_dataset = LunaDataset(
        train_csv, augment=aug_config.get('enabled', True),
        aug_config=aug_config, curriculum_stage=curriculum_stage
    )
    val_dataset = LunaDataset(val_csv, augment=False)
    test_dataset = LunaDataset(test_csv, augment=False)

    loader_kwargs = {
        'num_workers': data_cfg.get('num_workers', 4),
        'pin_memory': data_cfg.get('pin_memory', True),
        'persistent_workers': data_cfg.get('persistent_workers', False)
            and data_cfg.get('num_workers', 4) > 0,
    }
    # prefetch_factor only valid when num_workers > 0
    if data_cfg.get('num_workers', 4) > 0:
        loader_kwargs['prefetch_factor'] = data_cfg.get('prefetch_factor', 2)

    batch_size = config.get('training', {}).get('batch_size', 16)

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        drop_last=True, **loader_kwargs
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, **loader_kwargs
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, **loader_kwargs
    )

    return train_loader, val_loader, test_loader