OSF-Base / osf /datasets /pretrain_dataset.py
ztshuaiUCLA's picture
Upload folder using huggingface_hub
8f8716a verified
# Sleep Epoch Dataset for pretraining and downstream tasks
import os
import numpy as np
import pandas as pd
import torch
from pathlib import Path
from contextlib import suppress
from typing import Sequence, Optional, Dict, Union, List
from torch.utils.data import Dataset
from train_config import NEED_NORM_COL
def to_pm1(s: pd.Series) -> pd.Series:
s = pd.to_numeric(s, errors="coerce")
vmin, vmax = s.min(skipna=True), s.max(skipna=True)
if pd.isna(vmin) or pd.isna(vmax) or vmax <= vmin:
return pd.Series(0.0, index=s.index)
return (2 * (s - vmin) / (vmax - vmin) - 1).fillna(0.0)
class SleepEpochDataset(Dataset):
def __init__(
self,
csv_dir='/path/to/your/postprocessed/data',
split: str = "train",
*,
data_pct=1,
patient_cols: Optional[Union[str, Sequence[str]]] = None,
event_cols: Optional[Union[str, Sequence[str]]] = None,
train_edf_cols=None,
test_size: float = 0.15,
val_size: float = 0.15,
random_state: int = 1337,
sample_rate: int = 128,
window_size: int = 300,
epoch_length: int = 30,
cache_size: int = 8,
transform=None,
downstream_dataset_name=None,
data_source: str = "auto",
include_datasets: Optional[List[str]] = None,
regression_targets: Optional[List[str]] = None,
regression_filter_config: Optional[Dict] = None,
return_all_event_cols: bool = False,
return_nsrrid: bool = False,
):
assert split in {"pretrain", "pretrain-val", "pretrain-test", "train", "val", "test"}
assert data_source in {"auto", "pretrain", "downstream", "both"}
self.transform = transform
self.sample_rate = sample_rate
self.window_size = window_size
self.epoch_length = epoch_length
self.patient_cols = [patient_cols] if isinstance(patient_cols, str) else patient_cols
self.event_cols = [event_cols] if isinstance(event_cols, str) else event_cols
self.train_edf_cols = train_edf_cols
self.split = split
self.data_pct = float(data_pct)
self.data_source = data_source
self.regression_targets = regression_targets
self.regression_filter_config = regression_filter_config
self.return_all_event_cols = return_all_event_cols
self.return_nsrrid = return_nsrrid
patient_df, epoch_df = self._load_csvs(
csv_dir, split, data_source, include_datasets, self.event_cols,
regression_targets=self.regression_targets,
regression_filter_config=self.regression_filter_config,
return_all_event_cols=self.return_all_event_cols,
)
if downstream_dataset_name and include_datasets is None:
if downstream_dataset_name != "all":
mask = epoch_df['dataset_name'].astype(str).str.lower().str.startswith(downstream_dataset_name)
epoch_df = epoch_df.loc[mask].copy()
ids = epoch_df["nsrrid"].astype(str).unique()
patient_df = patient_df[patient_df["nsrrid"].astype(str).isin(ids)].copy()
# Determine num_classes
if self.event_cols:
if self.event_cols[0] in ['Hypopnea', 'Arousal', 'Oxygen Desaturation']:
self.num_classes = 2
elif self.event_cols[0] == 'Stage':
self.num_classes = 4
mapping = {0: 0, 1: 1, 2: 1, 3: 2, 4: 3}
epoch_df['Stage'] = epoch_df['Stage'].replace(mapping)
else:
self.num_classes = 2
else:
self.num_classes = 2
# Drop Stage == -1
if self.event_cols and ('Stage' in self.event_cols) and ('Stage' in epoch_df.columns):
epoch_df = epoch_df.loc[epoch_df['Stage'] != -1].copy()
# Build tables
if split in ("pretrain", "pretrain-val"):
sort_cols = [c for c in ['nsrrid', 'seg_id', 'epoch_id'] if c in epoch_df.columns]
self.all_epoch_df = epoch_df.sort_values(sort_cols).reset_index(drop=True)
idx_keep_cols = [c for c in ['nsrrid', 'seg_id', 'path_head'] if c in self.all_epoch_df.columns]
if self.regression_targets:
for t in self.regression_targets:
col = f"{t}_mean"
if col in self.all_epoch_df.columns:
idx_keep_cols.append(col)
self.epoch_df = (
self.all_epoch_df[idx_keep_cols]
.drop_duplicates(['nsrrid', 'seg_id'], keep='first')
.reset_index(drop=True)
)
else:
expected_len = self.window_size // self.epoch_length
grp = epoch_df.groupby(['nsrrid', 'seg_id']).size().rename('n').reset_index()
valid_keys = grp.loc[grp['n'] == expected_len, ['nsrrid', 'seg_id']]
epoch_df_valid = epoch_df.merge(valid_keys, on=['nsrrid', 'seg_id'], how='inner')
sort_cols = [c for c in ['nsrrid', 'seg_id', 'epoch_id'] if c in epoch_df_valid.columns]
self.all_epoch_df = epoch_df_valid.sort_values(sort_cols).reset_index(drop=True)
idx_keep_cols = [c for c in ['nsrrid', 'seg_id', 'path_head'] if c in self.all_epoch_df.columns]
if self.regression_targets:
for t in self.regression_targets:
col = f"{t}_mean"
if col in self.all_epoch_df.columns:
idx_keep_cols.append(col)
self.epoch_df = (
self.all_epoch_df[idx_keep_cols]
.drop_duplicates(['nsrrid', 'seg_id'], keep='first')
.reset_index(drop=True)
)
# Patient-level sampling
if not (0 < self.data_pct <= 1.0):
raise ValueError(f"data_pct must be in (0,1], got {self.data_pct}")
if self.data_pct < 1.0:
eligible_patients = pd.Index(self.epoch_df['nsrrid'].unique())
n_keep = max(1, int(len(eligible_patients) * self.data_pct))
sampled_nsrrids = pd.Series(eligible_patients).sample(n=n_keep, random_state=random_state).to_list()
self.epoch_df = self.epoch_df.loc[self.epoch_df['nsrrid'].isin(sampled_nsrrids)].reset_index(drop=True)
self.all_epoch_df = self.all_epoch_df.loc[self.all_epoch_df['nsrrid'].isin(sampled_nsrrids)].reset_index(drop=True)
patient_df = patient_df.loc[patient_df['nsrrid'].isin(sampled_nsrrids)].copy()
self.patient_df = patient_df.set_index("nsrrid")
# Build segment indices
self._seg_indices = None
if hasattr(self, "all_epoch_df") and {'nsrrid', 'seg_id'}.issubset(self.all_epoch_df.columns):
grp_indices = self.all_epoch_df.groupby(['nsrrid', 'seg_id'], sort=False).indices
self._seg_indices = {}
has_epoch_id = 'epoch_id' in self.all_epoch_df.columns
epoch_id_values = self.all_epoch_df['epoch_id'].to_numpy() if has_epoch_id else None
for key, idx_list in grp_indices.items():
idx_arr = np.fromiter(idx_list, dtype=np.int64)
if has_epoch_id:
order = np.argsort(epoch_id_values[idx_arr])
idx_arr = idx_arr[order]
self._seg_indices[key] = idx_arr
# Compute class distribution
self._class_counts = None
if self.event_cols and self.event_cols[0] in self.all_epoch_df.columns:
label_col = self.event_cols[0]
value_counts = self.all_epoch_df[label_col].value_counts().sort_index()
class_counts = np.zeros(self.num_classes, dtype=np.int64)
for cls_idx, count in value_counts.items():
if 0 <= int(cls_idx) < self.num_classes:
class_counts[int(cls_idx)] = int(count)
self._class_counts = class_counts
def _load_csvs(self, csv_dir, split, data_source, include_datasets, event_cols,
regression_targets=None, regression_filter_config=None, return_all_event_cols=False):
split_suffix_map = {
"pretrain": "train", "pretrain-val": "valid", "pretrain-test": "test",
"train": "train", "val": "valid", "test": "test"
}
split_suffix = split_suffix_map[split]
if data_source == "auto":
sources = ["pretrain"] if split.startswith("pretrain") else ["downstream"]
elif data_source == "both":
sources = ["pretrain", "downstream"]
else:
sources = [data_source]
patient_dfs = []
epoch_dfs = []
csv_prefix = "epoch_regression" if regression_targets else "epoch"
for source in sources:
patient_csv = f"{csv_dir}/patient_{source}_{split_suffix}.csv"
epoch_csv = f"{csv_dir}/{csv_prefix}_{source}_{split_suffix}.csv"
if Path(patient_csv).is_file() and Path(epoch_csv).is_file():
patient_dfs.append(pd.read_csv(patient_csv))
epoch_dfs.append(pd.read_csv(epoch_csv))
patient_df = pd.concat(patient_dfs, ignore_index=True).drop_duplicates(subset=['nsrrid'])
epoch_df = pd.concat(epoch_dfs, ignore_index=True)
base_cols = ['nsrrid', 'seg_id', 'dataset_name', 'epoch_id', 'path_head']
if event_cols:
if return_all_event_cols:
for col in event_cols:
if col and col not in base_cols:
base_cols.append(col)
elif event_cols[0]:
base_cols.append(event_cols[0])
if regression_targets:
for t in regression_targets:
col_name = f"{t}_mean"
if col_name in epoch_df.columns:
base_cols.append(col_name)
keep_cols = [c for c in base_cols if c in epoch_df.columns]
epoch_df = epoch_df[keep_cols].copy()
if regression_targets:
label_cols = [f"{t}_mean" for t in regression_targets]
existing = [c for c in label_cols if c in epoch_df.columns]
if existing:
epoch_df = epoch_df.dropna(subset=existing).reset_index(drop=True)
if regression_filter_config:
for col_name, filter_rules in regression_filter_config.items():
if col_name in epoch_df.columns:
mask = pd.Series([True] * len(epoch_df))
if "min" in filter_rules:
mask = mask & (epoch_df[col_name] >= filter_rules["min"])
if "max" in filter_rules:
mask = mask & (epoch_df[col_name] <= filter_rules["max"])
epoch_df = epoch_df[mask].reset_index(drop=True)
if include_datasets is not None and 'dataset_name' in epoch_df.columns:
include_lower = [d.lower() for d in include_datasets]
mask = epoch_df['dataset_name'].astype(str).str.lower().isin(include_lower)
epoch_df = epoch_df[mask].copy()
patient_df = patient_df[patient_df['nsrrid'].isin(epoch_df['nsrrid'].unique())].copy()
return patient_df, epoch_df
def __len__(self) -> int:
return len(self.epoch_df)
def get_class_counts(self) -> Optional[np.ndarray]:
return self._class_counts
def _resample_df(self, df: pd.DataFrame, target_hz: int) -> pd.DataFrame:
if not np.issubdtype(df.index.dtype, np.number):
t = np.arange(len(df)) / float(target_hz)
df = df.copy()
df.index = t
t0 = float(df.index.min())
t1 = float(df.index.max())
t_target = np.arange(t0, t0 + self.window_size, 1.0 / target_hz)
if t_target[-1] > t1:
t_target = t_target[t_target <= t1 + 1e-9]
out = df.reindex(t_target).interpolate(method="linear", limit_direction="both")
return out.fillna(0.0)
def __getitem__(self, idx: int):
row = self.epoch_df.iloc[idx]
nsrrid = row["nsrrid"]
seg_id = int(row["seg_id"])
cols = list(self.train_edf_cols) if self.train_edf_cols is not None else None
if self.split == "pretrain":
df_epoch = self._load_epoch_all_df(row["path_head"], seg_id, columns=cols)
df_epoch = self._resample_df(df_epoch, self.sample_rate)
if cols is not None:
for ch in cols:
if ch not in df_epoch.columns:
df_epoch[ch] = 0.0
elif ch in NEED_NORM_COL:
df_epoch[ch] = to_pm1(df_epoch[ch])
df_epoch = df_epoch[cols]
samples_per_epoch = int(self.window_size * self.sample_rate)
if len(df_epoch) < samples_per_epoch:
pad = samples_per_epoch - len(df_epoch)
tail = pd.DataFrame({c: 0.0 for c in df_epoch.columns},
index=df_epoch.index[-1] + (np.arange(1, pad + 1) / self.sample_rate))
df_epoch = pd.concat([df_epoch, tail], axis=0)
elif len(df_epoch) > samples_per_epoch:
df_epoch = df_epoch.iloc[:samples_per_epoch]
x = torch.tensor(df_epoch.to_numpy(copy=False), dtype=torch.float32).t().contiguous()
x = torch.clamp(x, min=-6, max=6)
output = {"psg": x}
if self.return_nsrrid:
output["nsrrid"] = nsrrid
output["seg_id"] = seg_id
if self.patient_cols:
y = torch.tensor(self.patient_df.loc[nsrrid, self.patient_cols].values.astype(float), dtype=torch.float32)
output["label"] = y.long() if not self.return_nsrrid else y
elif self.event_cols:
if self.return_all_event_cols:
available_cols = [c for c in self.event_cols if c in row.index]
y = torch.tensor([row[c] for c in available_cols], dtype=torch.float32)
else:
y = torch.tensor([row[self.event_cols[0]]], dtype=torch.float32)
output["label"] = y
return output
else:
# Downstream split
if self._seg_indices is None:
seg_df = self.all_epoch_df[
(self.all_epoch_df['nsrrid'] == nsrrid) & (self.all_epoch_df['seg_id'] == seg_id)
].sort_values('epoch_id')
else:
idx_arr = self._seg_indices.get((nsrrid, seg_id))
seg_df = self.all_epoch_df.iloc[idx_arr] if idx_arr is not None else \
self.all_epoch_df[(self.all_epoch_df['nsrrid'] == nsrrid) & (self.all_epoch_df['seg_id'] == seg_id)].sort_values('epoch_id')
df_epoch = self._load_epoch_all_df(row["path_head"], seg_id, columns=cols)
df_epoch = self._resample_df(df_epoch, self.sample_rate)
if cols is not None:
for ch in cols:
if ch not in df_epoch.columns:
df_epoch[ch] = 0.0
elif ch in NEED_NORM_COL:
df_epoch[ch] = to_pm1(df_epoch[ch])
df_epoch = df_epoch[cols]
samples_per_epoch = int(self.window_size * self.sample_rate)
if len(df_epoch) < samples_per_epoch:
pad = samples_per_epoch - len(df_epoch)
tail = pd.DataFrame({c: 0.0 for c in df_epoch.columns},
index=df_epoch.index[-1] + (np.arange(1, pad + 1) / self.sample_rate))
df_epoch = pd.concat([df_epoch, tail], axis=0)
elif len(df_epoch) > samples_per_epoch:
df_epoch = df_epoch.iloc[:samples_per_epoch]
x = torch.tensor(df_epoch.to_numpy(copy=False), dtype=torch.float32).t().contiguous()
x = torch.clamp(x, min=-6, max=6)
output = {"psg": x}
if self.return_nsrrid:
output["nsrrid"] = nsrrid
output["seg_id"] = seg_id
if self.patient_cols:
y = torch.tensor(self.patient_df.loc[nsrrid, self.patient_cols].values.astype(float), dtype=torch.float32)
y = y.repeat(self.window_size // self.epoch_length)
output["label"] = y
elif self.event_cols:
if self.return_all_event_cols:
available_cols = [c for c in self.event_cols if c in seg_df.columns]
y = torch.tensor(seg_df[available_cols].values.astype(float), dtype=torch.float32).squeeze(0)
else:
y = torch.tensor(seg_df[self.event_cols].values.astype(float), dtype=torch.float32).squeeze(1)
output["label"] = y
elif self.regression_targets:
label_cols = [f"{t}_mean" for t in self.regression_targets]
y = torch.tensor([row[c] for c in label_cols], dtype=torch.float32)
output["label"] = y
return output
def _build_epoch_all_path(self, path_head: str, epoch_id: int) -> Path:
return Path(f"{path_head}/epoch-{epoch_id:05d}_all.parquet")
def _load_epoch_all_df(self, path_head: str, epoch_id: int, columns=None) -> pd.DataFrame:
fp = self._build_epoch_all_path(path_head, epoch_id)
if not fp.is_file():
raise FileNotFoundError(f"Parquet missing: {fp}")
df = pd.read_parquet(fp)
for c in df.columns:
if not np.issubdtype(df[c].dtype, np.floating):
with suppress(Exception):
df[c] = df[c].astype(np.float32)
return df