suyoyog commited on
Commit
b47ff5a
·
verified ·
1 Parent(s): 6eb2296

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -131
app.py CHANGED
@@ -303,37 +303,37 @@ class Discriminator(nn.Module):
303
  return self.sigmoid(x)
304
 
305
 
306
- class vgg19(nn.Module):
307
- def __init__(self, pre_trained=True, require_grad=False):
308
- super(vgg19, self).__init__()
309
-
310
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
311
- num_gpus = torch.cuda.device_count()
312
- vgg_features = models.vgg19(pretrained=pre_trained).features
313
- self.seq_list = nn.ModuleList([nn.Sequential(ele) for ele in vgg_features])
314
-
315
- self.vgg_layer = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
316
- 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
317
- 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
318
- 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
319
- 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5']
320
-
321
- if not require_grad:
322
- for parameter in self.parameters():
323
- parameter.requires_grad = False
324
- if num_gpus == 2:
325
- self = DataParallel(self, device_ids=[0,1]).to(device)
326
- else:
327
- self.to(device) # Move the entire model to the selected device
328
-
329
- def forward(self, x):
330
- vgg_outputs = []
331
-
332
- for layer in self.seq_list:
333
- x = layer(x)
334
- vgg_outputs.append(x)
335
-
336
- return vgg_outputs
337
 
338
 
339
 
@@ -400,123 +400,123 @@ class TVLoss(nn.Module):
400
 
401
 
402
  # Create the directory if it doesn't exist
403
- output_directory = '/kaggle/working/model/'
404
- os.makedirs(output_directory, exist_ok=True)
405
 
406
 
407
 
408
- def train(config):
409
- # print("config : ", config)
410
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
411
- num_gpus = torch.cuda.device_count()
412
- transform = transforms.Compose([crop(config['scale'], config['patch_size']), augmentation()])
413
- dataset = mydata(GT_path=config['GT_path'], LR_path=config['LR_path'], in_memory=config['in_memory'], transform=transform)
414
- loader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'])
415
 
416
- generator = Generator()
417
 
418
- if config['fine_tuning']:
419
- generator.load_state_dict(torch.load(config['generator_path']))
420
- generator = nn.DataParallel(generator, device_ids=[0, 1]) if num_gpus == 2 else generator.to(device)
421
 
422
- generator.train()
423
 
424
- l2_loss = nn.MSELoss()
425
- g_optim = optim.Adam(generator.parameters(), lr=1e-4)
426
 
427
- pre_epoch = 0
428
- fine_epoch = 0
429
 
430
- #### Train using L2_loss
431
- while pre_epoch < config['pre_train_epoch']:
432
- for i, tr_data in enumerate(loader):
433
- gt = tr_data['GT'].to(device)
434
- lr = tr_data['LR'].to(device)
435
- output = generator(lr)
436
- loss = l2_loss(gt, output)
437
-
438
- g_optim.zero_grad()
439
- loss.backward()
440
- g_optim.step()
441
-
442
- pre_epoch += 1
443
-
444
- if pre_epoch % 2 == 0:
445
- print(pre_epoch)
446
- print(loss.item())
447
- print('=========')
448
-
449
- if pre_epoch % 4 == 0:
450
- torch.save(generator.state_dict(), '/kaggle/working/model/pre_trained_model_%03d.pt' % pre_epoch)
451
-
452
- #### Train using perceptual & adversarial loss
453
- vgg_net = vgg19().to(device)
454
- vgg_net = vgg_net.eval()
455
- discriminator = Discriminator()
456
- if num_gpus == 2:
457
- discriminator = DataParallel(discriminator, device_ids = [0,1]).to(device)
458
- else:
459
- discriminator = discriminator.to(device)
460
- discriminator.train()
461
 
462
- d_optim = optim.Adam(discriminator.parameters(), lr=1e-4)
463
- scheduler = optim.lr_scheduler.StepLR(g_optim, step_size=2000, gamma=0.1)
464
 
465
- VGG_loss = perceptual_loss(vgg_net)
466
- cross_ent = nn.BCELoss()
467
- tv_loss = TVLoss()
468
- real_label = torch.ones((gt.size(0), 2)).to(device)
469
- fake_label = torch.zeros((gt.size(0), 2)).to(device)
470
-
471
- while fine_epoch < config['fine_train_epoch']:
472
- scheduler.step()
473
-
474
- for i, tr_data in enumerate(loader):
475
- gt = tr_data['GT'].to(device)
476
- lr = tr_data['LR'].to(device)
477
 
478
- ## Training Discriminator
479
- output = generator(lr)
480
- fake_prob = discriminator(output, lr)
481
- real_prob = discriminator(gt, lr)
482
- d_loss_real = cross_ent(real_prob, real_label)
483
- d_loss_fake = cross_ent(fake_prob, fake_label)
484
 
