SLIM-Brain / data /pretrain_dataset.py
OneMore1's picture
Upload 12 files
538668e verified
import os
import glob
import numpy as np
from typing import Any, Callable, Dict, Optional, Set, Tuple
import torch
from torch.utils.data import Dataset
import random
class fMRIDataset(Dataset):
def __init__(self,
data_root, datasets, split_suffixes, crop_length=40, downstream=False):
self.file_paths = []
self.crop_length = crop_length
self.downstream = downstream
for dataset_name in datasets:
for suffix in split_suffixes:
folder_name = f"{dataset_name}_{suffix}"
folder_path = os.path.join(data_root, folder_name)
if not os.path.exists(folder_path):
print(f"Warning: Folder not found: {folder_path}")
continue
for root, dirs, files in os.walk(folder_path):
npz_files = glob.glob(os.path.join(root, "*.npz"))
if len(npz_files) > 1:
# sample_size = max(1, int(len(npz_files) * 0.5))
# npz_files = random.sample(npz_files, sample_size)
npz_files = sorted(npz_files)[:1]
self.file_paths.extend(npz_files)
print(f"Dataset loaded. Total files found: {len(self.file_paths)}")
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
file_path = self.file_paths[idx]
try:
with np.load(file_path) as data_file:
key = list(data_file.keys())[0]
fmri_data = data_file[key]
fmri_data = fmri_data.astype(np.float32)
except Exception as e:
print(f"Error loading file {file_path}: {e}")
return None
total_time_frames = fmri_data.shape[-1]
if total_time_frames > self.crop_length:
start_idx = np.random.randint(0, total_time_frames - self.crop_length + 1)
end_idx = start_idx + self.crop_length
cropped_data = fmri_data[..., start_idx:end_idx]
else:
cropped_data = fmri_data[..., :self.crop_length]
data_tensor = torch.from_numpy(cropped_data)
data_tensor = data_tensor.permute(3, 0, 1, 2)
return data_tensor