suyoyog commited on
Commit
a10d7c5
·
verified ·
1 Parent(s): c4f2c70

may be final

Browse files
Files changed (1) hide show
  1. app.py +26 -256
app.py CHANGED
@@ -112,6 +112,9 @@ class crop(object):
112
 
113
  return {'LR' : LR_patch, 'GT' : GT_patch}
114
 
 
 
 
115
  class augmentation(object):
116
 
117
  def __call__(self, sample):
@@ -120,8 +123,6 @@ class augmentation(object):
120
  hor_flip = random.randrange(0,2)
121
  ver_flip = random.randrange(0,2)
122
  rot = random.randrange(0,2)
123
-
124
- # print(f"Horizontal Flip: {hor_flip}, Vertical Flip: {ver_flip}, Rotation: {rot}")
125
  if hor_flip:
126
  temp_LR = np.fliplr(LR_img)
127
  LR_img = temp_LR.copy()
@@ -253,6 +254,7 @@ class Discriminator(nn.Module):
253
  self.block_1_2 = D_Block(64, 128, stride=1)
254
  self.block_1_3 = D_Block(128, 128)
255
 
 
256
  # Layer for LR image size
257
  self.conv_2_1 = nn.Sequential(
258
  nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1), nn.LeakyReLU()
@@ -294,48 +296,12 @@ class Discriminator(nn.Module):
294
  )
295
 
296
  x = self.flatten(x)
297
-
298
- # print(x.shape)
299
-
300
  x = self.fc1(x)
301
  x = self.fc2(self.relu(x))
302
- # outputs the range of 0 to 1
303
- # print(x.shape)
304
  return self.sigmoid(x)
305
 
306
 
307
- # class vgg19(nn.Module):
308
- # def __init__(self, pre_trained=True, require_grad=False):
309
- # super(vgg19, self).__init__()
310
-
311
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
312
- # num_gpus = torch.cuda.device_count()
313
- # vgg_features = models.vgg19(pretrained=pre_trained).features
314
- # self.seq_list = nn.ModuleList([nn.Sequential(ele) for ele in vgg_features])
315
-
316
- # self.vgg_layer = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
317
- # 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
318
- # 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
319
- # 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
320
- # 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5']
321
-
322
- # if not require_grad:
323
- # for parameter in self.parameters():
324
- # parameter.requires_grad = False
325
- # if num_gpus == 2:
326
- # self = DataParallel(self, device_ids=[0,1]).to(device)
327
- # else:
328
- # self.to(device) # Move the entire model to the selected device
329
-
330
- # def forward(self, x):
331
- # vgg_outputs = []
332
-
333
- # for layer in self.seq_list:
334
- # x = layer(x)
335
- # vgg_outputs.append(x)
336
-
337
- # return vgg_outputs
338
-
339
 
340
 
341
  class MeanShift(nn.Conv2d):
@@ -353,6 +319,7 @@ class MeanShift(nn.Conv2d):
353
  p.requires_grad = False
354
 
355
 
 
356
  class perceptual_loss(nn.Module):
357
  def __init__(self, vgg):
358
  super(perceptual_loss, self).__init__()
@@ -398,134 +365,12 @@ class TVLoss(nn.Module):
398
  @staticmethod
399
  def tensor_size(t):
400
  return t.size()[1] * t.size()[2] * t.size()[3]
