Spaces:
Sleeping
Sleeping
File size: 3,706 Bytes
f95931d a0dd22b f95931d a0dd22b f95931d a0dd22b f95931d a0dd22b f95931d a0dd22b f95931d a0dd22b f95931d a0dd22b f95931d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import os
import h5py
import torch
import numpy as np
from skimage import measure
from torchvision import transforms
from torch.utils.data import Dataset
from skimage.measure import label, regionprops
class FrameDataset(Dataset):
def __init__(self, ffile, dfile, NrPixels=2048, nFrames=1440, batch_size=100, thresh=100, fHead=8192):
self.NrPixels = NrPixels
self.batch_size = batch_size
# Read dark frame
with open(dfile, 'rb') as darkf:
darkf.seek(fHead+NrPixels*NrPixels*2, os.SEEK_SET)
self.dark = np.fromfile(darkf, dtype=np.uint16, count=(NrPixels*NrPixels))
self.dark = np.reshape(self.dark,(NrPixels,NrPixels))
self.dark = self.dark.astype(float)
# Read frames
self.frames = []
self.length = nFrames
with open(ffile, 'rb') as f:
for _ in range(1, nFrames+1): # Skip first frame
BytesToSkip = fHead + fNr*NrPixels*NrPixels*2
f.seek(BytesToSkip, os.SEEK_SET)
this_frame = np.fromfile(f, dtype=np.uint16, count=(NrPixels*NrPixels))
this_frame = np.reshape(this_frame, (NrPixels, NrPixels))
this_frame = this_frame.astype(float)
this_frame = this_frame - self.dark
this_frame[this_frame < thresh] = 0
thisFrame = thisFrame.astype(int)
self.frames.append(this_frame)
def __iter__(self):
self.batch_start = 0
self.batch_end = self.batch_size
return self
def __next__(self):
if self.batch_end > self.length:
self.batch_start = 0
self.batch_end = self.batch_size
raise StopIteration
else:
f_batch = self.f_data[self.batch_start:self.batch_end]
d_batch = self.d_data[self.batch_start:self.batch_end]
self.batch_start += self.batch_size
self.batch_end += self.batch_size
return f_batch
def __len__(self):
return self.length
def __getitem__(self, index):
f_batch = self.frames[index*self.batch_size:(index+1)*self.batch_size]
return f_batch
def get_peaks_skimage(self, frames):
regions = []
for frame in frames:
frame_array = np.frombuffer(frame, dtype=np.uint16).reshape(self.NrPixels, self.NrPixels)
labels = measure.label(frame_array)
regions = regionprops(labels)
for prop_nr,props in enumerate(regions):
if props.area < 4 or props.area > 150:
continue
y0,x0 = props.centroid
start_x = int(x0)-window
end_x = int(x0)+window+1
start_y = int(y0)-window
end_y = int(y0)+window+1
if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1:
continue
sub_img = np.copy(thisFrame)
sub_img[labels != prop_nr+1] = 0
sub_img = sub_img[start_y:end_y,start_x:end_x]
patches.append(sub_img)
xy_positions.append([start_y,start_x])
patches = np.array(patches)
xy_positions = np.array(xy_positions)
return patches
def normalize_patches(self, patches):
normalized_patches = []
for patch in patches:
patch = patch.astype(float)
patch /= patch.max()
patch *= 255
patch = patch.astype(int)
normalized_patches.append(patch)
return normalized_patches
|