Spaces:
Running
Running
File size: 4,658 Bytes
278bf2b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | # -*- coding: utf-8 -*-
"""Transforms for AVRA inference (subset used by load_transform)."""
from __future__ import division
import torch
import numpy as np
class SwapAxes(object):
def __init__(self, axis1, axis2):
self.axis1 = axis1
self.axis2 = axis2
def __call__(self, image):
return np.swapaxes(image, self.axis1, self.axis2)
class Return5D(object):
def __init__(self, nc=1):
self.nc = nc
def __call__(self, image_1):
image_1 = image_1.unsqueeze(1)
if self.nc != 1:
image_1 = image_1.repeat(1, self.nc, 1, 1)
return image_1
class ReturnStackedPA(object):
def __init__(self, nc=1, ax_lim=[50, -5], cor_lim=[20, 75], sag_lim=[30, -30], stepsize=[2, 2, 2], rnn=True):
self.nc = nc
self.ax_lim = ax_lim
self.cor_lim = cor_lim
self.sag_lim = sag_lim
self.stepsize = stepsize
self.rnn = rnn
def __call__(self, ax):
ax_s, cor_s, sag_s = self.stepsize
sag = ax.permute(1, 0, 2)
cor = ax.permute(2, 1, 0)
ax = ax[self.ax_lim[0]:self.ax_lim[1], :, :]
cor = cor[self.cor_lim[0]:self.cor_lim[1], :, :]
sag = sag[self.sag_lim[0]:self.sag_lim[1], :, :]
ax = ax[0::ax_s, :, :]
cor = cor[0::cor_s, :, :]
sag = sag[0::sag_s, :, :]
ax = ax.unsqueeze(1)
cor = cor.unsqueeze(1)
sag = sag.unsqueeze(1)
if self.nc != 1:
ax = ax.repeat(1, self.nc, 1, 1)
sag = sag.repeat(1, self.nc, 1, 1)
cor = cor.repeat(1, self.nc, 1, 1)
img = torch.cat((ax, cor, sag), dim=0)
if not self.rnn:
img = img.squeeze(1)
return img
class ReduceSlices(object):
def __init__(self, factor_hw, factor_d):
self.f_h = factor_hw
self.f_w = factor_hw
self.f_d = factor_d
def __call__(self, image):
image = image[0::self.f_h, 0::self.f_w, 0::self.f_d]
return image
class CenterCrop(object):
def __init__(self, output_x, output_y, output_z, offset_x=0, offset_y=0, offset_z=0):
self.output_x = int(output_x)
self.output_y = int(output_y)
self.output_z = int(output_z)
self.offset_x = int(offset_x)
self.offset_y = int(offset_y)
self.offset_z = int(offset_z)
def __call__(self, image):
img_min = image.min()
img_max = image.max()
image = image - img_min
img_mean = image.mean()
x_orig, y_orig, z_orig = image.shape[:3]
x_mid = int(x_orig/2.)
y_mid = int(y_orig/2.)
z_mid = int(z_orig/2.)
new_x, new_y, new_z = self.output_x, self.output_y, self.output_z
x = int(x_mid + self.offset_x - round(new_x/2.))
y = int(y_mid + self.offset_y - round(new_y/2.))
z = int(z_mid + self.offset_z - round(new_z/2.))
if x + new_x > x_orig:
x = 0
new_x = x_orig
if y + new_y > y_orig:
y = 0
new_y = y_orig
if z + new_z > z_orig:
z = 0
new_z = z_orig
image = image[x:x+new_x, y:y+new_y, z:z+new_z]
image = image / img_mean
return image
class RandomCrop(object):
def __init__(self, output_x, output_y, output_z):
self.output_x = output_x
self.output_y = output_y
self.output_z = output_z
def __call__(self, image):
x, y, z = image.shape[:3]
new_x, new_y, new_z = self.output_x, self.output_y, self.output_z
x = np.random.randint(0, max(1, x - new_x))
y = np.random.randint(0, max(1, y - new_y))
z = np.random.randint(0, max(1, z - new_z))
image = image[x:x+new_x, y:y+new_y, z:z+new_z]
return image
class RandomMirrorLR(object):
def __init__(self, axis):
self.axis = axis
def __call__(self, image):
if np.random.randn() > 0:
image = np.flip(image, self.axis).copy()
return image
class RandomNoise(object):
def __init__(self, noise_var=0.1, p=0.5):
self.noise_var = noise_var
self.p = p
def __call__(self, image):
if torch.rand(1)[0] < self.p:
var = torch.rand(1)[0] * self.noise_var
image = image + torch.randn(image.shape) * var
return image
class PerImageNormalization(object):
def __call__(self, image):
image = image - image.mean()
image = image / image.std()
return image
class ToTensorFSL(object):
def __call__(self, image):
image = image.transpose((2, 0, 1))
image = torch.from_numpy(image)
return image
|