401
-
402
-
403
- # Create the directory if it doesn't exist
404
- #output_directory = '/kaggle/working/model/'
405
- #os.makedirs(output_directory, exist_ok=True)
406
-
407
-
408
-
409
- # def train(config):
410
- # # print("config : ", config)
411
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
412
- # num_gpus = torch.cuda.device_count()
413
- # transform = transforms.Compose([crop(config['scale'], config['patch_size']), augmentation()])
414
- # dataset = mydata(GT_path=config['GT_path'], LR_path=config['LR_path'], in_memory=config['in_memory'], transform=transform)
415
- # loader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'])
416
-
417
- # generator = Generator()
418
-
419
- # if config['fine_tuning']:
420
- # generator.load_state_dict(torch.load(config['generator_path']))
421
- # generator = nn.DataParallel(generator, device_ids=[0, 1]) if num_gpus == 2 else generator.to(device)
422
-
423
- # generator.train()
424
-
425
- # l2_loss = nn.MSELoss()
426
- # g_optim = optim.Adam(generator.parameters(), lr=1e-4)
427
-
428
- # pre_epoch = 0
429
- # fine_epoch = 0
430
-
431
- # #### Train using L2_loss
432
- # while pre_epoch < config['pre_train_epoch']:
433
- # for i, tr_data in enumerate(loader):
434
- # gt = tr_data['GT'].to(device)
435
- # lr = tr_data['LR'].to(device)
436
- # output = generator(lr)
437
- # loss = l2_loss(gt, output)
438
-
439
- # g_optim.zero_grad()
440
- # loss.backward()
441
- # g_optim.step()
442
-
443
- # pre_epoch += 1
444
-
445
- # if pre_epoch % 2 == 0:
446
- # print(pre_epoch)
447
- # print(loss.item())
448
- # print('=========')
449
-
450
- # if pre_epoch % 4 == 0:
451
- # torch.save(generator.state_dict(), '/kaggle/working/model/pre_trained_model_%03d.pt' % pre_epoch)
452
-
453
- # #### Train using perceptual & adversarial loss
454
- # vgg_net = vgg19().to(device)
455
- # vgg_net = vgg_net.eval()
456
- # discriminator = Discriminator()
457
- # if num_gpus == 2:
458
- # discriminator = DataParallel(discriminator, device_ids = [0,1]).to(device)
459
- # else:
460
- # discriminator = discriminator.to(device)
461
- # discriminator.train()
462
-
463
- # d_optim = optim.Adam(discriminator.parameters(), lr=1e-4)
464
- # scheduler = optim.lr_scheduler.StepLR(g_optim, step_size=2000, gamma=0.1)
465
-
466
- # VGG_loss = perceptual_loss(vgg_net)
467
- # cross_ent = nn.BCELoss()
468
- # tv_loss = TVLoss()
469
- # real_label = torch.ones((gt.size(0), 2)).to(device)
470
- # fake_label = torch.zeros((gt.size(0), 2)).to(device)
471
-
472
- # while fine_epoch < config['fine_train_epoch']:
473
- # scheduler.step()
474
-
475
- # for i, tr_data in enumerate(loader):
476
- # gt = tr_data['GT'].to(device)
477
- # lr = tr_data['LR'].to(device)
478
-
479
- # ## Training Discriminator
480
- # output = generator(lr)
481
- # fake_prob = discriminator(output, lr)
482
- # real_prob = discriminator(gt, lr)
483
- # d_loss_real = cross_ent(real_prob, real_label)
484
- # d_loss_fake = cross_ent(fake_prob, fake_label)
485
-
486
- # d_loss = d_loss_real + d_loss_fake
487
-
488
- # g_optim.zero_grad()
489
- # d_optim.zero_grad()
490
- # d_loss.backward()
491
- # d_optim.step()
492
-
493
- # output = generator(lr)
494
- # fake_prob = discriminator(output, lr)
495
-
496
- # _percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0, layer=config['feat_layer'])
497
-
498
- # L2_loss = l2_loss(output, gt)
499
- # percep_loss = config['vgg_rescale_coeff'] * _percep_loss
500
- # adversarial_loss = config['adv_coeff'] * cross_ent(fake_prob, real_label)
501
- # total_variance_loss = config['tv_loss_coeff'] * tv_loss(config['vgg_rescale_coeff'] * (hr_feat - sr_feat)**2)
502
-
503
- # g_loss = percep_loss + adversarial_loss + total_variance_loss + L2_loss
504
-
505
- # g_optim.zero_grad()
506
- # d_optim.zero_grad()
507
- # g_loss.backward()
508
- # g_optim.step()
509
-
510
- # fine_epoch += 1
511
-
512
- # if fine_epoch % 2 == 0:
513
- # print(fine_epoch)
514
- # print(g_loss.item())
515
- # print(d_loss.item())
516
- # print('=========')
517
-
518
- # if fine_epoch % 4 == 0:
519
- # torch.save(generator.state_dict(), '/kaggle/working/model/MedSRGAN_gene_%03d.pt' % fine_epoch)
520
- # torch.save(discriminator.state_dict(), '/kaggle/working/model/MedSRGAN_discrim_%03d.pt' % fine_epoch)
521
 
522
 
523
 