485
- d_loss = d_loss_real + d_loss_fake
486
 
487
- g_optim.zero_grad()
488
- d_optim.zero_grad()
489
- d_loss.backward()
490
- d_optim.step()
491
 
492
- output = generator(lr)
493
- fake_prob = discriminator(output, lr)
494
 
495
- _percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0, layer=config['feat_layer'])
496
 
497
- L2_loss = l2_loss(output, gt)
498
- percep_loss = config['vgg_rescale_coeff'] * _percep_loss
499
- adversarial_loss = config['adv_coeff'] * cross_ent(fake_prob, real_label)
500
- total_variance_loss = config['tv_loss_coeff'] * tv_loss(config['vgg_rescale_coeff'] * (hr_feat - sr_feat)**2)
501
 
502
- g_loss = percep_loss + adversarial_loss + total_variance_loss + L2_loss
503
 
504
- g_optim.zero_grad()
505
- d_optim.zero_grad()
506
- g_loss.backward()
507
- g_optim.step()
508
 
509
- fine_epoch += 1
510
 
511
- if fine_epoch % 2 == 0:
512
- print(fine_epoch)
513
- print(g_loss.item())
514
- print(d_loss.item())
515
- print('=========')
516
 
517
- if fine_epoch % 4 == 0:
518
- torch.save(generator.state_dict(), '/kaggle/working/model/MedSRGAN_gene_%03d.pt' % fine_epoch)
519
- torch.save(discriminator.state_dict(), '/kaggle/working/model/MedSRGAN_discrim_%03d.pt' % fine_epoch)
520
 
521
 
522
 
