Spaces:
Sleeping
Sleeping
File size: 6,823 Bytes
8960670 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | #!/usr/bin/env python3
"""
LUNA16 PyTorch Dataset for DCA-Net training.
Loads preprocessed .npz patches and applies on-the-fly augmentation.
"""
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
class LunaDataset(Dataset):
"""
PyTorch Dataset for LUNA16 preprocessed patches.
Loads nodule (64³) and context (48³) patches from .npz files.
Applies random augmentation during training.
Args:
csv_path: Path to metadata CSV (train_samples.csv, etc.)
augment: Whether to apply data augmentation
aug_config: Augmentation configuration dict
"""
def __init__(self, csv_path, augment=False, aug_config=None, curriculum_stage=None):
self.metadata = pd.read_csv(csv_path)
self.augment = augment
self.aug_config = aug_config or {}
# Apply curriculum filtering
if curriculum_stage is not None and 'is_hard_negative' in self.metadata.columns:
original_len = len(self.metadata)
if curriculum_stage == 1:
# Stage 1: Easy samples only — positives + non-hard negatives
self.metadata = self.metadata[
(self.metadata['label'] == 1) |
(self.metadata['is_hard_negative'] == False)
].reset_index(drop=True)
elif curriculum_stage == 2:
# Stage 2: All samples (same as stage 3 for our data)
pass # Use all samples
# Stage 3 or None: use all samples
filtered_len = len(self.metadata)
if filtered_len != original_len:
pos = (self.metadata['label'] == 1).sum()
neg = filtered_len - pos
print(f" Curriculum stage {curriculum_stage}: {original_len} → {filtered_len} samples ({pos} pos, {neg} neg)")
# Verify a sample exists
if len(self.metadata) > 0:
sample = self.metadata.iloc[0]
if not Path(sample['nodule_path']).exists():
raise FileNotFoundError(
f"Patch file not found: {sample['nodule_path']}. "
"Check that preprocessed_data/ paths are correct."
)
def __len__(self):
return len(self.metadata)
def __getitem__(self, idx):
row = self.metadata.iloc[idx]
# Load patches
nodule_patch = np.load(row['nodule_path'])['patch'].astype(np.float32)
context_patch = np.load(row['context_path'])['patch'].astype(np.float32)
label = np.float32(row['label'])
# Apply augmentation
if self.augment:
nodule_patch, context_patch = self._augment(
nodule_patch, context_patch
)
# Convert to tensors: add channel dim → (1, D, H, W)
nodule_tensor = torch.from_numpy(nodule_patch).unsqueeze(0)
context_tensor = torch.from_numpy(context_patch).unsqueeze(0)
label_tensor = torch.tensor(label)
return nodule_tensor, context_tensor, label_tensor
def _augment(self, nodule, context):
"""Apply random augmentations to both patches consistently."""
cfg = self.aug_config
# Random rotation (90° increments along each axis)
if cfg.get('rotation', True):
k = np.random.randint(0, 4)
axes = [(0, 1), (0, 2), (1, 2)]
ax = axes[np.random.randint(0, 3)]
nodule = np.rot90(nodule, k=k, axes=ax).copy()
context = np.rot90(context, k=k, axes=ax).copy()
# Random flip
if cfg.get('flip', True):
for axis in range(3):
if np.random.rand() > 0.5:
nodule = np.flip(nodule, axis=axis).copy()
context = np.flip(context, axis=axis).copy()
# Gaussian noise
if cfg.get('noise', True):
std = cfg.get('noise_std', 0.05)
noise = np.random.normal(0, std, nodule.shape).astype(np.float32)
nodule = nodule + noise
noise_c = np.random.normal(0, std, context.shape).astype(np.float32)
context = context + noise_c
# Random intensity shift
if cfg.get('intensity_shift', 0) > 0:
shift = np.random.uniform(
-cfg['intensity_shift'], cfg['intensity_shift']
)
nodule = nodule + shift
context = context + shift
# Clamp back to [-1, 1]
nodule = np.clip(nodule, -1.0, 1.0)
context = np.clip(context, -1.0, 1.0)
return nodule, context
def create_data_loaders(config, curriculum_stage=None):
"""Create train, validation, and test DataLoaders from config.
Args:
config: Full training configuration dict
curriculum_stage: Optional curriculum stage (1, 2, or 3) for train set filtering
Returns:
train_loader, val_loader, test_loader
"""
data_cfg = config.get('data', {})
preprocessed_dir = Path(data_cfg.get('preprocessed_dir', 'preprocessed_data'))
metadata_dir = preprocessed_dir / 'metadata'
aug_config = data_cfg.get('augmentation', {})
train_csv = metadata_dir / 'train_samples.csv'
val_csv = metadata_dir / 'val_samples.csv'
test_csv = metadata_dir / 'test_samples.csv'
# Check files exist
for csv_path in [train_csv, val_csv, test_csv]:
if not csv_path.exists():
raise FileNotFoundError(
f"Metadata CSV not found: {csv_path}. "
"Run generate_metadata.py first."
)
train_dataset = LunaDataset(
train_csv, augment=aug_config.get('enabled', True),
aug_config=aug_config, curriculum_stage=curriculum_stage
)
val_dataset = LunaDataset(val_csv, augment=False)
test_dataset = LunaDataset(test_csv, augment=False)
loader_kwargs = {
'num_workers': data_cfg.get('num_workers', 4),
'pin_memory': data_cfg.get('pin_memory', True),
'persistent_workers': data_cfg.get('persistent_workers', False)
and data_cfg.get('num_workers', 4) > 0,
}
# prefetch_factor only valid when num_workers > 0
if data_cfg.get('num_workers', 4) > 0:
loader_kwargs['prefetch_factor'] = data_cfg.get('prefetch_factor', 2)
batch_size = config.get('training', {}).get('batch_size', 16)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
drop_last=True, **loader_kwargs
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, **loader_kwargs
)
test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, **loader_kwargs
)
return train_loader, val_loader, test_loader
|