Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import nibabel as nib | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| import torchio as tio | |
| from typing import List, Dict, Tuple, Optional | |
| import random | |
| class MultiModalPretrainDataset(Dataset): | |
| """ | |
| 多模态3D医学图像预训练数据集 | |
| 特点: | |
| - 支持多个数据集(A4, ADNIDOD, AIBL, BraTS, NACC) | |
| - 缺失模态填充为0,并提供observed_indicator | |
| - 支持数据增强(Spatial transforms) | |
| - 支持Modality Dropout增加组合多样性 | |
| """ | |
| # 统一的模态顺序 | |
| MODALITY_ORDER = ['T1', 'T2', 'Flair', 'PET'] # 统一为4个模态 | |
| # 每个数据集的模态列名映射到统一名称 | |
| MODALITY_MAPPING = { | |
| 'modality_data_A4.xlsx': {'T1': 'T1', 'T2': 'T2', 'Flair': 'Flair', 'Amy_PET': 'PET'}, | |
| 'modality_data_ADNIDOD.xlsx': {'T1': 'T1', 'T2': 'T2', 'Flair': 'Flair', 'PET': 'PET'}, | |
| 'modality_data_AIBL.xlsx': {'T1': 'T1', 'T2': 'T2', 'Flair': 'Flair', 'PET': 'PET'}, | |
| 'modality_data_BraTS.xlsx': {'T1w': 'T1', 'T2w': 'T2', 'Flair': 'Flair', 'PET': 'PET'}, | |
| 'modality_data_NACC.xlsx': {'T1': 'T1', 'T2': 'T2', 'Flair': 'Flair', 'Amyloid': 'PET'}, | |
| } | |
| # Path prefix replacement: Excel paths use the old server prefix, | |
| # remap to the local data directory. | |
| OLD_PATH_PREFIX = "/home/data/Pretrain" | |
| NEW_PATH_PREFIX = "./data/Pretrain" | |
| def __init__( | |
| self, | |
| excel_dir: str = "./data/Match_data_path/pretraining_processed", | |
| image_size: Tuple[int, int, int] = (128, 128, 128), | |
| augmentation: bool = True, | |
| modality_dropout_prob: float = 0.3, | |
| min_modalities: int = 1, | |
| cache_data: bool = False, | |
| ): | |
| """ | |
| Args: | |
| excel_dir: Excel文件目录路径 | |
| image_size: 图像尺寸 (D, H, W) | |
| augmentation: 是否进行数据增强 | |
| modality_dropout_prob: 每个模态被dropout的概率 | |
| min_modalities: 至少保留的模态数量 | |
| cache_data: 是否缓存加载的数据到内存 | |
| """ | |
| self.excel_dir = excel_dir | |
| self.image_size = image_size | |
| self.augmentation = augmentation | |
| self.modality_dropout_prob = modality_dropout_prob | |
| self.min_modalities = min_modalities | |
| self.cache_data = cache_data | |
| self.cache = {} | |
| # 加载所有样本 | |
| self.samples = self._load_all_samples() | |
| print(f"Loaded {len(self.samples)} samples from {len(self.MODALITY_MAPPING)} datasets") | |
| # 初始化数据增强 | |
| if self.augmentation: | |
| self.spatial_transform = tio.OneOf({ | |
| tio.RandomFlip(axes=0, flip_probability=0.5): 0.33, | |
| tio.RandomAffine(scales=(0.9, 1.2), degrees=10, p=0.5): 0.33, | |
| tio.RandomElasticDeformation( | |
| num_control_points=(10, 10, 10), | |
| max_displacement=8, | |
| locked_borders=2, | |
| p=0.5 | |
| ): 0.34, | |
| }) | |
| def _load_all_samples(self) -> List[Dict]: | |
| """加载所有Excel文件中的样本""" | |
| samples = [] | |
| for excel_file, modality_map in self.MODALITY_MAPPING.items(): | |
| excel_path = os.path.join(self.excel_dir, excel_file) | |
| if not os.path.exists(excel_path): | |
| print(f"Warning: Excel file not found: {excel_path}") | |
| continue | |
| df = pd.read_excel(excel_path) | |
| dataset_name = excel_file.replace('modality_data_', '').replace('.xlsx', '') | |
| for idx, row in df.iterrows(): | |
| sample = { | |
| 'dataset': dataset_name, | |
| 'subject_id': row.get('SubjectID', f'{dataset_name}_{idx}'), | |
| 'modalities': {} | |
| } | |
| # 映射模态路径 | |
| for orig_col, unified_name in modality_map.items(): | |
| if orig_col in df.columns: | |
| path = row[orig_col] | |
| if pd.notna(path) and isinstance(path, str): | |
| # Remap old server path prefix to local path | |
| if path.startswith(self.OLD_PATH_PREFIX): | |
| path = self.NEW_PATH_PREFIX + path[len(self.OLD_PATH_PREFIX):] | |
| if os.path.exists(path): | |
| sample['modalities'][unified_name] = path | |
| # 只添加至少有一个模态的样本 | |
| if len(sample['modalities']) >= 1: | |
| samples.append(sample) | |
| return samples | |
| def _load_nifti(self, path: str) -> np.ndarray: | |
| """加载NIfTI文件""" | |
| try: | |
| nii = nib.load(path) | |
| data = nii.get_fdata().astype(np.float32) | |
| return data | |
| except Exception as e: | |
| print(f"Error loading {path}: {e}") | |
| return None | |
| def _apply_modality_dropout(self, available_modalities: List[str]) -> List[str]: | |
| """ | |
| 应用Modality Dropout | |
| 随机丢弃一些模态以增加组合多样性 | |
| """ | |
| if len(available_modalities) <= self.min_modalities: | |
| return available_modalities | |
| kept_modalities = [] | |
| for mod in available_modalities: | |
| if random.random() > self.modality_dropout_prob: | |
| kept_modalities.append(mod) | |
| # 确保至少保留min_modalities个模态 | |
| if len(kept_modalities) < self.min_modalities: | |
| # 随机选择需要保留的模态 | |
| kept_modalities = random.sample(available_modalities, self.min_modalities) | |
| return kept_modalities | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
| sample = self.samples[idx] | |
| # 检查缓存 | |
| if self.cache_data and idx in self.cache: | |
| cached_data = self.cache[idx] | |
| images = cached_data['images'].clone() | |
| original_observed = cached_data['observed'].clone() | |
| else: | |
| # 初始化输出张量 | |
| num_modalities = len(self.MODALITY_ORDER) | |
| images = torch.zeros(num_modalities, *self.image_size, dtype=torch.float32) | |
| original_observed = torch.zeros(num_modalities, dtype=torch.float32) | |
| # 加载每个模态 | |
| for i, modality in enumerate(self.MODALITY_ORDER): | |
| if modality in sample['modalities']: | |
| path = sample['modalities'][modality] | |
| data = self._load_nifti(path) | |
| if data is not None: | |
| # 确保数据尺寸正确 | |
| if data.shape == self.image_size: | |
| images[i] = torch.from_numpy(data) | |
| original_observed[i] = 1.0 | |
| else: | |
| print(f"Warning: Size mismatch for {path}, expected {self.image_size}, got {data.shape}") | |
| # 缓存数据 | |
| if self.cache_data: | |
| self.cache[idx] = { | |
| 'images': images.clone(), | |
| 'observed': original_observed.clone() | |
| } | |
| # 不再应用Modality Dropout,直接使用原始observed | |
| observed = original_observed.clone() | |
| # 应用空间数据增强 | |
| if self.augmentation: | |
| # 只对observed的模态应用增强 | |
| # 创建TorchIO Subject | |
| subject_dict = {} | |
| for i, modality in enumerate(self.MODALITY_ORDER): | |
| if observed[i] == 1.0: | |
| # TorchIO需要4D张量 (C, D, H, W) | |
| subject_dict[modality] = tio.ScalarImage(tensor=images[i:i+1]) | |
| if subject_dict: | |
| subject = tio.Subject(**subject_dict) | |
| transformed = self.spatial_transform(subject) | |
| # 将增强后的数据放回images张量 | |
| for i, modality in enumerate(self.MODALITY_ORDER): | |
| if modality in subject_dict: | |
| images[i] = transformed[modality].data[0] | |
| return { | |
| 'images': images, # (num_modalities, D, H, W) | |
| 'observed': observed, # (num_modalities,) | |
| } | |
| def create_pretrain_dataloader( | |
| excel_dir: str = "/home/data/Match_data_path/pretraining_processed", | |
| batch_size: int = 4, | |
| num_workers: int = 8, | |
| augmentation: bool = True, | |
| modality_dropout_prob: float = 0.3, | |
| min_modalities: int = 1, | |
| shuffle: bool = True, | |
| pin_memory: bool = True, | |
| cache_data: bool = False, | |
| ) -> DataLoader: | |
| """ | |
| 创建预训练数据加载器 | |
| Args: | |
| excel_dir: Excel文件目录 | |
| batch_size: 批量大小 | |
| num_workers: 数据加载进程数 | |
| augmentation: 是否数据增强 | |
| modality_dropout_prob: 模态dropout概率 | |
| min_modalities: 至少保留的模态数 | |
| shuffle: 是否打乱数据 | |
| pin_memory: 是否使用pinned memory | |
| cache_data: 是否缓存数据到内存 | |
| Returns: | |
| DataLoader实例 | |
| """ | |
| dataset = MultiModalPretrainDataset( | |
| excel_dir=excel_dir, | |
| augmentation=augmentation, | |
| modality_dropout_prob=modality_dropout_prob, | |
| min_modalities=min_modalities, | |
| cache_data=cache_data, | |
| ) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| drop_last=True, | |
| ) | |
| return dataloader | |
| def collate_fn_with_info(batch: List[Dict]) -> Dict[str, torch.Tensor]: | |
| """ | |
| 自定义collate函数,处理批量数据 | |
| """ | |
| images = torch.stack([item['images'] for item in batch]) | |
| observed = torch.stack([item['observed'] for item in batch]) | |
| return { | |
| 'images': images, # (B, num_modalities, D, H, W) | |
| 'observed': observed, # (B, num_modalities) | |
| } | |
| # ============== 使用示例 ============== | |
| if __name__ == '__main__': | |
| print("=" * 60) | |
| print("多模态3D医学图像预训练数据加载器") | |
| print("=" * 60) | |
| # 创建数据加载器 | |
| dataloader = create_pretrain_dataloader( | |
| excel_dir="/home/data/Match_data_path/pretraining_processed", | |
| batch_size=2, | |
| num_workers=4, | |
| augmentation=True, | |
| modality_dropout_prob=0.3, | |
| min_modalities=1, | |
| shuffle=True, | |
| ) | |
| print(f"\n数据集大小: {len(dataloader.dataset)}") | |
| print(f"批量数: {len(dataloader)}") | |
| print(f"模态顺序: {MultiModalPretrainDataset.MODALITY_ORDER}") | |
| # 测试加载一个批量 | |
| print("\n测试加载一个批量...") | |
| for batch in dataloader: | |
| images = batch['images'] | |
| observed = batch['observed'] | |
| print(f"\n批量数据形状:") | |
| print(f" images: {images.shape}") # (B, 4, 128, 128, 128) | |
| print(f" observed: {observed.shape}") # (B, 4) | |
| print("\n" + "=" * 60) | |
| print("数据加载测试完成!") | |
| print("=" * 60) | |