@@ -545,14 +545,14 @@ config = {
545
 
546
 
547
  # List all folders in /kaggle/input/
548
- input_directory = '/kaggle/input/lumber-spine-dataset/'
549
  # input_directory = '/kaggle/input/custom-dataset/custom_dataset/'
550
- all_folders = [f for f in os.listdir(input_directory) if os.path.isdir(os.path.join(input_directory, f))]
551
 
552
  # Print the list of folders
553
- print("All folders in /kaggle/input/:")
554
- for folder in all_folders:
555
- print(folder)
556
 
557
  #train(config)
558
 
 
303
  return self.sigmoid(x)
304
 
305
 
306
+ # class vgg19(nn.Module):
307
+ # def __init__(self, pre_trained=True, require_grad=False):
308
+ # super(vgg19, self).__init__()
309
+
310
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
311
+ # num_gpus = torch.cuda.device_count()
312
+ # vgg_features = models.vgg19(pretrained=pre_trained).features
313
+ # self.seq_list = nn.ModuleList([nn.Sequential(ele) for ele in vgg_features])
314
+
315
+ # self.vgg_layer = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
316
+ # 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
317
+ # 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
318
+ # 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
319
+ # 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5']
320
+
321
+ # if not require_grad:
322
+ # for parameter in self.parameters():
323
+ # parameter.requires_grad = False
324
+ # if num_gpus == 2:
325
+ # self = DataParallel(self, device_ids=[0,1]).to(device)
326
+ # else:
327
+ # self.to(device) # Move the entire model to the selected device
328
+
329
+ # def forward(self, x):
330
+ # vgg_outputs = []
331
+
332
+ # for layer in self.seq_list:
333
+ # x = layer(x)
334
+ # vgg_outputs.append(x)
335
+
336
+ # return vgg_outputs
337
 
338
 
339
 
 
400
 
401
 
402
  # Create the directory if it doesn't exist
403
+ #output_directory = '/kaggle/working/model/'
404
+ #os.makedirs(output_directory, exist_ok=True)
405
 
406
 
407
 
408
+ # def train(config):
409
+ # # print("config : ", config)
410
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
411
+ # num_gpus = torch.cuda.device_count()
412
+ # transform = transforms.Compose([crop(config['scale'], config['patch_size']), augmentation()])
413
+ # dataset = mydata(GT_path=config['GT_path'], LR_path=config['LR_path'], in_memory=config['in_memory'], transform=transform)
414
+ # loader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'])
415
 
416
+ # generator = Generator()
417
 
418
+ # if config['fine_tuning']:
419
+ # generator.load_state_dict(torch.load(config['generator_path']))
420
+ # generator = nn.DataParallel(generator, device_ids=[0, 1]) if num_gpus == 2 else generator.to(device)
421
 
422
+ # generator.train()
423
 
424
+ # l2_loss = nn.MSELoss()
425
+ # g_optim = optim.Adam(generator.parameters(), lr=1e-4)
426
 
427
+ # pre_epoch = 0
428
+ # fine_epoch = 0
429
 
430
+ # #### Train using L2_loss
431
+ # while pre_epoch < config['pre_train_epoch']:
432
+ # for i, tr_data in enumerate(loader):
433
+ # gt = tr_data['GT'].to(device)
434
+ # lr = tr_data['LR'].to(device)
435
+ # output = generator(lr)
436
+ # loss = l2_loss(gt, output)
437
+
438
+ # g_optim.zero_grad()
439
+ # loss.backward()
440
+ # g_optim.step()
441
+
442
+ # pre_epoch += 1
443
+
444
+ # if pre_epoch % 2 == 0:
445
+ # print(pre_epoch)
446
+ # print(loss.item())
447
+ # print('=========')
448
+
449
+ # if pre_epoch % 4 == 0:
450
+ # torch.save(generator.state_dict(), '/kaggle/working/model/pre_trained_model_%03d.pt' % pre_epoch)
451
+
452
+ # #### Train using perceptual & adversarial loss
453
+ # vgg_net = vgg19().to(device)
454
+ # vgg_net = vgg_net.eval()
455
+ # discriminator = Discriminator()
456
+ # if num_gpus == 2:
457
+ # discriminator = DataParallel(discriminator, device_ids = [0,1]).to(device)
458
+ # else:
459
+ # discriminator = discriminator.to(device)
460
+ # discriminator.train()
461
 
462
+ # d_optim = optim.Adam(discriminator.parameters(), lr=1e-4)
463
+ # scheduler = optim.lr_scheduler.StepLR(g_optim, step_size=2000, gamma=0.1)
464
 
465
+ # VGG_loss = perceptual_loss(vgg_net)
466
+ # cross_ent = nn.BCELoss()
467
+ # tv_loss = TVLoss()
468
+ # real_label = torch.ones((gt.size(0), 2)).to(device)
469
+ # fake_label = torch.zeros((gt.size(0), 2)).to(device)
470
+
471
+ # while fine_epoch < config['fine_train_epoch']:
472
+ # scheduler.step()
473
+
474
+ # for i, tr_data in enumerate(loader):
475
+ # gt = tr_data['GT'].to(device)
476
+ # lr = tr_data['LR'].to(device)
477
 
478
+ # ## Training Discriminator
479
+ # output = generator(lr)
480
+ # fake_prob = discriminator(output, lr)
481
+ # real_prob = discriminator(gt, lr)
482
+ # d_loss_real = cross_ent(real_prob, real_label)
483
+ # d_loss_fake = cross_ent(fake_prob, fake_label)
484
 
485
+ # d_loss = d_loss_real + d_loss_fake
486
 
487
+ # g_optim.zero_grad()
488
+ # d_optim.zero_grad()
489
+ # d_loss.backward()
490
+ # d_optim.step()
491
 
492
+ # output = generator(lr)
493
+ # fake_prob = discriminator(output, lr)
494
 
495
+ # _percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0, layer=config['feat_layer'])
496
 
497
+ # L2_loss = l2_loss(output, gt)
498
+ # percep_loss = config['vgg_rescale_coeff'] * _percep_loss
499
+ # adversarial_loss = config['adv_coeff'] * cross_ent(fake_prob, real_label)
500
+ # total_variance_loss = config['tv_loss_coeff'] * tv_loss(config['vgg_rescale_coeff'] * (hr_feat - sr_feat)**2)
501
 
502
+ # g_loss = percep_loss + adversarial_loss + total_variance_loss + L2_loss
503
 
504
+ # g_optim.zero_grad()
505
+ # d_optim.zero_grad()
506
+ # g_loss.backward()
507
+ # g_optim.step()
508
 
509
+ # fine_epoch += 1
510
 
511
+ # if fine_epoch % 2 == 0:
512
+ # print(fine_epoch)
513
+ # print(g_loss.item())
514
+ # print(d_loss.item())
515
+ # print('=========')
516
 
517
+ # if fine_epoch % 4 == 0:
518
+ # torch.save(generator.state_dict(), '/kaggle/working/model/MedSRGAN_gene_%03d.pt' % fine_epoch)
519
+ # torch.save(discriminator.state_dict(), '/kaggle/working/model/MedSRGAN_discrim_%03d.pt' % fine_epoch)
520
 
521
 
522
 
 
545
 
546
 
547
  # List all folders in /kaggle/input/
548
+ #input_directory = '/kaggle/input/lumber-spine-dataset/'
549
  # input_directory = '/kaggle/input/custom-dataset/custom_dataset/'
550
+ #all_folders = [f for f in os.listdir(input_directory) if os.path.isdir(os.path.join(input_directory, f))]
551
 
552
  # Print the list of folders
553
+ # print("All folders in /kaggle/input/:")
554
+ # for folder in all_folders:
555
+ # print(folder)
556
 
557
  #train(config)
558