#!/usr/bin/env python3 """ LUNA16 PyTorch Dataset for DCA-Net training. Loads preprocessed .npz patches and applies on-the-fly augmentation. """ import numpy as np import pandas as pd import torch from torch.utils.data import Dataset, DataLoader from pathlib import Path class LunaDataset(Dataset): """ PyTorch Dataset for LUNA16 preprocessed patches. Loads nodule (64³) and context (48³) patches from .npz files. Applies random augmentation during training. Args: csv_path: Path to metadata CSV (train_samples.csv, etc.) augment: Whether to apply data augmentation aug_config: Augmentation configuration dict """ def __init__(self, csv_path, augment=False, aug_config=None, curriculum_stage=None): self.metadata = pd.read_csv(csv_path) self.augment = augment self.aug_config = aug_config or {} # Apply curriculum filtering if curriculum_stage is not None and 'is_hard_negative' in self.metadata.columns: original_len = len(self.metadata) if curriculum_stage == 1: # Stage 1: Easy samples only — positives + non-hard negatives self.metadata = self.metadata[ (self.metadata['label'] == 1) | (self.metadata['is_hard_negative'] == False) ].reset_index(drop=True) elif curriculum_stage == 2: # Stage 2: All samples (same as stage 3 for our data) pass # Use all samples # Stage 3 or None: use all samples filtered_len = len(self.metadata) if filtered_len != original_len: pos = (self.metadata['label'] == 1).sum() neg = filtered_len - pos print(f" Curriculum stage {curriculum_stage}: {original_len} → {filtered_len} samples ({pos} pos, {neg} neg)") # Verify a sample exists if len(self.metadata) > 0: sample = self.metadata.iloc[0] if not Path(sample['nodule_path']).exists(): raise FileNotFoundError( f"Patch file not found: {sample['nodule_path']}. " "Check that preprocessed_data/ paths are correct." ) def __len__(self): return len(self.metadata) def __getitem__(self, idx): row = self.metadata.iloc[idx] # Load patches nodule_patch = np.load(row['nodule_path'])['patch'].astype(np.float32) context_patch = np.load(row['context_path'])['patch'].astype(np.float32) label = np.float32(row['label']) # Apply augmentation if self.augment: nodule_patch, context_patch = self._augment( nodule_patch, context_patch ) # Convert to tensors: add channel dim → (1, D, H, W) nodule_tensor = torch.from_numpy(nodule_patch).unsqueeze(0) context_tensor = torch.from_numpy(context_patch).unsqueeze(0) label_tensor = torch.tensor(label) return nodule_tensor, context_tensor, label_tensor def _augment(self, nodule, context): """Apply random augmentations to both patches consistently.""" cfg = self.aug_config # Random rotation (90° increments along each axis) if cfg.get('rotation', True): k = np.random.randint(0, 4) axes = [(0, 1), (0, 2), (1, 2)] ax = axes[np.random.randint(0, 3)] nodule = np.rot90(nodule, k=k, axes=ax).copy() context = np.rot90(context, k=k, axes=ax).copy() # Random flip if cfg.get('flip', True): for axis in range(3): if np.random.rand() > 0.5: nodule = np.flip(nodule, axis=axis).copy() context = np.flip(context, axis=axis).copy() # Gaussian noise if cfg.get('noise', True): std = cfg.get('noise_std', 0.05) noise = np.random.normal(0, std, nodule.shape).astype(np.float32) nodule = nodule + noise noise_c = np.random.normal(0, std, context.shape).astype(np.float32) context = context + noise_c # Random intensity shift if cfg.get('intensity_shift', 0) > 0: shift = np.random.uniform( -cfg['intensity_shift'], cfg['intensity_shift'] ) nodule = nodule + shift context = context + shift # Clamp back to [-1, 1] nodule = np.clip(nodule, -1.0, 1.0) context = np.clip(context, -1.0, 1.0) return nodule, context def create_data_loaders(config, curriculum_stage=None): """Create train, validation, and test DataLoaders from config. Args: config: Full training configuration dict curriculum_stage: Optional curriculum stage (1, 2, or 3) for train set filtering Returns: train_loader, val_loader, test_loader """ data_cfg = config.get('data', {}) preprocessed_dir = Path(data_cfg.get('preprocessed_dir', 'preprocessed_data')) metadata_dir = preprocessed_dir / 'metadata' aug_config = data_cfg.get('augmentation', {}) train_csv = metadata_dir / 'train_samples.csv' val_csv = metadata_dir / 'val_samples.csv' test_csv = metadata_dir / 'test_samples.csv' # Check files exist for csv_path in [train_csv, val_csv, test_csv]: if not csv_path.exists(): raise FileNotFoundError( f"Metadata CSV not found: {csv_path}. " "Run generate_metadata.py first." ) train_dataset = LunaDataset( train_csv, augment=aug_config.get('enabled', True), aug_config=aug_config, curriculum_stage=curriculum_stage ) val_dataset = LunaDataset(val_csv, augment=False) test_dataset = LunaDataset(test_csv, augment=False) loader_kwargs = { 'num_workers': data_cfg.get('num_workers', 4), 'pin_memory': data_cfg.get('pin_memory', True), 'persistent_workers': data_cfg.get('persistent_workers', False) and data_cfg.get('num_workers', 4) > 0, } # prefetch_factor only valid when num_workers > 0 if data_cfg.get('num_workers', 4) > 0: loader_kwargs['prefetch_factor'] = data_cfg.get('prefetch_factor', 2) batch_size = config.get('training', {}).get('batch_size', 16) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, **loader_kwargs ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, **loader_kwargs ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, **loader_kwargs ) return train_loader, val_loader, test_loader