mcosarinsky commited on
Commit
4e79318
·
1 Parent(s): 861836e

add example images

Browse files
utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.23 kB). View file
 
utils/example1.jpg ADDED

Git LFS Details

  • SHA256: e0c31c67600fe3cebdabaa2f214970a8c240b312e36046f06b68e42d299a89f7
  • Pointer size: 131 Bytes
  • Size of remote file: 109 kB
utils/example2.jpg ADDED

Git LFS Details

  • SHA256: 1b30ba33e69dfec7f52a780188748e478e17c5c1127ce9e08d06cc5ed6dfcc8c
  • Pointer size: 131 Bytes
  • Size of remote file: 347 kB
utils/example3.png ADDED

Git LFS Details

  • SHA256: 7f0803e5e06f532ace2dbd86df8a62f5a9ffa9db121fd2db37caff9fe00dd6c5
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
utils/example4.jpg ADDED

Git LFS Details

  • SHA256: f6cc4d5b76a70511ed07a9c2b1c20fb19c7d25abca44adf0d53cadf2277b0a4d
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB
utils/plotting.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import matplotlib.pyplot as plt
4
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
5
+
6
+ def getDenseMask(landmarks, h, w):
7
+ RL, LL, H = landmarks[:44], landmarks[44:94], landmarks[94:]
8
+ img = np.zeros([h, w], dtype='uint8')
9
+ RL = RL.reshape(-1, 1, 2).astype('int')
10
+ LL = LL.reshape(-1, 1, 2).astype('int')
11
+ H = H.reshape(-1, 1, 2).astype('int')
12
+ img = cv2.drawContours(img, [RL], -1, 1, -1)
13
+ img = cv2.drawContours(img, [LL], -1, 1, -1)
14
+ img = cv2.drawContours(img, [H], -1, 2, -1)
15
+ return img
16
+
17
+ def drawOnTop(img, landmarks, original_shape):
18
+ h, w = original_shape
19
+ output = getDenseMask(landmarks, h, w)
20
+ image = np.zeros([h,w,3])
21
+ image[:,:,0] = img + 0.3*(output==1).astype('float') - 0.1*(output==2).astype('float')
22
+ image[:,:,1] = img + 0.3*(output==2).astype('float') - 0.1*(output==1).astype('float')
23
+ image[:,:,2] = img - 0.1*(output==1).astype('float') - 0.2*(output==2).astype('float')
24
+ image = np.clip(image,0,1)
25
+ RL, LL, H = landmarks[:44], landmarks[44:94], landmarks[94:]
26
+ for l in RL: image = cv2.circle(image,(int(l[0]),int(l[1])),5,(1,0,1),-1)
27
+ for l in LL: image = cv2.circle(image,(int(l[0]),int(l[1])),5,(1,0,1),-1)
28
+ for l in H: image = cv2.circle(image,(int(l[0]),int(l[1])),5,(1,1,0),-1)
29
+ return image
30
+
31
+ def create_overlay(img, landmarks):
32
+ h, w = img.shape[:2]
33
+ dense_mask = getDenseMask(landmarks, h, w)
34
+ overlay = np.zeros([h, w, 3])
35
+
36
+ overlay[:,:,0] = img + 0.3 * (dense_mask == 1).astype('float') - 0.1 * (dense_mask == 2).astype('float')
37
+ overlay[:,:,1] = img + 0.3 * (dense_mask == 2).astype('float') - 0.1 * (dense_mask == 1).astype('float')
38
+ overlay[:,:,2] = img - 0.1 * (dense_mask == 1).astype('float') - 0.2 * (dense_mask == 2).astype('float')
39
+ overlay = np.clip(overlay, 0, 1)
40
+
41
+ return overlay
42
+
43
+ def plot_side_by_side_comparison(img_orig, means_orig, uncertainty_orig, img_corr, means_corr, uncertainty_corr):
44
+
45
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 7))
46
+
47
+ fig.set_constrained_layout(True)
48
+
49
+ vmax = max(np.max(np.mean(uncertainty_orig, axis=1)), np.max(np.mean(uncertainty_corr, axis=1)))
50
+
51
+ # --- Original ---
52
+ overlay_orig = create_overlay(img_orig, means_orig)
53
+ ax1.imshow(overlay_orig)
54
+ scatter1 = ax1.scatter(
55
+ means_orig[:, 0], means_orig[:, 1],
56
+ c=np.mean(uncertainty_orig, axis=1),
57
+ cmap='hot', s=50, vmin=0, vmax=vmax
58
+ )
59
+ ax1.set_title("Original", fontsize=16, pad=10)
60
+ ax1.axis('off')
61
+
62
+ # --- Corrupted ---
63
+ overlay_corr = create_overlay(img_corr, means_corr)
64
+ ax2.imshow(overlay_corr)
65
+ scatter2 = ax2.scatter(
66
+ means_corr[:, 0], means_corr[:, 1],
67
+ c=np.mean(uncertainty_corr, axis=1),
68
+ cmap='hot', s=50, vmin=0, vmax=vmax
69
+ )
70
+ ax2.set_title("Corrupted", fontsize=16, pad=10)
71
+ ax2.axis('off')
72
+
73
+ # Shared colorbar
74
+ cbar = fig.colorbar(scatter2, ax=[ax1, ax2], fraction=0.046, pad=0.01, shrink=0.85)
75
+ cbar.ax.tick_params(labelsize=10)
76
+
77
+ return fig
utils/segmentation.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import torch
4
+ import scipy.sparse as sp
5
+ import sys
6
+ import os
7
+ from zipfile import ZipFile
8
+ from .plotting import plot_side_by_side_comparison
9
+
10
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
+ from models.HybridGNet2IGSC import Hybrid
12
+
13
+ hybrid = None
14
+
15
+ def scipy_to_torch_sparse(scp_matrix):
16
+ values = scp_matrix.data
17
+ indices = np.vstack((scp_matrix.row, scp_matrix.col))
18
+ i = torch.LongTensor(indices)
19
+ v = torch.FloatTensor(values)
20
+ shape = scp_matrix.shape
21
+
22
+ sparse_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape))
23
+ return sparse_tensor
24
+
25
+ ## Adjacency Matrix
26
+ def mOrgan(N):
27
+ sub = np.zeros([N, N])
28
+ for i in range(0, N):
29
+ sub[i, i-1] = 1
30
+ sub[i, (i+1)%N] = 1
31
+ return sub
32
+
33
+ ## Downsampling Matrix
34
+ def mOrganD(N):
35
+ N2 = int(np.ceil(N/2))
36
+ sub = np.zeros([N2, N])
37
+
38
+ for i in range(0, N2):
39
+ if (2*i+1) == N:
40
+ sub[i, 2*i] = 1
41
+ else:
42
+ sub[i, 2*i] = 1/2
43
+ sub[i, 2*i+1] = 1/2
44
+
45
+ return sub
46
+
47
+ def mOrganU(N):
48
+ N2 = int(np.ceil(N/2))
49
+ sub = np.zeros([N, N2])
50
+
51
+ for i in range(0, N):
52
+ if i % 2 == 0:
53
+ sub[i, i//2] = 1
54
+ else:
55
+ sub[i, i//2] = 1/2
56
+ sub[i, (i//2 + 1) % N2] = 1/2
57
+
58
+ return sub
59
+
60
+ def genMatrixesLungsHeart():
61
+ RLUNG = 44
62
+ LLUNG = 50
63
+ HEART = 26
64
+
65
+ Asub1 = mOrgan(RLUNG)
66
+ Asub2 = mOrgan(LLUNG)
67
+ Asub3 = mOrgan(HEART)
68
+
69
+ ADsub1 = mOrgan(int(np.ceil(RLUNG / 2)))
70
+ ADsub2 = mOrgan(int(np.ceil(LLUNG / 2)))
71
+ ADsub3 = mOrgan(int(np.ceil(HEART / 2)))
72
+
73
+ Dsub1 = mOrganD(RLUNG)
74
+ Dsub2 = mOrganD(LLUNG)
75
+ Dsub3 = mOrganD(HEART)
76
+
77
+ Usub1 = mOrganU(RLUNG)
78
+ Usub2 = mOrganU(LLUNG)
79
+ Usub3 = mOrganU(HEART)
80
+
81
+ p1 = RLUNG
82
+ p2 = p1 + LLUNG
83
+ p3 = p2 + HEART
84
+
85
+ p1_ = int(np.ceil(RLUNG / 2))
86
+ p2_ = p1_ + int(np.ceil(LLUNG / 2))
87
+ p3_ = p2_ + int(np.ceil(HEART / 2))
88
+
89
+ A = np.zeros([p3, p3])
90
+
91
+ A[:p1, :p1] = Asub1
92
+ A[p1:p2, p1:p2] = Asub2
93
+ A[p2:p3, p2:p3] = Asub3
94
+
95
+ AD = np.zeros([p3_, p3_])
96
+
97
+ AD[:p1_, :p1_] = ADsub1
98
+ AD[p1_:p2_, p1_:p2_] = ADsub2
99
+ AD[p2_:p3_, p2_:p3_] = ADsub3
100
+
101
+ D = np.zeros([p3_, p3])
102
+
103
+ D[:p1_, :p1] = Dsub1
104
+ D[p1_:p2_, p1:p2] = Dsub2
105
+ D[p2_:p3_, p2:p3] = Dsub3
106
+
107
+ U = np.zeros([p3, p3_])
108
+
109
+ U[:p1, :p1_] = Usub1
110
+ U[p1:p2, p1_:p2_] = Usub2
111
+ U[p2:p3, p2_:p3_] = Usub3
112
+
113
+ return A, AD, D, U
114
+
115
+ def zip_files(files, output_name="complete_results.zip"):
116
+ with ZipFile(output_name, "w") as zipObj:
117
+ for file in files:
118
+ zipObj.write(file, arcname=file.split("/")[-1])
119
+ return output_name
120
+
121
+ def getMasks(landmarks, h, w):
122
+ RL, LL, H = landmarks[:44], landmarks[44:94], landmarks[94:]
123
+ RL_mask, LL_mask, H_mask = [np.zeros([h, w], dtype='uint8') for _ in range(3)]
124
+ RL_mask = cv2.drawContours(RL_mask, [RL.reshape(-1,1,2).astype('int')], -1, 255, -1)
125
+ LL_mask = cv2.drawContours(LL_mask, [LL.reshape(-1,1,2).astype('int')], -1, 255, -1)
126
+ H_mask = cv2.drawContours(H_mask, [H.reshape(-1,1,2).astype('int')], -1, 255, -1)
127
+ return RL_mask, LL_mask, H_mask
128
+
129
+ def pad_to_square(img):
130
+ h, w = img.shape[:2]
131
+ if h > w:
132
+ padw = h - w
133
+ auxw = padw % 2
134
+ img = np.pad(img, ((0,0),(padw//2, padw//2+auxw)), 'constant')
135
+ return img, (0, padw, 0, auxw)
136
+ else:
137
+ padh = w - h
138
+ auxh = padh % 2
139
+ img = np.pad(img, ((padh//2, padh//2+auxh),(0,0)), 'constant')
140
+ return img, (padh, 0, auxh, 0)
141
+
142
+ def preprocess(img):
143
+ img, padding = pad_to_square(img)
144
+ h, w = img.shape[:2]
145
+ if h != 1024 or w != 1024:
146
+ img = cv2.resize(img, (1024,1024), interpolation=cv2.INTER_CUBIC)
147
+ return img, (h, w, padding)
148
+
149
+ def removePreprocess(output, info):
150
+ h, w, padding = info
151
+ padh, padw, auxh, auxw = padding
152
+ if h != 1024 or w != 1024:
153
+ output = output * h
154
+ else:
155
+ output = output * 1024
156
+ output[:,:,0] -= padw//2
157
+ output[:,:,1] -= padh//2
158
+ return output
159
+
160
+ def loadModel(device):
161
+ global hybrid
162
+ A, AD, D, U = genMatrixesLungsHeart()
163
+ N1, N2 = A.shape[0], AD.shape[0]
164
+ A, AD, D, U = [sp.csc_matrix(x).tocoo() for x in [A, AD, D, U]]
165
+ D_, U_ = [D.copy()], [U.copy()]
166
+ A_ = [A.copy(), A.copy(), A.copy(), AD.copy(), AD.copy(), AD.copy()]
167
+ config = {'n_nodes':[N1,N1,N1,N2,N2,N2], 'latents':64, 'inputsize':1024,
168
+ 'filters':[2,32,32,32,16,16,16], 'skip_features':32, 'eval_sampling':True}
169
+ A_t, D_t, U_t = ([scipy_to_torch_sparse(x).to(device) for x in X] for X in (A_,D_,U_))
170
+ hybrid = Hybrid(config.copy(), D_t, U_t, A_t).to(device)
171
+ hybrid.load_state_dict(torch.load("weights/weights.pt", map_location=device))
172
+ hybrid.eval()
173
+ return hybrid
174
+
175
+ def predict_landmarks(img, n_samples=100):
176
+ global hybrid
177
+ img_proc, (h, w, padding) = preprocess(img)
178
+ data = torch.from_numpy(img_proc).unsqueeze(0).unsqueeze(0).to(next(hybrid.parameters()).device).float()
179
+ with torch.no_grad():
180
+ mu, log_var, conv6, conv5 = hybrid.encode(data)
181
+ zs = [hybrid.sampling(mu, log_var) for _ in range(n_samples)]
182
+ z_exp = torch.stack(zs, dim=0)
183
+ conv6_exp, conv5_exp = conv6.repeat(n_samples,1,1,1), conv5.repeat(n_samples,1,1,1)
184
+ output, _, _ = hybrid.decode(z_exp, conv6_exp, conv5_exp)
185
+ output = output.cpu().numpy().reshape(n_samples,-1,2)
186
+ output = removePreprocess(output, (h,w,padding)).astype('int')
187
+ means, stds = np.mean(output,axis=0), np.std(output,axis=0)
188
+ return means, stds
189
+
190
+
191
+ def segment(input_img, noise_std=0.0):
192
+ """
193
+ input_img: dict with keys "image" (numpy array) and optionally "mask"
194
+ noise_std: standard deviation of Gaussian noise to add for robustness
195
+ Returns: path to comparison figure, list of saved files
196
+ """
197
+ global hybrid
198
+
199
+ if hybrid is None:
200
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
201
+ hybrid = loadModel(device)
202
+
203
+ # Original image and corrupted version
204
+ img_orig = input_img["image"].astype(np.float32) / 255.0
205
+ mask = input_img.get("mask", None)
206
+ if mask is not None:
207
+ mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0
208
+ mask = 1.0 - mask
209
+ img_corr = np.minimum(img_orig, mask)
210
+ else:
211
+ img_corr = img_orig.copy()
212
+
213
+ if noise_std > 0:
214
+ noise = np.random.normal(0, noise_std, img_corr.shape)
215
+ img_corr = np.clip(img_corr + noise, 0.0, 1.0)
216
+
217
+ # Predict landmarks
218
+ means_orig, stds_orig = predict_landmarks(img_orig)
219
+ means_corr, stds_corr = predict_landmarks(img_corr)
220
+
221
+ # Save landmarks and masks
222
+ os.makedirs("tmp", exist_ok=True)
223
+
224
+ RL, LL, H = means_orig[:44], means_orig[44:94], means_orig[94:]
225
+ np.savetxt("tmp/RL_landmarks.txt", RL, delimiter=" ", fmt="%d")
226
+ np.savetxt("tmp/LL_landmarks.txt", LL, delimiter=" ", fmt="%d")
227
+ np.savetxt("tmp/H_landmarks.txt", H, delimiter=" ", fmt="%d")
228
+
229
+ RL_mask, LL_mask, H_mask = getMasks(means_orig, img_orig.shape[0], img_orig.shape[1])
230
+ cv2.imwrite("tmp/RL_mask.png", RL_mask)
231
+ cv2.imwrite("tmp/LL_mask.png", LL_mask)
232
+ cv2.imwrite("tmp/H_mask.png", H_mask)
233
+
234
+ RL_std, LL_std, H_std = stds_orig[:44], stds_orig[44:94], stds_orig[94:]
235
+ np.savetxt("tmp/RL_std.txt", RL_std, delimiter=" ", fmt="%.4f")
236
+ np.savetxt("tmp/LL_std.txt", LL_std, delimiter=" ", fmt="%.4f")
237
+ np.savetxt("tmp/H_std.txt", H_std, delimiter=" ", fmt="%.4f")
238
+
239
+ zipf = zip_files([
240
+ "tmp/RL_landmarks.txt","tmp/LL_landmarks.txt","tmp/H_landmarks.txt",
241
+ "tmp/RL_mask.png","tmp/LL_mask.png","tmp/H_mask.png",
242
+ "tmp/RL_std.txt","tmp/LL_std.txt","tmp/H_std.txt"
243
+ ])
244
+
245
+ # Optional: plot side-by-side comparison
246
+ fig = plot_side_by_side_comparison(img_orig, means_orig, stds_orig, img_corr, means_corr, stds_corr)
247
+ output_path = "tmp/segmentation_comparison.png"
248
+ fig.savefig(output_path, dpi=300)
249
+ import matplotlib.pyplot as plt
250
+ plt.close(fig)
251
+
252
+ saved_files = [
253
+ "tmp/RL_landmarks.txt","tmp/LL_landmarks.txt","tmp/H_landmarks.txt",
254
+ "tmp/RL_mask.png","tmp/LL_mask.png","tmp/H_mask.png",
255
+ "tmp/RL_std.txt","tmp/LL_std.txt","tmp/H_std.txt",
256
+ zipf
257
+ ]
258
+
259
+ return output_path, saved_files