File size: 4,658 Bytes
278bf2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# -*- coding: utf-8 -*-
"""Transforms for AVRA inference (subset used by load_transform)."""
from __future__ import division
import torch
import numpy as np


class SwapAxes(object):
    def __init__(self, axis1, axis2):
        self.axis1 = axis1
        self.axis2 = axis2

    def __call__(self, image):
        return np.swapaxes(image, self.axis1, self.axis2)


class Return5D(object):
    def __init__(self, nc=1):
        self.nc = nc

    def __call__(self, image_1):
        image_1 = image_1.unsqueeze(1)
        if self.nc != 1:
            image_1 = image_1.repeat(1, self.nc, 1, 1)
        return image_1


class ReturnStackedPA(object):
    def __init__(self, nc=1, ax_lim=[50, -5], cor_lim=[20, 75], sag_lim=[30, -30], stepsize=[2, 2, 2], rnn=True):
        self.nc = nc
        self.ax_lim = ax_lim
        self.cor_lim = cor_lim
        self.sag_lim = sag_lim
        self.stepsize = stepsize
        self.rnn = rnn

    def __call__(self, ax):
        ax_s, cor_s, sag_s = self.stepsize
        sag = ax.permute(1, 0, 2)
        cor = ax.permute(2, 1, 0)
        ax = ax[self.ax_lim[0]:self.ax_lim[1], :, :]
        cor = cor[self.cor_lim[0]:self.cor_lim[1], :, :]
        sag = sag[self.sag_lim[0]:self.sag_lim[1], :, :]
        ax = ax[0::ax_s, :, :]
        cor = cor[0::cor_s, :, :]
        sag = sag[0::sag_s, :, :]
        ax = ax.unsqueeze(1)
        cor = cor.unsqueeze(1)
        sag = sag.unsqueeze(1)
        if self.nc != 1:
            ax = ax.repeat(1, self.nc, 1, 1)
            sag = sag.repeat(1, self.nc, 1, 1)
            cor = cor.repeat(1, self.nc, 1, 1)
        img = torch.cat((ax, cor, sag), dim=0)
        if not self.rnn:
            img = img.squeeze(1)
        return img


class ReduceSlices(object):
    def __init__(self, factor_hw, factor_d):
        self.f_h = factor_hw
        self.f_w = factor_hw
        self.f_d = factor_d

    def __call__(self, image):
        image = image[0::self.f_h, 0::self.f_w, 0::self.f_d]
        return image


class CenterCrop(object):
    def __init__(self, output_x, output_y, output_z, offset_x=0, offset_y=0, offset_z=0):
        self.output_x = int(output_x)
        self.output_y = int(output_y)
        self.output_z = int(output_z)
        self.offset_x = int(offset_x)
        self.offset_y = int(offset_y)
        self.offset_z = int(offset_z)

    def __call__(self, image):
        img_min = image.min()
        img_max = image.max()
        image = image - img_min
        img_mean = image.mean()
        x_orig, y_orig, z_orig = image.shape[:3]
        x_mid = int(x_orig/2.)
        y_mid = int(y_orig/2.)
        z_mid = int(z_orig/2.)
        new_x, new_y, new_z = self.output_x, self.output_y, self.output_z
        x = int(x_mid + self.offset_x - round(new_x/2.))
        y = int(y_mid + self.offset_y - round(new_y/2.))
        z = int(z_mid + self.offset_z - round(new_z/2.))
        if x + new_x > x_orig:
            x = 0
            new_x = x_orig
        if y + new_y > y_orig:
            y = 0
            new_y = y_orig
        if z + new_z > z_orig:
            z = 0
            new_z = z_orig
        image = image[x:x+new_x, y:y+new_y, z:z+new_z]
        image = image / img_mean
        return image


class RandomCrop(object):
    def __init__(self, output_x, output_y, output_z):
        self.output_x = output_x
        self.output_y = output_y
        self.output_z = output_z

    def __call__(self, image):
        x, y, z = image.shape[:3]
        new_x, new_y, new_z = self.output_x, self.output_y, self.output_z
        x = np.random.randint(0, max(1, x - new_x))
        y = np.random.randint(0, max(1, y - new_y))
        z = np.random.randint(0, max(1, z - new_z))
        image = image[x:x+new_x, y:y+new_y, z:z+new_z]
        return image


class RandomMirrorLR(object):
    def __init__(self, axis):
        self.axis = axis

    def __call__(self, image):
        if np.random.randn() > 0:
            image = np.flip(image, self.axis).copy()
        return image


class RandomNoise(object):
    def __init__(self, noise_var=0.1, p=0.5):
        self.noise_var = noise_var
        self.p = p

    def __call__(self, image):
        if torch.rand(1)[0] < self.p:
            var = torch.rand(1)[0] * self.noise_var
            image = image + torch.randn(image.shape) * var
        return image


class PerImageNormalization(object):
    def __call__(self, image):
        image = image - image.mean()
        image = image / image.std()
        return image


class ToTensorFSL(object):
    def __call__(self, image):
        image = image.transpose((2, 0, 1))
        image = torch.from_numpy(image)
        return image