Amould commited on
Commit
fffea84
·
verified ·
1 Parent(s): 7cb6898

Create codes.py

Browse files
Files changed (1) hide show
  1. codes.py +179 -0
codes.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import numpy as np
3
+ from PIL import Image
4
+ import itertools
5
+ import glob
6
+ import random
7
+ import torch
8
+ import torchvision
9
+ import torchvision.transforms as transforms
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ import torch.nn.functional as F
13
+ from torch.nn.functional import relu as RLU
14
+
15
+ registration_method = 'Additive_Recurence' #{'Rawblock', 'matching_points', 'Additive_Recurence', 'Multiplicative_Recurence'} #'recurrent_matrix',
16
+ imposed_point = 0
17
+ Arch = 'ResNet'
18
+ Fix_Torch_Wrap = False
19
+ BW_Position = False
20
+ dim = 128
21
+ dim0 =224
22
+ crop_ratio = dim/dim0
23
+
24
+
25
+
26
+
27
+ class Identity(nn.Module):
28
+ def __init__(self):
29
+ super(Identity, self).__init__()
30
+ def forward(self, x):
31
+ return x
32
+
33
+ class Build_IRmodel_Resnet(nn.Module):
34
+ def __init__(self, resnet_model, registration_method = 'Additive_Recurence', BW_Position=False):
35
+ super(Build_IRmodel_Resnet, self).__init__()
36
+ self.resnet_model = resnet_model
37
+ self.BW_Position = BW_Position
38
+ self.N_parameters = 6
39
+ self.registration_method = registration_method
40
+ self.fc1 =nn.Linear(6, 64)
41
+ self.fc2 =nn.Linear(64, 128*3)
42
+ self.fc3 =nn.Linear(512, self.N_parameters)
43
+ def forward(self, input_X_batch):
44
+ source = input_X_batch['source']
45
+ target = input_X_batch['target']
46
+ if 'Recurence' in self.registration_method:
47
+ M_i = input_X_batch['M_i'].view(-1, 6)
48
+ M_rep = F.relu(self.fc1(M_i))
49
+ M_rep = F.relu(self.fc2(M_rep)).view(-1,3,1,128)
50
+ concatenated_input = torch.cat((source,target,M_rep), dim=2)
51
+ else:
52
+ concatenated_input = torch.cat((source,target), dim=2)
53
+ resnet_output = self.resnet_model(concatenated_input)
54
+ predicted_line = self.fc3(resnet_output)
55
+ if 'Recurence' in self.registration_method:
56
+ predicted_part_mtrx = predicted_line.view(-1, 2, 3)
57
+ Prd_Affine_mtrx = predicted_part_mtrx + input_X_batch['M_i']
58
+ predction = {'predicted_part_mtrx':predicted_part_mtrx,
59
+ 'Affine_mtrx': Prd_Affine_mtrx}
60
+ else:
61
+ Prd_Affine_mtrx = predicted_line.view(-1, 2, 3)
62
+ predction = {'Affine_mtrx': Prd_Affine_mtrx}
63
+ return predction
64
+
65
+ from torchvision.models import resnet18
66
+
67
+ core_model_tst = resnet18(pretrained=True)
68
+ core_model_tst.fc = Identity()
69
+ core_model_tst.load_state_dict(torch.load(file_savingfolder+'core_model'+ext+'.pth'))
70
+ core_model_tst.to(device)
71
+ IR_Model_tst = Build_IRmodel_Resnet(core_model_tst, registration_method)
72
+ IR_Model_tst.load_state_dict(torch.load(file_savingfolder+'IR_Model'+ext+'.pth'))
73
+ IR_Model_tst.to(device)
74
+ IR_Model_tst.eval()
75
+
76
+
77
+ def pil_to_numpy(im):
78
+ im.load()
79
+ # Unpack data
80
+ e = Image._getencoder(im.mode, "raw", im.mode)
81
+ e.setimage(im.im)
82
+ # NumPy buffer for the result
83
+ shape, typestr = Image._conv_type_shape(im)
84
+ data = np.empty(shape, dtype=np.dtype(typestr))
85
+ mem = data.data.cast("B", (data.data.nbytes,))
86
+ bufsize, s, offset = 65536, 0, 0
87
+ while not s:
88
+ l, s, d = e.encode(bufsize)
89
+ mem[offset:offset + len(d)] = d
90
+ offset += len(d)
91
+ if s < 0:
92
+ raise RuntimeError("encoder error %d in tobytes" % s)
93
+ return data
94
+
95
+ def load_image_pil_accelerated(image_path, dim=128):
96
+ image = Image.open(image_path).convert("RGB")
97
+ array = pil_to_numpy(image)
98
+ tensor = torch.from_numpy(np.rollaxis(array,2,0)/255).to(torch.float32)
99
+ tensor = torchvision.transforms.Resize((dim,dim))(tensor)
100
+ return tensor
101
+
102
+
103
+ def preprocess_image(image_path, dim = 128):
104
+ img = load_image_pil_accelerated(image_path, dim)
105
+ return img.unsqueeze(0)
106
+
107
+ '''
108
+ def load_image_from_url(image_path, dim = 128):
109
+ img = Image.open(image_path).convert("RGB")
110
+ img = img.resize((dim, dim))
111
+ return img
112
+
113
+ def preprocess_image(image_path, dim = 128):
114
+ img = load_img(image_path, target_size=(dim, dim))
115
+ img = img_to_array(img)
116
+ img = np.expand_dims(img, axis=0)
117
+ return img
118
+
119
+
120
+ def create_model(dim = 128):
121
+ # configure unet input shape (concatenation of moving and fixed images)
122
+ volshape = (dim,dim,3)
123
+ unet_input_features = 2*volshape[:-1]
124
+ inshape = (*volshape[:-1],unet_input_features)
125
+ nb_conv_per_level=1
126
+ enc_nf = [dim, dim, dim, dim]
127
+ dec_nf = [dim, dim, dim, dim, dim, int(dim/2)]
128
+ nb_upsample_skips = 0
129
+ nb_dec_convs = len(enc_nf)
130
+ final_convs = dec_nf[nb_dec_convs:]
131
+ dec_nf = dec_nf[:nb_dec_convs]
132
+ nb_levels = int(nb_dec_convs / nb_conv_per_level) + 1
133
+ source = tf.keras.Input(shape=volshape, name='source_input')
134
+ target = tf.keras.Input(shape=volshape, name='target_input')
135
+ inputs = [source, target]
136
+ unet_input = concatenate(inputs, name='input_concat')
137
+ #Define lyers
138
+ ndims = len(unet_input.get_shape()) - 2
139
+ MaxPooling = getattr(tf.keras.layers, 'MaxPooling%dD' % ndims)
140
+ Conv = getattr(tf.keras.layers, 'Conv%dD' % ndims)
141
+ UpSampling = getattr(tf.keras.layers, 'UpSampling%dD' % ndims)
142
+ # Encoder
143
+ enc_layers = []
144
+ lyr = unet_input
145
+ for level in range(nb_levels - 1):
146
+ for conv in range(nb_conv_per_level):
147
+ nfeat = enc_nf[level * nb_conv_per_level + conv]
148
+ lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr)
149
+ enc_layers.append(lyr)
150
+ lyr = MaxPooling(2)(lyr)
151
+
152
+ # Decoder
153
+ for level in range(nb_levels - 1):
154
+ real_level = nb_levels - level - 2
155
+ for conv in range(nb_conv_per_level):
156
+ nfeat = dec_nf[level * nb_conv_per_level + conv]
157
+ lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr)
158
+ # upsample
159
+ if level < (nb_levels - 1 - nb_upsample_skips):
160
+ upsampled = UpSampling(size=(2,) * ndims)(lyr)
161
+ lyr = concatenate([upsampled, enc_layers.pop()])
162
+
163
+ # Final convolution
164
+ for num, nfeat in enumerate(final_convs):
165
+ lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr)
166
+
167
+ unet = tf.keras.models.Model(inputs=inputs, outputs=lyr)
168
+ # transform the results into a flow field.
169
+ disp_tensor = Conv(ndims, kernel_size=3, padding='same', name='disp')(unet.output)
170
+ # using keras, we can easily form new models via tensor pointers
171
+ def_model = tf.keras.models.Model(inputs, disp_tensor)
172
+ # build transformer layer
173
+ spatial_transformer = SpatialTransformer()
174
+ # warp the moving image with the transformer
175
+ moved_image_tensor = spatial_transformer([source, disp_tensor])
176
+ outputs = [moved_image_tensor, disp_tensor]
177
+ vxm_model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
178
+ return vxm_model
179
+ '''