File size: 2,257 Bytes
538668e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import os
import glob
import numpy as np
from typing import Any, Callable, Dict, Optional, Set, Tuple
import torch
from torch.utils.data import Dataset
import random
class fMRIDataset(Dataset):
def __init__(self,
data_root, datasets, split_suffixes, crop_length=40, downstream=False):
self.file_paths = []
self.crop_length = crop_length
self.downstream = downstream
for dataset_name in datasets:
for suffix in split_suffixes:
folder_name = f"{dataset_name}_{suffix}"
folder_path = os.path.join(data_root, folder_name)
if not os.path.exists(folder_path):
print(f"Warning: Folder not found: {folder_path}")
continue
for root, dirs, files in os.walk(folder_path):
npz_files = glob.glob(os.path.join(root, "*.npz"))
if len(npz_files) > 1:
# sample_size = max(1, int(len(npz_files) * 0.5))
# npz_files = random.sample(npz_files, sample_size)
npz_files = sorted(npz_files)[:1]
self.file_paths.extend(npz_files)
print(f"Dataset loaded. Total files found: {len(self.file_paths)}")
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
file_path = self.file_paths[idx]
try:
with np.load(file_path) as data_file:
key = list(data_file.keys())[0]
fmri_data = data_file[key]
fmri_data = fmri_data.astype(np.float32)
except Exception as e:
print(f"Error loading file {file_path}: {e}")
return None
total_time_frames = fmri_data.shape[-1]
if total_time_frames > self.crop_length:
start_idx = np.random.randint(0, total_time_frames - self.crop_length + 1)
end_idx = start_idx + self.crop_length
cropped_data = fmri_data[..., start_idx:end_idx]
else:
cropped_data = fmri_data[..., :self.crop_length]
data_tensor = torch.from_numpy(cropped_data)
data_tensor = data_tensor.permute(3, 0, 1, 2)
return data_tensor
|