File size: 2,257 Bytes
538668e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import glob
import numpy as np
from typing import Any, Callable, Dict, Optional, Set, Tuple
import torch
from torch.utils.data import Dataset
import random

class fMRIDataset(Dataset):
    def __init__(self, 
                 data_root, datasets, split_suffixes, crop_length=40, downstream=False):

        self.file_paths = []
        self.crop_length = crop_length
        self.downstream = downstream
        for dataset_name in datasets:
            for suffix in split_suffixes:
                folder_name = f"{dataset_name}_{suffix}"
                folder_path = os.path.join(data_root, folder_name)
                if not os.path.exists(folder_path):
                    print(f"Warning: Folder not found: {folder_path}")
                    continue

                for root, dirs, files in os.walk(folder_path):
                    npz_files = glob.glob(os.path.join(root, "*.npz"))
                    if len(npz_files) > 1:
                        # sample_size = max(1, int(len(npz_files) * 0.5)) 
                        # npz_files = random.sample(npz_files, sample_size)
                        npz_files = sorted(npz_files)[:1]
                    self.file_paths.extend(npz_files)

        print(f"Dataset loaded. Total files found: {len(self.file_paths)}")

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):

        file_path = self.file_paths[idx]
        try:
            with np.load(file_path) as data_file:
                key = list(data_file.keys())[0]
                fmri_data = data_file[key] 
                fmri_data = fmri_data.astype(np.float32)
        except Exception as e:
            print(f"Error loading file {file_path}: {e}")
            return None

        total_time_frames = fmri_data.shape[-1]
        if total_time_frames > self.crop_length:
            start_idx = np.random.randint(0, total_time_frames - self.crop_length + 1)
            end_idx = start_idx + self.crop_length
            cropped_data = fmri_data[..., start_idx:end_idx]
        else:
            cropped_data = fmri_data[..., :self.crop_length]

        data_tensor = torch.from_numpy(cropped_data)

        data_tensor = data_tensor.permute(3, 0, 1, 2)

        return data_tensor