| """ |
| P1 & P4 Articles - Data Loading System |
| |
| Complete implementation for brain segmentation experiments |
| |
| Specialized Gray Matter (GM) Segmentation with U-Net Models - Journal Paper Implementation |
| Binary segmentation: Background vs Specialized GM |
| Professional results saving and visualization for publication |
| |
| This relates to our articles: |
| "Specialized gray matter segmentation via a generative adversarial network: |
| application on brain white matter hyperintensities classification" |
| |
| "Deep Learning-Based Neuroanatomical Profiling Reveals Detailed Brain Changes: |
| A Large-Scale Multiple Sclerosis Study" |
| |
| Features: |
| - Load FLAIR images and individual mask files from Cohort directory |
| - Support both Local_SAI_GM_sp dataset |
| - Handle standard and zoomed preprocessing variants |
| - Combine masks into 2-class format |
| - Create paired inputs: [FLAIR | mask] concatenated (256x512) |
| - Patient-stratified K-fold cross-validation |
| - TensorFlow dataset creation with proper batching |
| |
| Authors: |
| "Mahdi Bashiri Bawil, Mousa Shamsi, Abolhassan Shakeri Bavil" |
| |
| Developer: |
| "Mahdi Bashiri Bawil" |
| """ |
|
|
| import numpy as np |
| import os |
| from pathlib import Path |
| from typing import Tuple, List, Dict, Optional |
| import json |
| from sklearn.model_selection import KFold |
| from tqdm import tqdm |
| import cv2 as cv |
|
|
| |
| import tensorflow as tf |
|
|
|
|
| |
|
|
| class DataConfig: |
| """Data configuration for P4 experiments""" |
| |
| def __init__(self): |
| |
| self.cohort_dir = Path("/mnt/e/MBashiri/ours_articles/Paper#2/Data/Cohort") |
| |
| |
| self.datasets = { |
| 'Local_SAI_GM_sp': { |
| 'base_path': self.cohort_dir / 'Local_SAI_GM_sp', |
| 'slice_range': (1, 20), |
| 'patient_prefix_length': 6 |
| } |
| } |
| |
| |
| self.preprocessing_types = ['standard', 'zoomed'] |
| |
| |
| self.class_scenarios = { |
| 'binary': { |
| 'num_classes': 2, |
| 'class_names': ['Background', 'Specialized GM'], |
| 'description': 'Binary: Background, Specialized GM', |
| 'class_mapping': { |
| 'background': 0, |
| 'specialized_gm': 1, |
| } |
| } |
| } |
| |
| |
| self.k_folds = 5 |
| self.test_split = 0.1 |
| self.random_state = 42 |
| |
| |
| self.target_size = (256, 256) |
| self.paired_width = 512 |
| |
| |
| self.splits_dir = Path("data_splits_sp_gm") |
| self.splits_file = self.splits_dir / "SP_GM_fold_assignments.json" |
|
|
|
|
| |
|
|
| def extract_patient_id(filename: str, prefix_length: int = 6) -> str: |
| """ |
| Extract patient ID from filename |
| |
| Args: |
| filename: e.g., "101228_5.npy" or "c01p01_25.png" |
| prefix_length: Number of characters in patient ID |
| |
| Returns: |
| Patient ID: e.g., "101228" or "c01p01" |
| """ |
| return filename.split('_')[0][:prefix_length] |
|
|
|
|
| def extract_slice_number(filename: str) -> int: |
| """ |
| Extract slice number from filename |
| |
| Args: |
| filename: e.g., "101228_5.npy" or "c01p01_25.png" |
| |
| Returns: |
| Slice number as integer |
| """ |
| |
| basename = filename.split('.')[0] |
| |
| slice_num = basename.split('_')[-1] |
| return int(slice_num) |
|
|
|
|
| def load_flair_image(flair_path: Path, normalize: bool = False, of_z_score: bool = False) -> np.ndarray: |
| """ |
| Load FLAIR image (.png format) |
| |
| Args: |
| flair_path: Path to .png file |
| normalize: Whether to apply z-score normalization |
| |
| Returns: |
| FLAIR image (256, 256, 1) as float32 |
| """ |
| if of_z_score: |
| |
| flair = np.load(str(flair_path).replace('.png','.npy')).astype(np.float32) |
| else: |
| |
| flair = cv.imread(str(flair_path), cv.IMREAD_GRAYSCALE).astype(np.float32) |
|
|
| |
| flair = (flair - np.min(flair)) / (np.max(flair) - np.min(flair)) |
| flair = (2 * flair) - 1 |
| |
| |
| if len(flair.shape) == 2: |
| flair = np.expand_dims(flair, axis=-1) |
| |
| |
| if normalize and (np.std(flair) > 2.0 or np.abs(np.mean(flair)) > 1.0): |
| |
| flair = (flair - np.mean(flair)) / (np.std(flair) + 1e-7) |
| |
| return flair |
|
|
|
|
| def load_mask_image(mask_path: Path) -> np.ndarray: |
| """ |
| Load mask image (.png format) |
| |
| Args: |
| mask_path: Path to .png file |
| |
| Returns: |
| Binary mask (256, 256) as uint8 |
| """ |
| |
| mask = cv.imread(str(mask_path), cv.IMREAD_GRAYSCALE) |
| |
| if mask is None: |
| raise FileNotFoundError(f"Could not load mask: {mask_path}") |
| |
| |
| mask = (mask > 0).astype(np.uint8) |
| |
| return mask |
|
|
|
|
| def combine_masks(gm_mask: np.ndarray, |
| class_scenario: str, |
| preprocess: bool = False) -> np.ndarray: |
| """ |
| Combine individual masks into multi-class format |
| |
| Args: |
| gm_mask: Ventricles mask (256, 256) |
| class_scenario: 'binary' |
| preprocess: Boolean turning the morphological preprocessing on or off |
| |
| Returns: |
| Combined mask (256, 256) with class labels |
| """ |
| if preprocess: |
| from skimage.morphology import remove_small_objects, binary_erosion, binary_closing, binary_opening, disk, binary_dilation |
| min_object_size = 5 |
| closing_kernel_size = 2 |
| dilation_kernel_size = 1 |
|
|
| gm_mask = gm_mask > 0 |
|
|
| gm_mask = binary_closing(gm_mask, disk(closing_kernel_size)) |
| gm_mask = binary_erosion(gm_mask, disk(dilation_kernel_size)) |
| gm_mask = remove_small_objects(gm_mask, min_size=min_object_size) |
|
|
| |
| |
| combined = np.zeros_like(gm_mask, dtype=np.uint8) |
| combined[gm_mask>0] = 1 |
| |
| return combined |
|
|
|
|
| def is_valid_slice(gm_mask: np.ndarray) -> bool: |
| """ |
| Check if slice has at least one non-empty mask |
| |
| Args: |
| gm_mask: Specialized GM mask (256, 256) |
| |
| Returns: |
| True if at least one mask has non-zero pixels |
| """ |
| has_specialized_gm = np.sum(gm_mask) > 50 |
| |
| |
| return True |
|
|
|
|
| def create_paired_input(flair: np.ndarray, |
| mask: np.ndarray, |
| brain_mask: np.ndarray, |
| num_classes: np.ndarray, |
| if_bet=False) -> np.ndarray: |
| """ |
| Create paired input: [FLAIR | mask] concatenated horizontally |
| |
| Args: |
| flair: FLAIR image (256, 256, 1) float32 |
| mask: Combined mask (256, 256) uint8 |
| |
| Returns: |
| Paired image (256, 512, 1) float32 |
| """ |
| |
| brain_mask = brain_mask > 0 |
|
|
| |
| if if_bet: |
| |
| flair[~brain_mask] = np.min(flair) |
| mask[~brain_mask] = 0 |
|
|
| |
| if len(flair.shape) == 2: |
| flair = np.expand_dims(flair, axis=-1) |
| |
| |
|
|
| max_class = num_classes |
| mask_normalized = mask.astype(np.float32) |
| if max_class > 0: |
| mask_normalized = mask_normalized / max_class |
| mask_normalized = (2 * mask_normalized) - 1 |
| |
| mask_3d = np.expand_dims(mask_normalized, axis=-1) |
| |
| |
| paired = np.concatenate([flair, mask_3d], axis=1) |
| |
| return paired, mask |
|
|
|
|
| |
|
|
| class PatientStratifiedSplitter: |
| """ |
| Create patient-stratified train/val/test splits |
| Similar to P6 implementation but adapted for P1 data structure |
| """ |
| |
| def __init__(self, config: DataConfig): |
| self.config = config |
| self.config.splits_dir.mkdir(exist_ok=True) |
| |
| def collect_all_patients(self) -> Dict[str, List[str]]: |
| """ |
| Collect all unique patient IDs from both datasets |
| |
| Returns: |
| Dictionary mapping dataset_name -> list of patient IDs |
| """ |
| all_patients = {} |
| |
| for dataset_name, dataset_config in self.config.datasets.items(): |
| patients = set() |
| |
| |
| flair_dir = dataset_config['base_path'] / 'FLAIR' / 'Preprocessed' / 'images' |
| |
| if not flair_dir.exists(): |
| print(f"Warning: {flair_dir} does not exist. Skipping {dataset_name}.") |
| continue |
| |
| |
| for flair_file in flair_dir.glob('*.png'): |
| patient_id = extract_patient_id( |
| flair_file.name, |
| dataset_config['patient_prefix_length'] |
| ) |
| patients.add(patient_id) |
| |
| all_patients[dataset_name] = sorted(list(patients)) |
| print(f"{dataset_name}: {len(all_patients[dataset_name])} patients") |
| |
| return all_patients |
| |
| def create_patient_stratified_splits(self, |
| save: bool = True) -> Dict: |
| """ |
| Create patient-stratified K-fold splits |
| |
| Returns: |
| Dictionary containing fold assignments |
| """ |
| all_patients = self.collect_all_patients() |
| |
| |
| combined_patients = [] |
| for dataset_name, patients in all_patients.items(): |
| combined_patients.extend(patients) |
| |
| combined_patients = np.array(combined_patients) |
| total_patients = len(combined_patients) |
| |
| print(f"\nTotal unique patients: {total_patients}") |
| |
| |
| np.random.seed(self.config.random_state) |
| test_size = int(total_patients * self.config.test_split) |
| |
| test_indices = np.random.choice( |
| total_patients, |
| size=test_size, |
| replace=False |
| ) |
| |
| test_patients = combined_patients[test_indices] |
| train_val_indices = np.setdiff1d(np.arange(total_patients), test_indices) |
| train_val_patients = combined_patients[train_val_indices] |
| |
| print(f"Test patients: {len(test_patients)}") |
| print(f"Train+Val patients: {len(train_val_patients)}") |
| |
| |
| kfold = KFold( |
| n_splits=self.config.k_folds, |
| shuffle=True, |
| random_state=self.config.random_state |
| ) |
| |
| fold_assignments = { |
| 'metadata': { |
| 'total_patients': total_patients, |
| 'test_patients': len(test_patients), |
| 'trainval_patients': len(train_val_patients), |
| 'n_folds': self.config.k_folds, |
| 'random_seed': self.config.random_state, |
| 'datasets': list(all_patients.keys()) |
| }, |
| 'test_set': { |
| 'patients': test_patients.tolist(), |
| 'n_patients': len(test_patients) |
| }, |
| 'folds': {} |
| } |
| |
| for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(train_val_patients)): |
| train_patients_fold = train_val_patients[train_idx] |
| val_patients_fold = train_val_patients[val_idx] |
| |
| fold_assignments['folds'][f'fold_{fold_idx}'] = { |
| 'train_patients': train_patients_fold.tolist(), |
| 'val_patients': val_patients_fold.tolist(), |
| 'n_train': len(train_patients_fold), |
| 'n_val': len(val_patients_fold) |
| } |
| |
| print(f"Fold {fold_idx}: Train={len(train_patients_fold)}, Val={len(val_patients_fold)}") |
| |
| |
| if save: |
| with open(self.config.splits_file, 'w') as f: |
| json.dump(fold_assignments, f, indent=2) |
| print(f"\nΓ’Εβ¦ Fold assignments saved to: {self.config.splits_file}") |
| |
| return fold_assignments |
| |
| def load_fold_assignments(self) -> Dict: |
| """Load existing fold assignments from JSON""" |
| if not self.config.splits_file.exists(): |
| raise FileNotFoundError( |
| f"Fold assignments not found: {self.config.splits_file}\n" |
| f"Run create_patient_stratified_splits() first." |
| ) |
| |
| with open(self.config.splits_file, 'r') as f: |
| fold_assignments = json.load(f) |
| |
| return fold_assignments |
| |
| def verify_patient_separation(self, fold_assignments: Dict) -> bool: |
| """ |
| Verify no patient appears in multiple folds or in both train/val |
| Similar to P6's verification logic |
| """ |
| print("\n" + "="*60) |
| print("VERIFYING PATIENT SEPARATION") |
| print("="*60) |
| |
| all_issues = [] |
| test_patients = set(fold_assignments['test_set']['patients']) |
| |
| |
| for fold_name, fold_data in fold_assignments['folds'].items(): |
| train_patients = set(fold_data['train_patients']) |
| val_patients = set(fold_data['val_patients']) |
| |
| test_train_overlap = test_patients.intersection(train_patients) |
| test_val_overlap = test_patients.intersection(val_patients) |
| |
| if test_train_overlap: |
| issue = f"{fold_name}: Test-Train overlap: {test_train_overlap}" |
| all_issues.append(issue) |
| print(f"β {issue}") |
| |
| if test_val_overlap: |
| issue = f"{fold_name}: Test-Val overlap: {test_val_overlap}" |
| all_issues.append(issue) |
| print(f"β {issue}") |
| |
| |
| for fold_name, fold_data in fold_assignments['folds'].items(): |
| train_patients = set(fold_data['train_patients']) |
| val_patients = set(fold_data['val_patients']) |
| |
| train_val_overlap = train_patients.intersection(val_patients) |
| if train_val_overlap: |
| issue = f"{fold_name}: Train-Val overlap: {train_val_overlap}" |
| all_issues.append(issue) |
| print(f"β {issue}") |
| |
| |
| all_val_patients = [] |
| for fold_data in fold_assignments['folds'].values(): |
| all_val_patients.extend(fold_data['val_patients']) |
| |
| val_patient_counts = {} |
| for patient in all_val_patients: |
| val_patient_counts[patient] = val_patient_counts.get(patient, 0) + 1 |
| |
| for patient, count in val_patient_counts.items(): |
| if count != 1: |
| issue = f"Patient {patient} in validation {count} times (should be 1)" |
| all_issues.append(issue) |
| print(f"β {issue}") |
| |
| if not all_issues: |
| print("Γ’Εβ¦ All patient separation checks passed") |
| print("Γ’Εβ¦ No data leakage detected") |
| return True |
| else: |
| print(f"\nβ Found {len(all_issues)} issues") |
| return False |
|
|
|
|
| |
|
|
| class P1DataLoader: |
| """ |
| Main data loader for P1 experiments |
| Handles loading FLAIR and masks, creating paired inputs, TensorFlow datasets |
| """ |
| |
| def __init__(self, config: DataConfig): |
| self.config = config |
| |
| def get_file_paths(self, |
| patient_id: str, |
| slice_num: int, |
| dataset_name: str, |
| preprocessing: str) -> Dict[str, Path]: |
| """ |
| Construct file paths for a given patient-slice |
| |
| Args: |
| patient_id: e.g., "101228" or "c01p01" |
| slice_num: Slice number |
| dataset_name: 'Local_SAI_GM_sp' |
| preprocessing: 'standard' or 'zoomed' |
| |
| Returns: |
| Dictionary with paths to FLAIR and mask files |
| """ |
| dataset_config = self.config.datasets[dataset_name] |
| base_path = dataset_config['base_path'] |
| |
| |
| if preprocessing == 'standard': |
| flair_subdir = 'images' |
| gt_subdir = 'images' |
| else: |
| flair_subdir = 'zoomed/images' |
| gt_subdir = 'zoomed/images' |
| |
| |
| flair_path = base_path / 'FLAIR' / 'Preprocessed' / flair_subdir / f'{patient_id}_{slice_num}.png' |
| gm_path = base_path / 'GroundTruth' / gt_subdir / 'GM_Masks' / f'{patient_id}_{slice_num}.png' |
| brain_path = base_path / 'GroundTruth' / gt_subdir / 'Brain_Masks' / f'{patient_id}_{slice_num}.png' |
| |
| |
| zoom_factors_path = None |
| if preprocessing == 'zoomed': |
| zoom_factors_path = base_path / 'FLAIR' / 'Preprocessed' / 'zoomed' / 'images' / f'{patient_id}_zooming_factors.npy' |
| |
| return { |
| 'flair': flair_path, |
| 'gm_mask': gm_path, |
| 'brain_mask': brain_path, |
| 'zoom_factors': zoom_factors_path |
| } |
| |
| def load_single_slice(self, |
| patient_id: str, |
| slice_num: int, |
| dataset_name: str, |
| preprocessing: str, |
| class_scenario: str, |
| of_z_score: bool = True, |
| if_bet: bool = True, |
| pre_morph: bool = False) -> Tuple[np.ndarray, np.ndarray]: |
| """ |
| Load a single patient-slice and create paired input |
| |
| Args: |
| patient_id: Patient identifier |
| slice_num: Slice number |
| dataset_name: 'Local_SAI_GM_sp' |
| preprocessing: 'standard' or 'zoomed' |
| class_scenario: 'binary' |
| |
| Returns: |
| Tuple of (paired_input, combined_mask) |
| - paired_input: (256, 512, 1) FLAIR + mask concatenated |
| - combined_mask: (256, 256) multi-class labels |
| """ |
| |
| num_classes = 1 |
|
|
| |
| paths = self.get_file_paths(patient_id, slice_num, dataset_name, preprocessing) |
| |
| |
| flair = load_flair_image(paths['flair'], of_z_score=of_z_score) |
| |
| |
| gm_mask = load_mask_image(paths['gm_mask']) |
| brain_mask = load_mask_image(paths['brain_mask']) |
| |
| |
| combined_mask = combine_masks(gm_mask, class_scenario, preprocess=pre_morph) |
| |
| |
| paired_input, combined_mask = create_paired_input(flair, combined_mask, brain_mask, num_classes=num_classes, if_bet=if_bet) |
| |
| return paired_input, combined_mask |
| |
| def collect_patient_slices(self, |
| patient_list: List[str], |
| dataset_name: str, |
| preprocessing: str) -> List[Tuple[str, int, str]]: |
| """ |
| Collect all valid slice files for given patients |
| FILTERS OUT SLICES WITH ALL EMPTY MASKS |
| |
| Args: |
| patient_list: List of patient IDs |
| dataset_name: 'Local_SAI_GM_sp' |
| preprocessing: 'standard' or 'zoomed' |
| |
| Returns: |
| List of tuples (patient_id, slice_num, dataset_name) |
| """ |
| dataset_config = self.config.datasets[dataset_name] |
| slice_min, slice_max = dataset_config['slice_range'] |
| |
| patient_slices = [] |
| skipped_empty = 0 |
| |
| for patient_id in patient_list: |
| |
| |
| for slice_num in range(slice_min, slice_max + 1): |
| paths = self.get_file_paths(patient_id, slice_num, dataset_name, preprocessing) |
| |
| |
| if (paths['flair'].exists() and |
| paths['gm_mask'].exists() and |
| paths['brain_mask'].exists()): |
| |
| |
| try: |
| gm_mask = load_mask_image(paths['gm_mask']) |
| brain_mask = load_mask_image(paths['brain_mask']) |
| |
| |
| if is_valid_slice(gm_mask): |
| patient_slices.append((patient_id, slice_num, dataset_name)) |
| else: |
| skipped_empty += 1 |
| |
| except Exception as e: |
| print(f"Warning: Could not validate {patient_id}_{slice_num}: {e}") |
| skipped_empty += 1 |
| |
| if skipped_empty > 0: |
| print(f" β οΈ Skipped {skipped_empty} slices with empty masks") |
| |
| return patient_slices |
| |
| def create_dataset_for_fold(self, |
| fold_id: int, |
| split: str, |
| preprocessing: str, |
| class_scenario: str, |
| batch_size: int = 1, |
| shuffle: bool = True, |
| use_z_scored: bool = True, |
| bet: bool = False) -> tf.data.Dataset: |
| """ |
| Create TensorFlow dataset for a specific fold and split |
| |
| Args: |
| fold_id: Fold number (0-4) |
| split: 'train', 'val', or 'test' |
| preprocessing: 'standard' or 'zoomed' |
| class_scenario: 'binary' |
| batch_size: Batch size |
| shuffle: Whether to shuffle data |
| |
| Returns: |
| tf.data.Dataset yielding (paired_input, combined_mask) batches |
| """ |
| |
| splitter = PatientStratifiedSplitter(self.config) |
| fold_assignments = splitter.load_fold_assignments() |
| |
| |
| if split == 'test': |
| patient_list = fold_assignments['test_set']['patients'] |
| else: |
| fold_key = f'fold_{fold_id}' |
| if split == 'train': |
| patient_list = fold_assignments['folds'][fold_key]['train_patients'] |
| elif split == 'val': |
| patient_list = fold_assignments['folds'][fold_key]['val_patients'] |
| else: |
| raise ValueError(f"Unknown split: {split}") |
| |
| print(f"\nCreating dataset for fold {fold_id}, split '{split}'") |
| print(f"Patients: {len(patient_list)}") |
| |
| |
| all_patient_slices = [] |
| |
| for dataset_name in self.config.datasets.keys(): |
| |
| |
| dataset_patients = [p for p in patient_list] |
| |
| patient_slices = self.collect_patient_slices( |
| dataset_patients, |
| dataset_name, |
| preprocessing |
| ) |
| all_patient_slices.extend(patient_slices) |
| |
| print(f"Total slices: {len(all_patient_slices)}") |
| |
| if len(all_patient_slices) == 0: |
| raise ValueError(f"No data found for fold {fold_id}, split '{split}'") |
| |
| |
| def data_generator(): |
| """Generator function for tf.data.Dataset""" |
| for patient_id, slice_num, dataset_name in all_patient_slices: |
| try: |
| paired_input, combined_mask = self.load_single_slice( |
| patient_id, slice_num, dataset_name, |
| preprocessing, class_scenario |
| ) |
| yield paired_input, combined_mask, patient_id, slice_num |
| except Exception as e: |
| print(f"Error loading {patient_id}_{slice_num}: {e}") |
| continue |
| |
| |
| dataset = tf.data.Dataset.from_generator( |
| data_generator, |
| output_signature=( |
| tf.TensorSpec(shape=(256, 512, 1), dtype=tf.float32), |
| tf.TensorSpec(shape=(256, 256), dtype=tf.uint8), |
| tf.TensorSpec(shape=(), dtype=tf.string), |
| tf.TensorSpec(shape=(), dtype=tf.int32) |
| ) |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| dataset = dataset.cache() |
|
|
| |
| if shuffle and split == 'train': |
| dataset = dataset.shuffle( |
| buffer_size=len(all_patient_slices), |
| reshuffle_each_iteration=True |
| ) |
| |
| |
| dataset = dataset.batch(batch_size) |
| dataset = dataset.prefetch(tf.data.AUTOTUNE) |
| |
| return dataset |
|
|
|
|
| |
|
|
| def test_data_loading(): |
| """Test data loading functionality""" |
| print("\n" + "="*60) |
| print("TESTING DATA LOADING") |
| print("="*60) |
| |
| config = DataConfig() |
| |
| |
| print("\n[TEST 1] Creating patient stratified splits...") |
| splitter = PatientStratifiedSplitter(config) |
| fold_assignments = splitter.create_patient_stratified_splits(save=True) |
| |
| |
| is_valid = splitter.verify_patient_separation(fold_assignments) |
| |
| if not is_valid: |
| print("β Patient separation verification failed!") |
| return False |
| |
| |
| print("\n[TEST 2] Loading single slice...") |
| loader = P1DataLoader(config) |
| |
| |
| test_patient = fold_assignments['folds']['fold_0']['train_patients'][0] |
| |
| |
| if test_patient.startswith('1'): |
| test_dataset = 'Local_SAI_GM_sp' |
| test_slice = 10 |
| else: |
| raise ValueError |
|
|
| |
| try: |
| paired_input, combined_mask = loader.load_single_slice( |
| test_patient, test_slice, test_dataset, |
| 'standard', 'binary' |
| ) |
| |
| print(f"Γ’Εβ¦ Loaded slice {test_patient}_{test_slice}") |
| print(f" Paired input shape: {paired_input.shape}") |
| print(f" Combined mask shape: {combined_mask.shape}") |
| print(f" Mask unique values: {np.unique(combined_mask)}") |
| |
| except Exception as e: |
| print(f"β Failed to load slice: {e}") |
| return False |
| |
| |
| print("\n[TEST 3] Creating TensorFlow dataset...") |
| try: |
| dataset = loader.create_dataset_for_fold( |
| fold_id=0, |
| split='train', |
| preprocessing='standard', |
| class_scenario='binary', |
| batch_size=2, |
| shuffle=True |
| ) |
| |
| |
| for batch_paired, batch_masks, _, _ in dataset.take(1): |
| print(f"Γ’Εβ¦ Created dataset") |
| print(f" Batch paired input shape: {batch_paired.shape}") |
| print(f" Batch masks shape: {batch_masks.shape}") |
| print(f" Paired input dtype: {batch_paired.dtype}") |
| print(f" Masks dtype: {batch_masks.dtype}") |
| |
| except Exception as e: |
| print(f"β Failed to create dataset: {e}") |
| return False |
| |
| print("\n" + "="*60) |
| print("Γ’Εβ¦ ALL TESTS PASSED") |
| print("="*60) |
| |
| return True |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| |
| success = test_data_loading() |
| |
| if success: |
| print("\n" + "="*60) |
| print("DATA LOADER READY FOR USE") |
| print("="*60) |
| print("\nNext steps:") |
| print("1. Verify fold_assignments.json created in data_splits/") |
| print("2. Check that all file paths are correct for your system") |
| print("3. Proceed to model implementation") |
| else: |
| print("\n" + "="*60) |
| print("β DATA LOADER TESTS FAILED") |
| print("="*60) |
| print("\nPlease fix the issues above before proceeding") |
|
|
|
|