Commit ·
8a6f09e
1
Parent(s): e1e6b7c
added dataset.py
Browse files- dataset.py +156 -0
dataset.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
import numpy as np
|
| 3 |
+
import h5py, torch, random, logging
|
| 4 |
+
from skimage.feature import peak_local_max
|
| 5 |
+
from skimage import measure
|
| 6 |
+
|
| 7 |
+
def clean_patch(p, center):
|
| 8 |
+
w, h = p.shape
|
| 9 |
+
cc = measure.label(p > 0)
|
| 10 |
+
if cc.max() == 1:
|
| 11 |
+
return p
|
| 12 |
+
|
| 13 |
+
# logging.warn(f"{cc.max()} peaks located in a patch")
|
| 14 |
+
lmin = np.inf
|
| 15 |
+
cc_lmin = None
|
| 16 |
+
for _c in range(1, cc.max()+1):
|
| 17 |
+
lmax = peak_local_max(p * (cc==_c), min_distance=1)
|
| 18 |
+
if lmax.shape[0] == 0:continue # single pixel component
|
| 19 |
+
lc = lmax.mean(axis=0)
|
| 20 |
+
dist = ((lc - center)**2).sum()
|
| 21 |
+
if dist < lmin:
|
| 22 |
+
cc_lmin = _c
|
| 23 |
+
lmin = dist
|
| 24 |
+
return p * (cc == cc_lmin)
|
| 25 |
+
|
| 26 |
+
class BraggNNDataset(Dataset):
|
| 27 |
+
def __init__(self, pfile, ffile, psz=15, rnd_shift=0, use='train', train_frac=0.8):
|
| 28 |
+
self.psz = psz
|
| 29 |
+
self.rnd_shift = rnd_shift
|
| 30 |
+
|
| 31 |
+
with h5py.File(pfile, "r") as h5fd:
|
| 32 |
+
if use == 'train':
|
| 33 |
+
sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0])
|
| 34 |
+
elif use == 'validation':
|
| 35 |
+
sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None
|
| 36 |
+
else:
|
| 37 |
+
logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
|
| 38 |
+
|
| 39 |
+
mask = h5fd['npeaks'][sti:edi] == 1 # use only single-peak patches
|
| 40 |
+
mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1))
|
| 41 |
+
|
| 42 |
+
self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask]
|
| 43 |
+
self.peak_row = h5fd['peak_row'][sti:edi][mask]
|
| 44 |
+
self.peak_col = h5fd['peak_col'][sti:edi][mask]
|
| 45 |
+
|
| 46 |
+
self.fidx_base = self.peak_fidx.min()
|
| 47 |
+
# only loaded frames that will be used
|
| 48 |
+
with h5py.File(ffile, "r") as h5fd:
|
| 49 |
+
self.frames = h5fd['frames'][self.peak_fidx.min():self.peak_fidx.max()+1]
|
| 50 |
+
|
| 51 |
+
self.len = self.peak_fidx.shape[0]
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, idx):
|
| 54 |
+
_frame = self.frames[self.peak_fidx[idx] - self.fidx_base]
|
| 55 |
+
if self.rnd_shift > 0:
|
| 56 |
+
row_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1)
|
| 57 |
+
col_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1)
|
| 58 |
+
else:
|
| 59 |
+
row_shift, col_shift = 0, 0
|
| 60 |
+
prow_rnd = int(self.peak_row[idx]) + row_shift
|
| 61 |
+
pcol_rnd = int(self.peak_col[idx]) + col_shift
|
| 62 |
+
|
| 63 |
+
row_base = max(0, prow_rnd-self.psz//2)
|
| 64 |
+
col_base = max(0, pcol_rnd-self.psz//2 )
|
| 65 |
+
|
| 66 |
+
crop_img = _frame[row_base:(prow_rnd + self.psz//2 + self.psz%2), \
|
| 67 |
+
col_base:(pcol_rnd + self.psz//2 + self.psz%2)]
|
| 68 |
+
# if((crop_img > 0).sum() == 1): continue # ignore single non-zero peak
|
| 69 |
+
if crop_img.size != self.psz ** 2:
|
| 70 |
+
c_pad_l = (self.psz - crop_img.shape[1]) // 2
|
| 71 |
+
c_pad_r = self.psz - c_pad_l - crop_img.shape[1]
|
| 72 |
+
|
| 73 |
+
r_pad_t = (self.psz - crop_img.shape[0]) // 2
|
| 74 |
+
r_pad_b = self.psz - r_pad_t - crop_img.shape[0]
|
| 75 |
+
|
| 76 |
+
logging.warn(f"sample {idx} touched edge when crop the patch: {crop_img.shape}")
|
| 77 |
+
crop_img = np.pad(crop_img, ((r_pad_t, r_pad_b), (c_pad_l, c_pad_r)), mode='constant')
|
| 78 |
+
else:
|
| 79 |
+
c_pad_l, r_pad_t = 0 ,0
|
| 80 |
+
|
| 81 |
+
_center = np.array([self.peak_row[idx] - row_base + r_pad_t, self.peak_col[idx] - col_base + c_pad_l])
|
| 82 |
+
crop_img = clean_patch(crop_img, _center)
|
| 83 |
+
if crop_img.max() != crop_img.min():
|
| 84 |
+
_min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32)
|
| 85 |
+
feature = (crop_img - _min) / (_max - _min)
|
| 86 |
+
else:
|
| 87 |
+
logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum()))
|
| 88 |
+
feature = crop_img
|
| 89 |
+
|
| 90 |
+
px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz
|
| 91 |
+
py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz
|
| 92 |
+
|
| 93 |
+
return feature[np.newaxis], np.array([px, py]).astype(np.float32)
|
| 94 |
+
|
| 95 |
+
def __len__(self):
|
| 96 |
+
return self.len
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class PatchWiseDataset(Dataset):
|
| 100 |
+
def __init__(self, pfile, ffile, psz=15, rnd_shift=0, use='train', train_frac=1):
|
| 101 |
+
self.psz = psz
|
| 102 |
+
self.rnd_shift = rnd_shift
|
| 103 |
+
with h5py.File(pfile, "r") as h5fd:
|
| 104 |
+
if use == 'train':
|
| 105 |
+
sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0])
|
| 106 |
+
elif use == 'validation':
|
| 107 |
+
sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None
|
| 108 |
+
else:
|
| 109 |
+
logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
|
| 110 |
+
|
| 111 |
+
mask = h5fd['npeaks'][sti:edi] == 1 # use only single-peak patches
|
| 112 |
+
mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1))
|
| 113 |
+
|
| 114 |
+
self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask]
|
| 115 |
+
self.peak_row = h5fd['peak_row'][sti:edi][mask]
|
| 116 |
+
self.peak_col = h5fd['peak_col'][sti:edi][mask]
|
| 117 |
+
|
| 118 |
+
self.fidx_base = self.peak_fidx.min()
|
| 119 |
+
# only loaded frames that will be used
|
| 120 |
+
with h5py.File(ffile, 'r') as h5fd:
|
| 121 |
+
if use == 'train':
|
| 122 |
+
sti, edi = 0, int(train_frac * h5fd['frames'].shape[0])
|
| 123 |
+
elif use == 'validation':
|
| 124 |
+
sti, edi = int(train_frac * h5fd['frames'].shape[0]), None
|
| 125 |
+
else:
|
| 126 |
+
logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
|
| 127 |
+
|
| 128 |
+
self.crop_img = h5fd['frames'][sti:edi]
|
| 129 |
+
self.len = self.peak_fidx.shape[0]
|
| 130 |
+
|
| 131 |
+
def __getitem__(self, idx):
|
| 132 |
+
crop_img = self.crop_img[idx]
|
| 133 |
+
|
| 134 |
+
row_shift, col_shift = 0, 0
|
| 135 |
+
c_pad_l, r_pad_t = 0 ,0
|
| 136 |
+
prow_rnd = int(self.peak_row[idx]) + row_shift
|
| 137 |
+
pcol_rnd = int(self.peak_col[idx]) + col_shift
|
| 138 |
+
|
| 139 |
+
row_base = max(0, prow_rnd-self.psz//2)
|
| 140 |
+
col_base = max(0, pcol_rnd-self.psz//2)
|
| 141 |
+
|
| 142 |
+
if crop_img.max() != crop_img.min():
|
| 143 |
+
_min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32)
|
| 144 |
+
feature = (crop_img - _min) / (_max - _min)
|
| 145 |
+
else:
|
| 146 |
+
#logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum()))
|
| 147 |
+
feature = crop_img
|
| 148 |
+
|
| 149 |
+
px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz
|
| 150 |
+
py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz
|
| 151 |
+
|
| 152 |
+
return feature[np.newaxis], np.array([px, py]).astype(np.float32)
|
| 153 |
+
|
| 154 |
+
def __len__(self):
|
| 155 |
+
return self.len
|
| 156 |
+
|