File size: 3,706 Bytes
f95931d
 
 
a0dd22b
 
 
f95931d
 
a0dd22b
f95931d
 
 
 
 
 
 
 
 
a0dd22b
f95931d
 
 
a0dd22b
f95931d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0dd22b
f95931d
 
 
 
 
a0dd22b
f95931d
 
 
 
 
a0dd22b
 
f95931d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os 
import h5py 
import torch
import numpy as np
from skimage import measure
from torchvision import transforms
from torch.utils.data import Dataset
from skimage.measure import label, regionprops

class FrameDataset(Dataset): 
    def __init__(self, ffile, dfile, NrPixels=2048, nFrames=1440, batch_size=100, thresh=100, fHead=8192):
        self.NrPixels = NrPixels
        self.batch_size = batch_size
        
        # Read dark frame
        with open(dfile, 'rb') as darkf:
            darkf.seek(fHead+NrPixels*NrPixels*2, os.SEEK_SET)
            self.dark = np.fromfile(darkf, dtype=np.uint16, count=(NrPixels*NrPixels))
            self.dark = np.reshape(self.dark,(NrPixels,NrPixels))
        self.dark = self.dark.astype(float)
        
        # Read frames
        self.frames = []
        self.length = nFrames
        with open(ffile, 'rb') as f:
            for _ in range(1, nFrames+1):  # Skip first frame
                BytesToSkip = fHead + fNr*NrPixels*NrPixels*2
                f.seek(BytesToSkip, os.SEEK_SET)
                this_frame = np.fromfile(f, dtype=np.uint16, count=(NrPixels*NrPixels))
                this_frame = np.reshape(this_frame, (NrPixels, NrPixels))
                this_frame = this_frame.astype(float)
                this_frame = this_frame - self.dark
                this_frame[this_frame < thresh] = 0
                thisFrame = thisFrame.astype(int)
                self.frames.append(this_frame)
 
    def __iter__(self):
        self.batch_start = 0
        self.batch_end = self.batch_size
        return self
    
    def __next__(self):
        if self.batch_end > self.length:
            self.batch_start = 0
            self.batch_end = self.batch_size
            raise StopIteration
        else:
            f_batch = self.f_data[self.batch_start:self.batch_end]
            d_batch = self.d_data[self.batch_start:self.batch_end]
            self.batch_start += self.batch_size
            self.batch_end += self.batch_size
            return f_batch
    
    def __len__(self):
        return self.length
       
    def __getitem__(self, index):
        f_batch = self.frames[index*self.batch_size:(index+1)*self.batch_size]
        return f_batch

    def get_peaks_skimage(self, frames):
        regions = []
        for frame in frames:
            frame_array = np.frombuffer(frame, dtype=np.uint16).reshape(self.NrPixels, self.NrPixels)
            labels  = measure.label(frame_array)
            regions = regionprops(labels)
            for prop_nr,props in enumerate(regions): 
                if props.area < 4 or props.area > 150:                                         
                    continue
                y0,x0   = props.centroid
                start_x = int(x0)-window
                end_x   = int(x0)+window+1
                start_y = int(y0)-window
                end_y   = int(y0)+window+1
                if start_x < 0 or end_x > NrPixels - 1 or start_y < 0 or end_y > NrPixels - 1:
                    continue
                sub_img = np.copy(thisFrame)
                sub_img[labels != prop_nr+1] = 0
                sub_img = sub_img[start_y:end_y,start_x:end_x]
                patches.append(sub_img)
                xy_positions.append([start_y,start_x])
            patches = np.array(patches)
            xy_positions = np.array(xy_positions)
        return patches

    def normalize_patches(self, patches):
        normalized_patches = []
        for patch in patches:
            patch = patch.astype(float)
            patch /= patch.max()
            patch *= 255
            patch = patch.astype(int)
            normalized_patches.append(patch)
        return normalized_patches