JoyfulCalf commited on
Commit
bc2ef1d
·
verified ·
1 Parent(s): 7ee2c2d

Update unet.py

Browse files
Files changed (1) hide show
  1. unet.py +558 -0
unet.py CHANGED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ import glob
6
+ #import SimpleITK as sitk
7
+ from torch import optim
8
+ import torch.utils.data
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ import torch.nn
13
+ import torchvision
14
+ import matplotlib.pyplot as plt
15
+ import natsort
16
+ from torch.utils.data.sampler import SubsetRandomSampler
17
+ from Data_Loader import Images_Dataset, Images_Dataset_folder
18
+ import torchsummary
19
+ #from torch.utils.tensorboard import SummaryWriter
20
+ #from tensorboardX import SummaryWriter
21
+
22
+ import shutil
23
+ import random
24
+ from Models import Unet_dict, NestedUNet, U_Net, R2U_Net, AttU_Net, R2AttU_Net
25
+ from losses import calc_loss, dice_loss, threshold_predictions_v,threshold_predictions_p
26
+ from ploting import plot_kernels, LayerActivations, input_images, plot_grad_flow
27
+ from Metrics import dice_coeff, accuracy_score
28
+ import time
29
+ #from ploting import VisdomLinePlotter
30
+ #from visdom import Visdom
31
+
32
+
33
+ #######################################################
34
+ #Checking if GPU is used
35
+ #######################################################
36
+
37
+ train_on_gpu = torch.cuda.is_available()
38
+
39
+ if not train_on_gpu:
40
+ print('CUDA is not available. Training on CPU')
41
+ else:
42
+ print('CUDA is available. Training on GPU')
43
+
44
+ device = torch.device("cuda:0" if train_on_gpu else "cpu")
45
+
46
+ #######################################################
47
+ #Setting the basic paramters of the model
48
+ #######################################################
49
+
50
+ batch_size = 4
51
+ print('batch_size = ' + str(batch_size))
52
+
53
+ valid_size = 0.15
54
+
55
+ epoch = 15
56
+ print('epoch = ' + str(epoch))
57
+
58
+ random_seed = random.randint(1, 100)
59
+ print('random_seed = ' + str(random_seed))
60
+
61
+ shuffle = True
62
+ valid_loss_min = np.Inf
63
+ num_workers = 4
64
+ lossT = []
65
+ lossL = []
66
+ lossL.append(np.inf)
67
+ lossT.append(np.inf)
68
+ epoch_valid = epoch-2
69
+ n_iter = 1
70
+ i_valid = 0
71
+
72
+ pin_memory = False
73
+ if train_on_gpu:
74
+ pin_memory = True
75
+
76
+ #plotter = VisdomLinePlotter(env_name='Tutorial Plots')
77
+
78
+ #######################################################
79
+ #Setting up the model
80
+ #######################################################
81
+
82
+ model_Inputs = [U_Net, R2U_Net, AttU_Net, R2AttU_Net, NestedUNet]
83
+
84
+
85
+ def model_unet(model_input, in_channel=3, out_channel=1):
86
+ model_test = model_input(in_channel, out_channel)
87
+ return model_test
88
+
89
+ #passsing this string so that if it's AttU_Net or R2ATTU_Net it doesn't throw an error at torchSummary
90
+
91
+
92
+ model_test = model_unet(model_Inputs[0], 3, 1)
93
+
94
+ model_test.to(device)
95
+
96
+ #######################################################
97
+ #Getting the Summary of Model
98
+ #######################################################
99
+
100
+ torchsummary.summary(model_test, input_size=(3, 128, 128))
101
+
102
+ #######################################################
103
+ #Passing the Dataset of Images and Labels
104
+ #######################################################
105
+
106
+ t_data = '/flush1/bat161/segmentation/New_Trails/venv/DATA/new_3C_I_ori/'
107
+ l_data = '/flush1/bat161/segmentation/New_Trails/venv/DATA/new_3C_L_ori/'
108
+ test_image = '/flush1/bat161/segmentation/New_Trails/venv/DATA/test_new_3C_I_ori/0131_0009.png'
109
+ test_label = '/flush1/bat161/segmentation/New_Trails/venv/DATA/test_new_3C_L_ori/0131_0009.png'
110
+ test_folderP = '/flush1/bat161/segmentation/New_Trails/venv/DATA/test_new_3C_I_ori/*'
111
+ test_folderL = '/flush1/bat161/segmentation/New_Trails/venv/DATA/test_new_3C_L_ori/*'
112
+
113
+ Training_Data = Images_Dataset_folder(t_data,
114
+ l_data)
115
+
116
+ #######################################################
117
+ #Giving a transformation for input data
118
+ #######################################################
119
+
120
+ data_transform = torchvision.transforms.Compose([
121
+ # torchvision.transforms.Resize((128,128)),
122
+ # torchvision.transforms.CenterCrop(96),
123
+ torchvision.transforms.ToTensor(),
124
+ torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
125
+ ])
126
+
127
+ #######################################################
128
+ #Trainging Validation Split
129
+ #######################################################
130
+
131
+ num_train = len(Training_Data)
132
+ indices = list(range(num_train))
133
+ split = int(np.floor(valid_size * num_train))
134
+
135
+ if shuffle:
136
+ np.random.seed(random_seed)
137
+ np.random.shuffle(indices)
138
+
139
+ train_idx, valid_idx = indices[split:], indices[:split]
140
+ train_sampler = SubsetRandomSampler(train_idx)
141
+ valid_sampler = SubsetRandomSampler(valid_idx)
142
+
143
+ train_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=train_sampler,
144
+ num_workers=num_workers, pin_memory=pin_memory,)
145
+
146
+ valid_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=valid_sampler,
147
+ num_workers=num_workers, pin_memory=pin_memory,)
148
+
149
+ #######################################################
150
+ #Using Adam as Optimizer
151
+ #######################################################
152
+
153
+ initial_lr = 0.001
154
+ opt = torch.optim.Adam(model_test.parameters(), lr=initial_lr) # try SGD
155
+ #opt = optim.SGD(model_test.parameters(), lr = initial_lr, momentum=0.99)
156
+
157
+ MAX_STEP = int(1e10)
158
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, MAX_STEP, eta_min=1e-5)
159
+ #scheduler = optim.lr_scheduler.CosineAnnealingLr(opt, epoch, 1)
160
+
161
+ #######################################################
162
+ #Writing the params to tensorboard
163
+ #######################################################
164
+
165
+ #writer1 = SummaryWriter()
166
+ #dummy_inp = torch.randn(1, 3, 128, 128)
167
+ #model_test.to('cpu')
168
+ #writer1.add_graph(model_test, model_test(torch.randn(3, 3, 128, 128, requires_grad=True)))
169
+ #model_test.to(device)
170
+
171
+ #######################################################
172
+ #Creating a Folder for every data of the program
173
+ #######################################################
174
+
175
+ New_folder = './model'
176
+
177
+ if os.path.exists(New_folder) and os.path.isdir(New_folder):
178
+ shutil.rmtree(New_folder)
179
+
180
+ try:
181
+ os.mkdir(New_folder)
182
+ except OSError:
183
+ print("Creation of the main directory '%s' failed " % New_folder)
184
+ else:
185
+ print("Successfully created the main directory '%s' " % New_folder)
186
+
187
+ #######################################################
188
+ #Setting the folder of saving the predictions
189
+ #######################################################
190
+
191
+ read_pred = './model/pred'
192
+
193
+ #######################################################
194
+ #Checking if prediction folder exixts
195
+ #######################################################
196
+
197
+ if os.path.exists(read_pred) and os.path.isdir(read_pred):
198
+ shutil.rmtree(read_pred)
199
+
200
+ try:
201
+ os.mkdir(read_pred)
202
+ except OSError:
203
+ print("Creation of the prediction directory '%s' failed of dice loss" % read_pred)
204
+ else:
205
+ print("Successfully created the prediction directory '%s' of dice loss" % read_pred)
206
+
207
+ #######################################################
208
+ #checking if the model exists and if true then delete
209
+ #######################################################
210
+
211
+ read_model_path = './model/Unet_D_' + str(epoch) + '_' + str(batch_size)
212
+
213
+ if os.path.exists(read_model_path) and os.path.isdir(read_model_path):
214
+ shutil.rmtree(read_model_path)
215
+ print('Model folder there, so deleted for newer one')
216
+
217
+ try:
218
+ os.mkdir(read_model_path)
219
+ except OSError:
220
+ print("Creation of the model directory '%s' failed" % read_model_path)
221
+ else:
222
+ print("Successfully created the model directory '%s' " % read_model_path)
223
+
224
+ #######################################################
225
+ #Training loop
226
+ #######################################################
227
+
228
+ for i in range(epoch):
229
+
230
+ train_loss = 0.0
231
+ valid_loss = 0.0
232
+ since = time.time()
233
+ scheduler.step(i)
234
+ lr = scheduler.get_lr()
235
+
236
+ #######################################################
237
+ #Training Data
238
+ #######################################################
239
+
240
+ model_test.train()
241
+ k = 1
242
+
243
+ for x, y in train_loader:
244
+ x, y = x.to(device), y.to(device)
245
+
246
+ #If want to get the input images with their Augmentation - To check the data flowing in net
247
+ input_images(x, y, i, n_iter, k)
248
+
249
+ # grid_img = torchvision.utils.make_grid(x)
250
+ #writer1.add_image('images', grid_img, 0)
251
+
252
+ # grid_lab = torchvision.utils.make_grid(y)
253
+
254
+ opt.zero_grad()
255
+
256
+ y_pred = model_test(x)
257
+ lossT = calc_loss(y_pred, y) # Dice_loss Used
258
+
259
+ train_loss += lossT.item() * x.size(0)
260
+ lossT.backward()
261
+ # plot_grad_flow(model_test.named_parameters(), n_iter)
262
+ opt.step()
263
+ x_size = lossT.item() * x.size(0)
264
+ k = 2
265
+
266
+ # for name, param in model_test.named_parameters():
267
+ # name = name.replace('.', '/')
268
+ # writer1.add_histogram(name, param.data.cpu().numpy(), i + 1)
269
+ # writer1.add_histogram(name + '/grad', param.grad.data.cpu().numpy(), i + 1)
270
+
271
+
272
+ #######################################################
273
+ #Validation Step
274
+ #######################################################
275
+
276
+ model_test.eval()
277
+ torch.no_grad() #to increase the validation process uses less memory
278
+
279
+ for x1, y1 in valid_loader:
280
+ x1, y1 = x1.to(device), y1.to(device)
281
+
282
+ y_pred1 = model_test(x1)
283
+ lossL = calc_loss(y_pred1, y1) # Dice_loss Used
284
+
285
+ valid_loss += lossL.item() * x1.size(0)
286
+ x_size1 = lossL.item() * x1.size(0)
287
+
288
+ #######################################################
289
+ #Saving the predictions
290
+ #######################################################
291
+
292
+ im_tb = Image.open(test_image)
293
+ im_label = Image.open(test_label)
294
+ s_tb = data_transform(im_tb)
295
+ s_label = data_transform(im_label)
296
+ s_label = s_label.detach().numpy()
297
+
298
+ pred_tb = model_test(s_tb.unsqueeze(0).to(device)).cpu()
299
+ pred_tb = F.sigmoid(pred_tb)
300
+ pred_tb = pred_tb.detach().numpy()
301
+
302
+ #pred_tb = threshold_predictions_v(pred_tb)
303
+
304
+ x1 = plt.imsave(
305
+ './model/pred/img_iteration_' + str(n_iter) + '_epoch_'
306
+ + str(i) + '.png', pred_tb[0][0])
307
+
308
+ # accuracy = accuracy_score(pred_tb[0][0], s_label)
309
+
310
+ #######################################################
311
+ #To write in Tensorboard
312
+ #######################################################
313
+
314
+ train_loss = train_loss / len(train_idx)
315
+ valid_loss = valid_loss / len(valid_idx)
316
+
317
+ if (i+1) % 1 == 0:
318
+ print('Epoch: {}/{} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(i + 1, epoch, train_loss,
319
+ valid_loss))
320
+ # writer1.add_scalar('Train Loss', train_loss, n_iter)
321
+ # writer1.add_scalar('Validation Loss', valid_loss, n_iter)
322
+ #writer1.add_image('Pred', pred_tb[0]) #try to get output of shape 3
323
+
324
+
325
+ #######################################################
326
+ #Early Stopping
327
+ #######################################################
328
+
329
+ if valid_loss <= valid_loss_min and epoch_valid >= i: # and i_valid <= 2:
330
+
331
+ print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model '.format(valid_loss_min, valid_loss))
332
+ torch.save(model_test.state_dict(),'./model/Unet_D_' +
333
+ str(epoch) + '_' + str(batch_size) + '/Unet_epoch_' + str(epoch)
334
+ + '_batchsize_' + str(batch_size) + '.pth')
335
+ # print(accuracy)
336
+ if round(valid_loss, 4) == round(valid_loss_min, 4):
337
+ print(i_valid)
338
+ i_valid = i_valid+1
339
+ valid_loss_min = valid_loss
340
+ #if i_valid ==3:
341
+ # break
342
+
343
+ #######################################################
344
+ # Extracting the intermediate layers
345
+ #######################################################
346
+
347
+ #####################################
348
+ # for kernals
349
+ #####################################
350
+ x1 = torch.nn.ModuleList(model_test.children())
351
+ # x2 = torch.nn.ModuleList(x1[16].children())
352
+ #x3 = torch.nn.ModuleList(x2[0].children())
353
+
354
+ #To get filters in the layers
355
+ #plot_kernels(x1.weight.detach().cpu(), 7)
356
+
357
+ #####################################
358
+ # for images
359
+ #####################################
360
+ x2 = len(x1)
361
+ dr = LayerActivations(x1[x2-1]) #Getting the last Conv Layer
362
+
363
+ img = Image.open(test_image)
364
+ s_tb = data_transform(img)
365
+
366
+ pred_tb = model_test(s_tb.unsqueeze(0).to(device)).cpu()
367
+ pred_tb = F.sigmoid(pred_tb)
368
+ pred_tb = pred_tb.detach().numpy()
369
+
370
+ plot_kernels(dr.features, n_iter, 7, cmap="rainbow")
371
+
372
+ time_elapsed = time.time() - since
373
+ print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
374
+ n_iter += 1
375
+
376
+ #######################################################
377
+ #closing the tensorboard writer
378
+ #######################################################
379
+
380
+ #writer1.close()
381
+
382
+ #######################################################
383
+ #if using dict
384
+ #######################################################
385
+
386
+ #model_test.filter_dict
387
+
388
+ #######################################################
389
+ #Loading the model
390
+ #######################################################
391
+
392
+ test1 =model_test.load_state_dict(torch.load('./model/Unet_D_' +
393
+ str(epoch) + '_' + str(batch_size)+ '/Unet_epoch_' + str(epoch)
394
+ + '_batchsize_' + str(batch_size) + '.pth'))
395
+
396
+
397
+ #######################################################
398
+ #checking if cuda is available
399
+ #######################################################
400
+
401
+ if torch.cuda.is_available():
402
+ torch.cuda.empty_cache()
403
+
404
+ #######################################################
405
+ #Loading the model
406
+ #######################################################
407
+
408
+ model_test.load_state_dict(torch.load('./model/Unet_D_' +
409
+ str(epoch) + '_' + str(batch_size)+ '/Unet_epoch_' + str(epoch)
410
+ + '_batchsize_' + str(batch_size) + '.pth'))
411
+
412
+ model_test.eval()
413
+
414
+ #######################################################
415
+ #opening the test folder and creating a folder for generated images
416
+ #######################################################
417
+
418
+ read_test_folder = glob.glob(test_folderP)
419
+ x_sort_test = natsort.natsorted(read_test_folder) # To sort
420
+
421
+
422
+ read_test_folder112 = './model/gen_images'
423
+
424
+
425
+ if os.path.exists(read_test_folder112) and os.path.isdir(read_test_folder112):
426
+ shutil.rmtree(read_test_folder112)
427
+
428
+ try:
429
+ os.mkdir(read_test_folder112)
430
+ except OSError:
431
+ print("Creation of the testing directory %s failed" % read_test_folder112)
432
+ else:
433
+ print("Successfully created the testing directory %s " % read_test_folder112)
434
+
435
+
436
+ #For Prediction Threshold
437
+
438
+ read_test_folder_P_Thres = './model/pred_threshold'
439
+
440
+
441
+ if os.path.exists(read_test_folder_P_Thres) and os.path.isdir(read_test_folder_P_Thres):
442
+ shutil.rmtree(read_test_folder_P_Thres)
443
+
444
+ try:
445
+ os.mkdir(read_test_folder_P_Thres)
446
+ except OSError:
447
+ print("Creation of the testing directory %s failed" % read_test_folder_P_Thres)
448
+ else:
449
+ print("Successfully created the testing directory %s " % read_test_folder_P_Thres)
450
+
451
+ #For Label Threshold
452
+
453
+ read_test_folder_L_Thres = './model/label_threshold'
454
+
455
+
456
+ if os.path.exists(read_test_folder_L_Thres) and os.path.isdir(read_test_folder_L_Thres):
457
+ shutil.rmtree(read_test_folder_L_Thres)
458
+
459
+ try:
460
+ os.mkdir(read_test_folder_L_Thres)
461
+ except OSError:
462
+ print("Creation of the testing directory %s failed" % read_test_folder_L_Thres)
463
+ else:
464
+ print("Successfully created the testing directory %s " % read_test_folder_L_Thres)
465
+
466
+
467
+
468
+
469
+ #######################################################
470
+ #saving the images in the files
471
+ #######################################################
472
+
473
+ img_test_no = 0
474
+
475
+ for i in range(len(read_test_folder)):
476
+ im = Image.open(x_sort_test[i])
477
+
478
+ im1 = im
479
+ im_n = np.array(im1)
480
+ im_n_flat = im_n.reshape(-1, 1)
481
+
482
+ for j in range(im_n_flat.shape[0]):
483
+ if im_n_flat[j] != 0:
484
+ im_n_flat[j] = 255
485
+
486
+ s = data_transform(im)
487
+ pred = model_test(s.unsqueeze(0).cuda()).cpu()
488
+ pred = F.sigmoid(pred)
489
+ pred = pred.detach().numpy()
490
+
491
+ # pred = threshold_predictions_p(pred) #Value kept 0.01 as max is 1 and noise is very small.
492
+
493
+ if i % 24 == 0:
494
+ img_test_no = img_test_no + 1
495
+
496
+ x1 = plt.imsave('./model/gen_images/im_epoch_' + str(epoch) + 'int_' + str(i)
497
+ + '_img_no_' + str(img_test_no) + '.png', pred[0][0])
498
+
499
+
500
+ ####################################################
501
+ #Calculating the Dice Score
502
+ ####################################################
503
+
504
+ data_transform = torchvision.transforms.Compose([
505
+ # torchvision.transforms.Resize((128,128)),
506
+ # torchvision.transforms.CenterCrop(96),
507
+ torchvision.transforms.Grayscale(),
508
+ # torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
509
+ ])
510
+
511
+
512
+
513
+ read_test_folderP = glob.glob('./model/gen_images/*')
514
+ x_sort_testP = natsort.natsorted(read_test_folderP)
515
+
516
+
517
+ read_test_folderL = glob.glob(test_folderL)
518
+ x_sort_testL = natsort.natsorted(read_test_folderL) # To sort
519
+
520
+
521
+ dice_score123 = 0.0
522
+ x_count = 0
523
+ x_dice = 0
524
+
525
+ for i in range(len(read_test_folderP)):
526
+
527
+ x = Image.open(x_sort_testP[i])
528
+ s = data_transform(x)
529
+ s = np.array(s)
530
+ s = threshold_predictions_v(s)
531
+
532
+ #save the images
533
+ x1 = plt.imsave('./model/pred_threshold/im_epoch_' + str(epoch) + 'int_' + str(i)
534
+ + '_img_no_' + str(img_test_no) + '.png', s)
535
+
536
+ y = Image.open(x_sort_testL[i])
537
+ s2 = data_transform(y)
538
+ s3 = np.array(s2)
539
+ # s2 =threshold_predictions_v(s2)
540
+
541
+ #save the Images
542
+ y1 = plt.imsave('./model/label_threshold/im_epoch_' + str(epoch) + 'int_' + str(i)
543
+ + '_img_no_' + str(img_test_no) + '.png', s3)
544
+
545
+ total = dice_coeff(s, s3)
546
+ print(total)
547
+
548
+ if total <= 0.3:
549
+ x_count += 1
550
+ if total > 0.3:
551
+ x_dice = x_dice + total
552
+ dice_score123 = dice_score123 + total
553
+
554
+
555
+ print('Dice Score : ' + str(dice_score123/len(read_test_folderP)))
556
+ #print(x_count)
557
+ #print(x_dice)
558
+ #print('Dice Score : ' + str(float(x_dice/(len(read_test_folderP)-x_count))))