Spaces:
Sleeping
Sleeping
| import os | |
| import h5py | |
| import torch | |
| import numpy as np | |
| from skimage import measure | |
| from torchvision import transforms | |
| from torch.utils.data import Dataset | |
| from skimage.measure import label, regionprops | |
| class FrameDataset(Dataset): | |
| def __init__(self, ffile, dfile, NrPixels=2048, nFrames=1440, batch_size=100, thresh=100, fHead=8192): | |
| self.NrPixels = NrPixels | |
| self.batch_size = batch_size | |
| # Read dark frame | |
| with open(dfile, 'rb') as darkf: | |
| darkf.seek(fHead+NrPixels*NrPixels*2, os.SEEK_SET) | |
| self.dark = np.fromfile(darkf, dtype=np.uint16, count=(NrPixels*NrPixels)) | |
| self.dark = np.reshape(self.dark,(NrPixels,NrPixels)) | |
| self.dark = self.dark.astype(float) | |
| # Read frames | |
| self.frames = [] | |
| self.length = nFrames | |
| with open(ffile, 'rb') as f: | |
| for _ in range(1, nFrames+1): # Skip first frame | |
| BytesToSkip = fHead + fNr*NrPixels*NrPixels*2 | |
| f.seek(BytesToSkip, os.SEEK_SET) | |
| this_frame = np.fromfile(f, dtype=np.uint16, count=(NrPixels*NrPixels)) | |
| this_frame = np.reshape(this_frame, (NrPixels, NrPixels)) | |
| this_frame = this_frame.astype(float) | |
| this_frame = this_frame - self.dark | |
| this_frame[this_frame < thresh] = 0 | |
| thisFrame = thisFrame.astype(int) | |
| self.frames.append(this_frame) | |
| def __iter__(self): | |
| self.batch_start = 0 | |
| self.batch_end = self.batch_size | |
| return self | |
| def __next__(self): | |
| if self.batch_end > self.length: | |
| self.batch_start = 0 | |
| self.batch_end = self.batch_size | |
| raise StopIteration | |
| else: | |
| f_batch = self.f_data[self.batch_start:self.batch_end] | |
| d_batch = self.d_data[self.batch_start:self.batch_end] | |
| self.batch_start += self.batch_size | |
| self.batch_end += self.batch_size | |
| return f_batch | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, index): | |
| f_batch = self.frames[index*self.batch_size:(index+1)*self.batch_size] | |
| return f_batch | |
| def get_peaks_skimage(self, frames): | |
| regions = [] | |
| for frame in frames: | |
| frame_array = np.frombuffer(frame, dtype=np.uint16).reshape(self.NrPixels, self.NrPixels) | |
| labels = measure.label(frame_array) | |
| regions = regionprops(labels) | |
| for prop_nr,props in enumerate(regions): | |
| if props.area < 4 or props.area > 150: | |
| continue | |
| y0,x0 = props.centroid | |
| start_x = int(x0)-window | |
| end_x = int(x0)+window+1 | |
| start_y = int(y0)-window | |
| end_y = int(y0)+window+1 | |
| if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1: | |
| continue | |
| sub_img = np.copy(thisFrame) | |
| sub_img[labels != prop_nr+1] = 0 | |
| sub_img = sub_img[start_y:end_y,start_x:end_x] | |
| patches.append(sub_img) | |
| xy_positions.append([start_y,start_x]) | |
| patches = np.array(patches) | |
| xy_positions = np.array(xy_positions) | |
| return patches | |
| def normalize_patches(self, patches): | |
| normalized_patches = [] | |
| for patch in patches: | |
| patch = patch.astype(float) | |
| patch /= patch.max() | |
| patch *= 255 | |
| patch = patch.astype(int) | |
| normalized_patches.append(patch) | |
| return normalized_patches | |