Spaces:
Runtime error
Runtime error
Update codes.py
Browse files
codes.py
CHANGED
|
@@ -89,79 +89,7 @@ def load_image_pil_accelerated(image_path, dim=128):
|
|
| 89 |
|
| 90 |
|
| 91 |
def preprocess_image(image_path, dim = 128):
|
| 92 |
-
img =
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
'''
|
| 96 |
-
def load_image_from_url(image_path, dim = 128):
|
| 97 |
-
img = Image.open(image_path).convert("RGB")
|
| 98 |
-
img = img.resize((dim, dim))
|
| 99 |
-
return img
|
| 100 |
-
|
| 101 |
-
def preprocess_image(image_path, dim = 128):
|
| 102 |
-
img = load_img(image_path, target_size=(dim, dim))
|
| 103 |
-
img = img_to_array(img)
|
| 104 |
-
img = np.expand_dims(img, axis=0)
|
| 105 |
return img
|
| 106 |
|
| 107 |
-
|
| 108 |
-
def create_model(dim = 128):
|
| 109 |
-
# configure unet input shape (concatenation of moving and fixed images)
|
| 110 |
-
volshape = (dim,dim,3)
|
| 111 |
-
unet_input_features = 2*volshape[:-1]
|
| 112 |
-
inshape = (*volshape[:-1],unet_input_features)
|
| 113 |
-
nb_conv_per_level=1
|
| 114 |
-
enc_nf = [dim, dim, dim, dim]
|
| 115 |
-
dec_nf = [dim, dim, dim, dim, dim, int(dim/2)]
|
| 116 |
-
nb_upsample_skips = 0
|
| 117 |
-
nb_dec_convs = len(enc_nf)
|
| 118 |
-
final_convs = dec_nf[nb_dec_convs:]
|
| 119 |
-
dec_nf = dec_nf[:nb_dec_convs]
|
| 120 |
-
nb_levels = int(nb_dec_convs / nb_conv_per_level) + 1
|
| 121 |
-
source = tf.keras.Input(shape=volshape, name='source_input')
|
| 122 |
-
target = tf.keras.Input(shape=volshape, name='target_input')
|
| 123 |
-
inputs = [source, target]
|
| 124 |
-
unet_input = concatenate(inputs, name='input_concat')
|
| 125 |
-
#Define lyers
|
| 126 |
-
ndims = len(unet_input.get_shape()) - 2
|
| 127 |
-
MaxPooling = getattr(tf.keras.layers, 'MaxPooling%dD' % ndims)
|
| 128 |
-
Conv = getattr(tf.keras.layers, 'Conv%dD' % ndims)
|
| 129 |
-
UpSampling = getattr(tf.keras.layers, 'UpSampling%dD' % ndims)
|
| 130 |
-
# Encoder
|
| 131 |
-
enc_layers = []
|
| 132 |
-
lyr = unet_input
|
| 133 |
-
for level in range(nb_levels - 1):
|
| 134 |
-
for conv in range(nb_conv_per_level):
|
| 135 |
-
nfeat = enc_nf[level * nb_conv_per_level + conv]
|
| 136 |
-
lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr)
|
| 137 |
-
enc_layers.append(lyr)
|
| 138 |
-
lyr = MaxPooling(2)(lyr)
|
| 139 |
-
|
| 140 |
-
# Decoder
|
| 141 |
-
for level in range(nb_levels - 1):
|
| 142 |
-
real_level = nb_levels - level - 2
|
| 143 |
-
for conv in range(nb_conv_per_level):
|
| 144 |
-
nfeat = dec_nf[level * nb_conv_per_level + conv]
|
| 145 |
-
lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr)
|
| 146 |
-
# upsample
|
| 147 |
-
if level < (nb_levels - 1 - nb_upsample_skips):
|
| 148 |
-
upsampled = UpSampling(size=(2,) * ndims)(lyr)
|
| 149 |
-
lyr = concatenate([upsampled, enc_layers.pop()])
|
| 150 |
-
|
| 151 |
-
# Final convolution
|
| 152 |
-
for num, nfeat in enumerate(final_convs):
|
| 153 |
-
lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr)
|
| 154 |
-
|
| 155 |
-
unet = tf.keras.models.Model(inputs=inputs, outputs=lyr)
|
| 156 |
-
# transform the results into a flow field.
|
| 157 |
-
disp_tensor = Conv(ndims, kernel_size=3, padding='same', name='disp')(unet.output)
|
| 158 |
-
# using keras, we can easily form new models via tensor pointers
|
| 159 |
-
def_model = tf.keras.models.Model(inputs, disp_tensor)
|
| 160 |
-
# build transformer layer
|
| 161 |
-
spatial_transformer = SpatialTransformer()
|
| 162 |
-
# warp the moving image with the transformer
|
| 163 |
-
moved_image_tensor = spatial_transformer([source, disp_tensor])
|
| 164 |
-
outputs = [moved_image_tensor, disp_tensor]
|
| 165 |
-
vxm_model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
|
| 166 |
-
return vxm_model
|
| 167 |
-
'''
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
def preprocess_image(image_path, dim = 128):
|
| 92 |
+
img = torch.zeros([1,3,dim,dim])
|
| 93 |
+
img[0] = load_image_pil_accelerated(image_path, dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
return img
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|