524
  config = {
525
  'LR_path': '/kaggle/input/lumber-spine-dataset/LRimages',
526
  'GT_path': '/kaggle/input/lumber-spine-dataset/HRimages',
527
- # 'LR_path': '/kaggle/input/custom-dataset/custom_dataset/train_LR',
528
- # 'GT_path': '/kaggle/input/custom-dataset/custom_dataset/train_HR',
529
  'res_num': 16,
530
  'num_workers': 0,
531
  'batch_size': 16,
@@ -543,101 +388,36 @@ config = {
543
  'generator_path': None, # Set the path if available
544
  'mode': 'train'
545
  }
546
-
547
 
548
- # List all folders in /kaggle/input/
549
- #input_directory = '/kaggle/input/lumber-spine-dataset/'
550
- # input_directory = '/kaggle/input/custom-dataset/custom_dataset/'
551
- #all_folders = [f for f in os.listdir(input_directory) if os.path.isdir(os.path.join(input_directory, f))]
552
-
553
- # Print the list of folders
554
- # print("All folders in /kaggle/input/:")
555
- # for folder in all_folders:
556
- # print(folder)
557
-
558
- #train(config)
559
-
560
- # def preprocess_input(image_path):
561
- # # Load the input image
562
- # input_image = Image.open(image_path).convert("RGB")
563
- # input_image = np.array(input_image) / 127.5 - 1.0
564
- # input_image = input_image.transpose(2, 0, 1).astype(np.float32)
565
- # return torch.tensor(input_image).unsqueeze(0)
566
-
567
-
568
- def preprocess_input(image):
569
- # Convert Gradio image to PIL format
570
- image_pil = to_pil_image(image)
571
- image_pil=image_pil.convert("RGB")
572
- # Convert image to numpy array and normalize
573
- input_image = np.array(image_pil) / 127.5 - 1.0
574
- # Transpose dimensions
575
- input_image = input_image.transpose(2, 0, 1).astype(np.float32)
576
- # Convert to PyTorch tensor and add batch dimension
577
- input_tensor = torch.tensor(input_image).unsqueeze(0)
578
- return input_tensor
579
 
 
 
 
 
580
 
581
 
582
-
583
- # def test_single_image(generator_path, input_image_path):
584
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
585
- # num_gpus = torch.cuda.device_count()
586
- # # Load the generator model
587
- # generator = Generator()
588
-
589
- # if num_gpus == 2:
590
- # state_dict = torch.load(generator_path)
591
- # if 'module.' in list(state_dict.keys())[0]:
592
- # state_dict = {k[7:]: v for k, v in state_dict.items()}
593
- # # Load the state dictionary to the model
594
- # generator.load_state_dict(state_dict)
595
- # generator = DataParallel(generator, device_ids= [0,1]).to(device)
596
-
597
- # else:
598
- # generator = generator.to(device)
599
- # generator.load_state_dict(torch.load(generator_path))
600
- # generator.eval()
601
-
602
- # # Preprocess the input image
603
- # input_image = preprocess_input(input_image_path).to(device)
604
-
605
- # # Perform inference
606
- # with torch.no_grad():
607
- # output = generator(input_image)
608
- # output = output[0].cpu().numpy()
609
- # output = (output + 1.0) / 2.0
610
- # output = output.transpose(1, 2, 0)
611
- # plt.imshow(output)
612
- # plt.show()
613
-
614
- def test_single_image(input_image):
615
- return input_image
616
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
617
- num_gpus = torch.cuda.device_count()
618
- generator_path = 'pre_trained_model_064.pt'
 
 
619
 
620
  # Load the generator model
621
- generator = Generator()
622
-
623
- if num_gpus == 2:
624
- state_dict = torch.load(generator_path)
625
- if 'module.' in list(state_dict.keys())[0]:
626
- state_dict = {k[7:]: v for k, v in state_dict.items()}
627
- # Load the state dictionary to the model
628
- generator.load_state_dict(state_dict)
629
- generator = torch.nn.DataParallel(generator, device_ids=[0, 1]).to(device)
630
- else:
631
- generator = generator.to(device)
632
- generator.load_state_dict(torch.load(generator_path))
633
  generator.eval()
634
 
635
  # Preprocess the input image
636
- input_tensor = preprocess_input(input_image).to(device)
637
 
638
  # Perform inference
639
  with torch.no_grad():
640
- output = generator(input_tensor)
641
  output = output[0].cpu().numpy()
642
  output = (output + 1.0) / 2.0
643
  output = output.transpose(1, 2, 0)
@@ -646,23 +426,13 @@ def test_single_image(input_image):
646
  return output_image
647
 
648
 
649
-
650
-
651
-
652
- # Load the generator model path
653
- #generator_path = 'pre_trained_model_064.pt'
654
-
655
-
656
-
657
- #test_single_image('pre_trained_model_064.pt',uploaded_image_data)
658
- # Define Gradio interface
659
  uploaded_image_data = gr.components.Image(type="pil", label="Upload Image")
660
- with gr.Column(scale=2, min_width=400):
661
- output = gr.components.Image(type="pil", label="Super-Resolated Image")
662
 
663
 
664
  # Deploy the interface
665
- gr.Interface(test_single_image, inputs=uploaded_image_data, outputs=output,
666
  title="Image Super-Resolution",
667
  description="Upload an image to see it super-resolated."
668
  ).launch()
 
112
 
113
  return {'LR' : LR_patch, 'GT' : GT_patch}
114
 
115
+
116
+
117
+
118
  class augmentation(object):
119
 
120
  def __call__(self, sample):
 
123
  hor_flip = random.randrange(0,2)
124
  ver_flip = random.randrange(0,2)
125
  rot = random.randrange(0,2)
 
 
126
  if hor_flip:
127
  temp_LR = np.fliplr(LR_img)
128
  LR_img = temp_LR.copy()
 
254
  self.block_1_2 = D_Block(64, 128, stride=1)
255
  self.block_1_3 = D_Block(128, 128)
256
 
257
+
258
  # Layer for LR image size
259
  self.conv_2_1 = nn.Sequential(
260
  nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1), nn.LeakyReLU()
 
296
  )
