suyoyog commited on
Commit
89b0cd6
·
verified ·
1 Parent(s): 524c441

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -14
app.py CHANGED
@@ -16,6 +16,7 @@ import os
16
  import matplotlib.pyplot as plt
17
  import gradio as gr
18
  import io
 
19
 
20
  class mydata(Dataset):
21
  def __init__(self, LR_path, GT_path, in_memory = True, transform = None):
@@ -556,45 +557,98 @@ config = {
556
 
557
  #train(config)
558
 
559
- def preprocess_input(image_path):
560
- # Load the input image
561
- input_image = Image.open(image_path).convert("RGB")
562
- input_image = np.array(input_image) / 127.5 - 1.0
 
 
 
 
 
 
 
 
 
 
563
  input_image = input_image.transpose(2, 0, 1).astype(np.float32)
564
- return torch.tensor(input_image).unsqueeze(0)
 
 
 
565
 
566
- def test_single_image(generator_path, input_image_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
568
  num_gpus = torch.cuda.device_count()
 
 
569
  # Load the generator model
570
  generator = Generator()
571
-
572
  if num_gpus == 2:
573
  state_dict = torch.load(generator_path)
574
  if 'module.' in list(state_dict.keys())[0]:
575
  state_dict = {k[7:]: v for k, v in state_dict.items()}
576
  # Load the state dictionary to the model
577
  generator.load_state_dict(state_dict)
578
- generator = DataParallel(generator, device_ids= [0,1]).to(device)
579
-
580
  else:
581
  generator = generator.to(device)
582
  generator.load_state_dict(torch.load(generator_path))
583
  generator.eval()
584
 
585
  # Preprocess the input image
586
- input_image = preprocess_input(input_image_path).to(device)
587
 
588
  # Perform inference
589
  with torch.no_grad():
590
- output = generator(input_image)
591
  output = output[0].cpu().numpy()
592
  output = (output + 1.0) / 2.0
593
  output = output.transpose(1, 2, 0)
594
- plt.imshow(output)
595
- plt.show()
 
 
 
 
 
 
596
  # Load the generator model path
597
- generator_path = '/pre_trained_model_064.pt'
598
 
599
 
600
 
 
16
  import matplotlib.pyplot as plt
17
  import gradio as gr
18
  import io
19
+ from torchvision.transforms.functional import to_pil_image
20
 
21
  class mydata(Dataset):
22
  def __init__(self, LR_path, GT_path, in_memory = True, transform = None):
 
557
 
558
  #train(config)
559
 
560
+ # def preprocess_input(image_path):
561
+ # # Load the input image
562
+ # input_image = Image.open(image_path).convert("RGB")
563
+ # input_image = np.array(input_image) / 127.5 - 1.0
564
+ # input_image = input_image.transpose(2, 0, 1).astype(np.float32)
565
+ # return torch.tensor(input_image).unsqueeze(0)
566
+
567
+
568
+ def preprocess_input(image):
569
+ # Convert Gradio image to PIL format
570
+ image_pil = to_pil_image(image)
571
+ # Convert image to numpy array and normalize
572
+ input_image = np.array(image_pil) / 127.5 - 1.0
573
+ # Transpose dimensions
574
  input_image = input_image.transpose(2, 0, 1).astype(np.float32)
575
+ # Convert to PyTorch tensor and add batch dimension
576
+ input_tensor = torch.tensor(input_image).unsqueeze(0)
577
+ return input_tensor
578
+
579
 
580
+
581
+
582
+ # def test_single_image(generator_path, input_image_path):
583
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
584
+ # num_gpus = torch.cuda.device_count()
585
+ # # Load the generator model
586
+ # generator = Generator()
587
+
588
+ # if num_gpus == 2:
589
+ # state_dict = torch.load(generator_path)
590
+ # if 'module.' in list(state_dict.keys())[0]:
591
+ # state_dict = {k[7:]: v for k, v in state_dict.items()}
592
+ # # Load the state dictionary to the model
593
+ # generator.load_state_dict(state_dict)
594
+ # generator = DataParallel(generator, device_ids= [0,1]).to(device)
595
+
596
+ # else:
597
+ # generator = generator.to(device)
598
+ # generator.load_state_dict(torch.load(generator_path))
599
+ # generator.eval()
600
+
601
+ # # Preprocess the input image
602
+ # input_image = preprocess_input(input_image_path).to(device)
603
+
604
+ # # Perform inference
605
+ # with torch.no_grad():
606
+ # output = generator(input_image)
607
+ # output = output[0].cpu().numpy()
608
+ # output = (output + 1.0) / 2.0
609
+ # output = output.transpose(1, 2, 0)
610
+ # plt.imshow(output)
611
+ # plt.show()
612
+
613
+ def test_single_image(input_image):
614
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
615
  num_gpus = torch.cuda.device_count()
616
+ generator_path = 'pre_trained_model_064.pt'
617
+
618
  # Load the generator model
619
  generator = Generator()
620
+
621
  if num_gpus == 2:
622
  state_dict = torch.load(generator_path)
623
  if 'module.' in list(state_dict.keys())[0]:
624
  state_dict = {k[7:]: v for k, v in state_dict.items()}
625
  # Load the state dictionary to the model
626
  generator.load_state_dict(state_dict)
627
+ generator = torch.nn.DataParallel(generator, device_ids=[0, 1]).to(device)
 
628
  else:
629
  generator = generator.to(device)
630
  generator.load_state_dict(torch.load(generator_path))
631
  generator.eval()
632
 
633
  # Preprocess the input image
634
+ input_tensor = preprocess_input(input_image).to(device)
635
 
636
  # Perform inference
637
  with torch.no_grad():
638
+ output = generator(input_tensor)
639
  output = output[0].cpu().numpy()
640
  output = (output + 1.0) / 2.0
641
  output = output.transpose(1, 2, 0)
642
+ # Convert output to PIL image format
643
+ output_image = to_pil_image(output)
644
+ return output_image
645
+
646
+
647
+
648
+
649
+
650
  # Load the generator model path
651
+ #generator_path = 'pre_trained_model_064.pt'
652
 
653
 
654