| from __future__ import annotations | |
| from typing import Optional, Tuple | |
| from torch.utils.data import DataLoader, DistributedSampler | |
| from flexibrain.config import DataConfig, TrainingConfig | |
| from flexibrain.data.nifti import NiftiTxtDataset | |
| from flexibrain.data.classification import ClassificationDataset, custom_collate_fn as downstream_collate | |
| from flexibrain.data.collate import custom_collate_fn as pretrain_collate | |
| def build_pretrain_dataloaders(data: DataConfig, training: TrainingConfig, rank: int = 0, world_size: int = 1) -> Tuple[DataLoader, DataLoader]: | |
| train_set = NiftiTxtDataset(data.train_list, return_torch=True, memory_map=data.memory_map, cache_meta=True, T_prime=data.T_prime, tau_seconds=data.tau_seconds, default_tr=data.default_tr) | |
| val_set = NiftiTxtDataset(data.val_list, return_torch=True, memory_map=data.memory_map, cache_meta=True, T_prime=data.T_prime, tau_seconds=data.tau_seconds, default_tr=data.default_tr) | |
| train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True, seed=training.seed) if world_size > 1 else None | |
| val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False, seed=training.seed) if world_size > 1 else None | |
| train_loader = DataLoader(train_set, batch_size=data.batch_size, sampler=train_sampler, shuffle=train_sampler is None, num_workers=data.num_workers, pin_memory=True, drop_last=True, collate_fn=pretrain_collate) | |
| val_loader = DataLoader(val_set, batch_size=data.batch_size, sampler=val_sampler, shuffle=False, num_workers=data.num_workers, pin_memory=True, drop_last=False, collate_fn=pretrain_collate) | |
| return train_loader, val_loader | |
| def _classification_dataset(txt_file: Optional[str], data: DataConfig): | |
| if not txt_file: | |
| return None | |
| if not data.csv: | |
| raise ValueError("data.csv is required for downstream classification") | |
| return ClassificationDataset( | |
| txt_files=txt_file, | |
| csv_path=data.csv, | |
| id_column=data.id_column, | |
| label_column=data.label_column, | |
| label_mode=data.label_mode, | |
| path_id_mode=data.path_id_mode, | |
| normal_label=data.normal_label, | |
| return_torch=True, | |
| memory_map=data.memory_map, | |
| cache_meta=True, | |
| T_prime=data.T_prime, | |
| tau_seconds=data.tau_seconds, | |
| default_tr=data.default_tr, | |
| ) | |
| def build_downstream_dataloaders(data: DataConfig, training: TrainingConfig, rank: int = 0, world_size: int = 1): | |
| train_set = _classification_dataset(data.train_list, data) | |
| val_set = _classification_dataset(data.val_list, data) | |
| test_set = _classification_dataset(data.test_list, data) | |
| train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True, seed=training.seed) if world_size > 1 else None | |
| val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False, seed=training.seed) if world_size > 1 else None | |
| test_sampler = DistributedSampler(test_set, num_replicas=world_size, rank=rank, shuffle=False, seed=training.seed) if world_size > 1 and test_set is not None else None | |
| train_loader = DataLoader(train_set, batch_size=data.batch_size, sampler=train_sampler, shuffle=train_sampler is None, num_workers=data.num_workers, pin_memory=True, drop_last=True, collate_fn=downstream_collate) | |
| val_loader = DataLoader(val_set, batch_size=data.batch_size, sampler=val_sampler, shuffle=False, num_workers=data.num_workers, pin_memory=True, drop_last=False, collate_fn=downstream_collate) | |
| test_loader = None | |
| if test_set is not None: | |
| test_loader = DataLoader(test_set, batch_size=data.batch_size, sampler=test_sampler, shuffle=False, num_workers=data.num_workers, pin_memory=True, drop_last=False, collate_fn=downstream_collate) | |
| return train_loader, val_loader, test_loader | |