Amould commited on
Commit
2a6b806
·
verified ·
1 Parent(s): b45b4ad

Update codes.py

Browse files
Files changed (1) hide show
  1. codes.py +2 -74
codes.py CHANGED
@@ -89,79 +89,7 @@ def load_image_pil_accelerated(image_path, dim=128):
89
 
90
 
91
  def preprocess_image(image_path, dim = 128):
92
- img = load_image_pil_accelerated(image_path, dim)
93
- return img.unsqueeze(0)
94
-
95
- '''
96
- def load_image_from_url(image_path, dim = 128):
97
- img = Image.open(image_path).convert("RGB")
98
- img = img.resize((dim, dim))
99
- return img
100
-
101
- def preprocess_image(image_path, dim = 128):
102
- img = load_img(image_path, target_size=(dim, dim))
103
- img = img_to_array(img)
104
- img = np.expand_dims(img, axis=0)
105
  return img
106
 
107
-
108
- def create_model(dim = 128):
109
- # configure unet input shape (concatenation of moving and fixed images)
110
- volshape = (dim,dim,3)
111
- unet_input_features = 2*volshape[:-1]
112
- inshape = (*volshape[:-1],unet_input_features)
113
- nb_conv_per_level=1
114
- enc_nf = [dim, dim, dim, dim]
115
- dec_nf = [dim, dim, dim, dim, dim, int(dim/2)]
116
- nb_upsample_skips = 0
117
- nb_dec_convs = len(enc_nf)
118
- final_convs = dec_nf[nb_dec_convs:]
119
- dec_nf = dec_nf[:nb_dec_convs]
120
- nb_levels = int(nb_dec_convs / nb_conv_per_level) + 1
121
- source = tf.keras.Input(shape=volshape, name='source_input')
122
- target = tf.keras.Input(shape=volshape, name='target_input')
123
- inputs = [source, target]
124
- unet_input = concatenate(inputs, name='input_concat')
125
- #Define lyers
126
- ndims = len(unet_input.get_shape()) - 2
127
- MaxPooling = getattr(tf.keras.layers, 'MaxPooling%dD' % ndims)
128
- Conv = getattr(tf.keras.layers, 'Conv%dD' % ndims)
129
- UpSampling = getattr(tf.keras.layers, 'UpSampling%dD' % ndims)
130
- # Encoder
131
- enc_layers = []
132
- lyr = unet_input
133
- for level in range(nb_levels - 1):
134
- for conv in range(nb_conv_per_level):
135
- nfeat = enc_nf[level * nb_conv_per_level + conv]
136
- lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr)
137
- enc_layers.append(lyr)
138
- lyr = MaxPooling(2)(lyr)
139
-
140
- # Decoder
141
- for level in range(nb_levels - 1):
142
- real_level = nb_levels - level - 2
143
- for conv in range(nb_conv_per_level):
144
- nfeat = dec_nf[level * nb_conv_per_level + conv]
145
- lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr)
146
- # upsample
147
- if level < (nb_levels - 1 - nb_upsample_skips):
148
- upsampled = UpSampling(size=(2,) * ndims)(lyr)
149
- lyr = concatenate([upsampled, enc_layers.pop()])
150
-
151
- # Final convolution
152
- for num, nfeat in enumerate(final_convs):
153
- lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr)
154
-
155
- unet = tf.keras.models.Model(inputs=inputs, outputs=lyr)
156
- # transform the results into a flow field.
157
- disp_tensor = Conv(ndims, kernel_size=3, padding='same', name='disp')(unet.output)
158
- # using keras, we can easily form new models via tensor pointers
159
- def_model = tf.keras.models.Model(inputs, disp_tensor)
160
- # build transformer layer
161
- spatial_transformer = SpatialTransformer()
162
- # warp the moving image with the transformer
163
- moved_image_tensor = spatial_transformer([source, disp_tensor])
164
- outputs = [moved_image_tensor, disp_tensor]
165
- vxm_model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
166
- return vxm_model
167
- '''
 
89
 
90
 
91
  def preprocess_image(image_path, dim = 128):
92
+ img = torch.zeros([1,3,dim,dim])
93
+ img[0] = load_image_pil_accelerated(image_path, dim)
 
 
 
 
 
 
 
 
 
 
 
94
  return img
95