salmasoma
Set up inference-only HyperClinical Streamlit app with runtime HF asset download
278bf2b
# -*- coding: utf-8 -*-
"""Transforms for AVRA inference (subset used by load_transform)."""
from __future__ import division
import torch
import numpy as np
class SwapAxes(object):
def __init__(self, axis1, axis2):
self.axis1 = axis1
self.axis2 = axis2
def __call__(self, image):
return np.swapaxes(image, self.axis1, self.axis2)
class Return5D(object):
def __init__(self, nc=1):
self.nc = nc
def __call__(self, image_1):
image_1 = image_1.unsqueeze(1)
if self.nc != 1:
image_1 = image_1.repeat(1, self.nc, 1, 1)
return image_1
class ReturnStackedPA(object):
def __init__(self, nc=1, ax_lim=[50, -5], cor_lim=[20, 75], sag_lim=[30, -30], stepsize=[2, 2, 2], rnn=True):
self.nc = nc
self.ax_lim = ax_lim
self.cor_lim = cor_lim
self.sag_lim = sag_lim
self.stepsize = stepsize
self.rnn = rnn
def __call__(self, ax):
ax_s, cor_s, sag_s = self.stepsize
sag = ax.permute(1, 0, 2)
cor = ax.permute(2, 1, 0)
ax = ax[self.ax_lim[0]:self.ax_lim[1], :, :]
cor = cor[self.cor_lim[0]:self.cor_lim[1], :, :]
sag = sag[self.sag_lim[0]:self.sag_lim[1], :, :]
ax = ax[0::ax_s, :, :]
cor = cor[0::cor_s, :, :]
sag = sag[0::sag_s, :, :]
ax = ax.unsqueeze(1)
cor = cor.unsqueeze(1)
sag = sag.unsqueeze(1)
if self.nc != 1:
ax = ax.repeat(1, self.nc, 1, 1)
sag = sag.repeat(1, self.nc, 1, 1)
cor = cor.repeat(1, self.nc, 1, 1)
img = torch.cat((ax, cor, sag), dim=0)
if not self.rnn:
img = img.squeeze(1)
return img
class ReduceSlices(object):
def __init__(self, factor_hw, factor_d):
self.f_h = factor_hw
self.f_w = factor_hw
self.f_d = factor_d
def __call__(self, image):
image = image[0::self.f_h, 0::self.f_w, 0::self.f_d]
return image
class CenterCrop(object):
def __init__(self, output_x, output_y, output_z, offset_x=0, offset_y=0, offset_z=0):
self.output_x = int(output_x)
self.output_y = int(output_y)
self.output_z = int(output_z)
self.offset_x = int(offset_x)
self.offset_y = int(offset_y)
self.offset_z = int(offset_z)
def __call__(self, image):
img_min = image.min()
img_max = image.max()
image = image - img_min
img_mean = image.mean()
x_orig, y_orig, z_orig = image.shape[:3]
x_mid = int(x_orig/2.)
y_mid = int(y_orig/2.)
z_mid = int(z_orig/2.)
new_x, new_y, new_z = self.output_x, self.output_y, self.output_z
x = int(x_mid + self.offset_x - round(new_x/2.))
y = int(y_mid + self.offset_y - round(new_y/2.))
z = int(z_mid + self.offset_z - round(new_z/2.))
if x + new_x > x_orig:
x = 0
new_x = x_orig
if y + new_y > y_orig:
y = 0
new_y = y_orig
if z + new_z > z_orig:
z = 0
new_z = z_orig
image = image[x:x+new_x, y:y+new_y, z:z+new_z]
image = image / img_mean
return image
class RandomCrop(object):
def __init__(self, output_x, output_y, output_z):
self.output_x = output_x
self.output_y = output_y
self.output_z = output_z
def __call__(self, image):
x, y, z = image.shape[:3]
new_x, new_y, new_z = self.output_x, self.output_y, self.output_z
x = np.random.randint(0, max(1, x - new_x))
y = np.random.randint(0, max(1, y - new_y))
z = np.random.randint(0, max(1, z - new_z))
image = image[x:x+new_x, y:y+new_y, z:z+new_z]
return image
class RandomMirrorLR(object):
def __init__(self, axis):
self.axis = axis
def __call__(self, image):
if np.random.randn() > 0:
image = np.flip(image, self.axis).copy()
return image
class RandomNoise(object):
def __init__(self, noise_var=0.1, p=0.5):
self.noise_var = noise_var
self.p = p
def __call__(self, image):
if torch.rand(1)[0] < self.p:
var = torch.rand(1)[0] * self.noise_var
image = image + torch.randn(image.shape) * var
return image
class PerImageNormalization(object):
def __call__(self, image):
image = image - image.mean()
image = image / image.std()
return image
class ToTensorFSL(object):
def __call__(self, image):
image = image.transpose((2, 0, 1))
image = torch.from_numpy(image)
return image