FlexiBrain / flexibrain /data /builders.py
OneMore1's picture
Sync from GitHub FlexiBrain main
6a51385 verified
Raw
History Blame Contribute Delete
3.92 kB
from __future__ import annotations
from typing import Optional, Tuple
from torch.utils.data import DataLoader, DistributedSampler
from flexibrain.config import DataConfig, TrainingConfig
from flexibrain.data.nifti import NiftiTxtDataset
from flexibrain.data.classification import ClassificationDataset, custom_collate_fn as downstream_collate
from flexibrain.data.collate import custom_collate_fn as pretrain_collate
def build_pretrain_dataloaders(data: DataConfig, training: TrainingConfig, rank: int = 0, world_size: int = 1) -> Tuple[DataLoader, DataLoader]:
train_set = NiftiTxtDataset(data.train_list, return_torch=True, memory_map=data.memory_map, cache_meta=True, T_prime=data.T_prime, tau_seconds=data.tau_seconds, default_tr=data.default_tr)
val_set = NiftiTxtDataset(data.val_list, return_torch=True, memory_map=data.memory_map, cache_meta=True, T_prime=data.T_prime, tau_seconds=data.tau_seconds, default_tr=data.default_tr)
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True, seed=training.seed) if world_size > 1 else None
val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False, seed=training.seed) if world_size > 1 else None
train_loader = DataLoader(train_set, batch_size=data.batch_size, sampler=train_sampler, shuffle=train_sampler is None, num_workers=data.num_workers, pin_memory=True, drop_last=True, collate_fn=pretrain_collate)
val_loader = DataLoader(val_set, batch_size=data.batch_size, sampler=val_sampler, shuffle=False, num_workers=data.num_workers, pin_memory=True, drop_last=False, collate_fn=pretrain_collate)
return train_loader, val_loader
def _classification_dataset(txt_file: Optional[str], data: DataConfig):
if not txt_file:
return None
if not data.csv:
raise ValueError("data.csv is required for downstream classification")
return ClassificationDataset(
txt_files=txt_file,
csv_path=data.csv,
id_column=data.id_column,
label_column=data.label_column,
label_mode=data.label_mode,
path_id_mode=data.path_id_mode,
normal_label=data.normal_label,
return_torch=True,
memory_map=data.memory_map,
cache_meta=True,
T_prime=data.T_prime,
tau_seconds=data.tau_seconds,
default_tr=data.default_tr,
)
def build_downstream_dataloaders(data: DataConfig, training: TrainingConfig, rank: int = 0, world_size: int = 1):
train_set = _classification_dataset(data.train_list, data)
val_set = _classification_dataset(data.val_list, data)
test_set = _classification_dataset(data.test_list, data)
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True, seed=training.seed) if world_size > 1 else None
val_sampler = DistributedSampler(val_set, num_replicas=world_size, rank=rank, shuffle=False, seed=training.seed) if world_size > 1 else None
test_sampler = DistributedSampler(test_set, num_replicas=world_size, rank=rank, shuffle=False, seed=training.seed) if world_size > 1 and test_set is not None else None
train_loader = DataLoader(train_set, batch_size=data.batch_size, sampler=train_sampler, shuffle=train_sampler is None, num_workers=data.num_workers, pin_memory=True, drop_last=True, collate_fn=downstream_collate)
val_loader = DataLoader(val_set, batch_size=data.batch_size, sampler=val_sampler, shuffle=False, num_workers=data.num_workers, pin_memory=True, drop_last=False, collate_fn=downstream_collate)
test_loader = None
if test_set is not None:
test_loader = DataLoader(test_set, batch_size=data.batch_size, sampler=test_sampler, shuffle=False, num_workers=data.num_workers, pin_memory=True, drop_last=False, collate_fn=downstream_collate)
return train_loader, val_loader, test_loader