File size: 5,447 Bytes
538668e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import glob
import re
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from typing import List, Tuple, Union, Literal
import torch.nn.functional as F
from .pretrain_dataset import fMRIDataset
import io
import nibabel as nib
class fMRITaskDataset(fMRIDataset):
def __init__(
self,
data_root: str,
datasets: List[str],
split_suffixes: List[str],
crop_length: int,
label_csv_path: str,
task_type: Literal['classification', 'regression'] = 'classification',
downstream=True,
):
super().__init__(data_root, datasets, split_suffixes, crop_length, downstream)
self.task_type = task_type
self.labels_map = self._load_and_process_labels(label_csv_path)
initial_file_count = len(self.file_paths)
self.file_paths = [
path for path in self.file_paths
if self._extract_subject_id(path) in self.labels_map
]
if len(self.file_paths) < initial_file_count:
print(f"Warning: Dropped {initial_file_count - len(self.file_paths)} files due to missing labels in CSV.")
print(f"Task Dataset ready for {self.task_type}. Usable files: {len(self.file_paths)}")
def _extract_subject_id(self, file_path: str) -> str:
# folder_name = os.path.basename(os.path.dirname(file_path))
# match = re.search(r'(\d{7})', folder_name)
match = re.search(r'(\d{6})', os.path.basename(file_path))
if match:
subject_id_with_zeros = match.group(1)
subject_id = subject_id_with_zeros.lstrip('0')
return subject_id
return ""
def _load_and_process_labels(self, csv_path: str) -> dict:
if not os.path.exists(csv_path):
raise FileNotFoundError(f"Label CSV file not found at: {csv_path}")
print(f"Loading labels from {csv_path}...")
df = pd.read_csv(csv_path)
df['Subject'] = df['Subject'].astype(str)
df.dropna(subset=['Subject'], inplace=True)
labels_map = {}
if self.task_type == 'classification':
label_col = None
if 'Gender' in df.columns:
label_col = 'Gender'
elif 'gender' in df.columns:
label_col = 'gender'
elif 'age_group' in df.columns:
label_col = 'age_group'
if label_col is None:
raise ValueError("CSV must contain 'sex', 'gender' or 'age_group' column for classification.")
print(f"Using column '{label_col}' as label.")
# unique_vals = df[label_col].unique()
sex_mapping = {'F': 0, 'M': 1, 'f': 0, 'm': 1}
if df[label_col].dtype == object and df[label_col].astype(str).iloc[0].upper() in ['F', 'M']:
print(f"Encoding {label_col} (F/M) to Integers (0/1)...")
df = df[df[label_col].isin(sex_mapping.keys())]
df[label_col] = df[label_col].map(sex_mapping)
else:
df[label_col] = pd.to_numeric(df[label_col], errors='coerce').astype(int)
for _, row in df.iterrows():
subject_id = row['Subject']
labels_map[subject_id] = torch.tensor(row[label_col], dtype=torch.long)
elif self.task_type == 'regression':
label_col = 'age'
if label_col not in df.columns:
raise ValueError(f"Regression task requires '{label_col}' column.")
df[label_col] = pd.to_numeric(df[label_col], errors='coerce')
df.dropna(subset=[label_col], inplace=True)
for _, row in df.iterrows():
subject_id = row['Subject']
labels_map[subject_id] = torch.tensor(row[label_col], dtype=torch.float32).view(1)
else:
raise ValueError(f"Unsupported task_type: {self.task_type}")
print(f"Successfully loaded {len(labels_map)} subjects' labels.")
return labels_map
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
retries = 0
max_retries = 100
while retries < max_retries:
try:
data_tensor = super().__getitem__(idx)
if data_tensor is None:
raise ValueError(f"Failed to load data at index {idx} (super returned None)")
file_path = self.file_paths[idx]
subject_id = self._extract_subject_id(file_path)
data_tensor = data_tensor.unsqueeze(0)
if subject_id in self.labels_map:
label_tensor = self.labels_map[subject_id]
return data_tensor, label_tensor
else:
raise KeyError(f"Label not found for subject ID: {subject_id}")
except Exception as e:
# print(f"Warning: Error loading index {idx}: {e}. Retrying...")
idx = np.random.randint(0, len(self))
retries += 1
raise RuntimeError(f"Failed to load any valid data after {max_retries} retries.")
return data_tensor, label_tensor
|