297
 
298
  x = self.flatten(x)
 
 
 
299
  x = self.fc1(x)
300
  x = self.fc2(self.relu(x))
 
 
301
  return self.sigmoid(x)
302
 
303
 
304
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
 
307
  class MeanShift(nn.Conv2d):
 
319
  p.requires_grad = False
320
 
321
 
322
+
323
  class perceptual_loss(nn.Module):
324
  def __init__(self, vgg):
325
  super(perceptual_loss, self).__init__()
 
365
  @staticmethod
366
  def tensor_size(t):
367
  return t.size()[1] * t.size()[2] * t.size()[3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
 
370
 
371
  config = {
372
  'LR_path': '/kaggle/input/lumber-spine-dataset/LRimages',
373
  'GT_path': '/kaggle/input/lumber-spine-dataset/HRimages',
 
 
374
  'res_num': 16,
375
  'num_workers': 0,
376
  'batch_size': 16,
 
388
  'generator_path': None, # Set the path if available
389
  'mode': 'train'
390
  }
 
391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
+ def preprocess_input(input_image):
394
+ input_image = np.array(input_image) / 127.5 - 1.0
395
+ input_image = input_image.transpose(2, 0, 1).astype(np.float32)
396
+ return torch.tensor(input_image).unsqueeze(0)
397
 
398
 
399
+ def test_single_image( input_image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
401
+ # generator_path = './pre_trained_model_064.pt'
402
+ generator_path = './MedSRGAN_gene_016.pt'
403
+
404
+
405
 
406
  # Load the generator model
407
+ generator = Generator() #for ours
408
+ state_dict = torch.load(generator_path, map_location='cpu')
409
+ if 'module.' in list(state_dict.keys())[0]:
410
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
411
+
412
+ generator.load_state_dict(state_dict)
 
 
 
 
 
 
413
  generator.eval()
414
 
415
  # Preprocess the input image
416
+ input_image = preprocess_input(input_image).to(device)
417
 
418
  # Perform inference
419
  with torch.no_grad():
420
+ output = generator(input_image)
421
  output = output[0].cpu().numpy()
422
  output = (output + 1.0) / 2.0
423
  output = output.transpose(1, 2, 0)
 
426
  return output_image
427
 
428
 
 
 
 
 
 
 
 
 
 
 
429
  uploaded_image_data = gr.components.Image(type="pil", label="Upload Image")
430
+ # with gr.Column(scale=2, min_width=400):
431
+ output = gr.components.Image(type="pil", label="Super-Resolated Image")
432
 
433
 
434
  # Deploy the interface
435
+ gr.Interface(test_single_image, inputs=uploaded_image_data , outputs=output,
436
  title="Image Super-Resolution",
437
  description="Upload an image to see it super-resolated."
438
  ).launch()