JensLundsgaard commited on
Commit
07b043c
·
verified ·
1 Parent(s): 7eeaafa

Upload dataset_ivf.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataset_ivf.py +75 -0
dataset_ivf.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset_ivf.py
2
+ import numpy as np, pandas as pd, torch
3
+ from torch.utils.data import Dataset
4
+ from PIL import Image, ImageFile
5
+ import os
6
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
7
+ class IVFSequenceDataset(Dataset):
8
+ def __init__(self, df, resize=500, norm="minmax01"):
9
+ #self.df = pd.read_csv(index_csv)
10
+ self.df = df
11
+ self.resize = resize
12
+ self.norm = norm
13
+
14
+ def _read_gray(self, path):
15
+ img = Image.open(path)
16
+ if img is None:
17
+ raise FileNotFoundError(path)
18
+ # Resize if needed
19
+ if self.resize is not None:
20
+ img = img.resize((self.resize, self.resize), Image.BILINEAR)
21
+ return np.array(img, dtype="float32")
22
+
23
+ def _normalize_video(self, vol):
24
+ if self.norm == "zscore":
25
+ m, s = vol.mean(), vol.std() + 1e-6
26
+ vol = (vol - m) / s
27
+ elif self.norm == "minmax01":
28
+ lo, hi = np.percentile(vol, 1), np.percentile(vol, 99)
29
+ vol = (vol - lo) / (hi - lo + 1e-6)
30
+ vol = np.clip(vol, 0, 1)
31
+ return vol
32
+
33
+ def __getitem__(self, idx):
34
+ row = self.df.iloc[idx]
35
+ if pd.isna(row["embryo_paths"]) or pd.isna(row["empty_well_paths"]) or pd.isna(row["sample_paths"]):
36
+ print(f"Row {idx} has missing path data: ", row.to_string(index = False))
37
+ raise ValueError(f"Row {idx} has missing path data")
38
+
39
+ embryo_paths = row["embryo_paths"].split("|")
40
+ empty_well_paths = row["empty_well_paths"].split("|")
41
+ sample_paths = row["sample_paths"].split("|")
42
+
43
+ embryo_frames = [self._read_gray(p) for p in embryo_paths]
44
+ embryo_vol = np.stack(embryo_frames, axis=0)
45
+ embryo_vol = self._normalize_video(embryo_vol)
46
+ embryo_vol = embryo_vol[:,None, :, :]
47
+
48
+ empty_well_frames = [self._read_gray(p) for p in empty_well_paths]
49
+ empty_well_vol = np.stack(empty_well_frames, axis=0)
50
+ empty_well_vol = self._normalize_video(empty_well_vol)
51
+ empty_well_vol = empty_well_vol[:,None, :, :]
52
+
53
+ sample_frames = [self._read_gray(p) for p in sample_paths]
54
+ sample_vol = np.stack(sample_frames, axis=0)
55
+ sample_vol = self._normalize_video(sample_vol)
56
+ sample_vol = sample_vol[:,None, :, :]
57
+ return torch.from_numpy(embryo_vol), torch.from_numpy(empty_well_vol), torch.from_numpy(sample_vol)
58
+
59
+ def __len__(self):
60
+ return len(self.df)
61
+
62
+ if __name__ == "__main__":
63
+
64
+ ds = IVFSequenceDataset(pd.read_csv(os.path.abspath("index.csv")), resize=128, norm="minmax01")
65
+ total_size = len(ds)
66
+
67
+ train_size = int(0.85 * total_size)
68
+ val_size = total_size - train_size
69
+
70
+ generator = torch.Generator().manual_seed(42)
71
+ _, val_set = torch.utils.data.random_split(ds, [train_size, val_size], generator=generator)
72
+ val_df = ds.df.iloc[val_set.indices]
73
+ pd.set_option('display.max_rows', None)
74
+ print(val_df[['cell_id']])
75
+