Update app.py
Browse files
app.py
CHANGED
|
@@ -16,6 +16,7 @@ import os
|
|
| 16 |
import matplotlib.pyplot as plt
|
| 17 |
import gradio as gr
|
| 18 |
import io
|
|
|
|
| 19 |
|
| 20 |
class mydata(Dataset):
|
| 21 |
def __init__(self, LR_path, GT_path, in_memory = True, transform = None):
|
|
@@ -556,45 +557,98 @@ config = {
|
|
| 556 |
|
| 557 |
#train(config)
|
| 558 |
|
| 559 |
-
def preprocess_input(image_path):
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
input_image = input_image.transpose(2, 0, 1).astype(np.float32)
|
| 564 |
-
|
|
|
|
|
|
|
|
|
|
| 565 |
|
| 566 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 568 |
num_gpus = torch.cuda.device_count()
|
|
|
|
|
|
|
| 569 |
# Load the generator model
|
| 570 |
generator = Generator()
|
| 571 |
-
|
| 572 |
if num_gpus == 2:
|
| 573 |
state_dict = torch.load(generator_path)
|
| 574 |
if 'module.' in list(state_dict.keys())[0]:
|
| 575 |
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
| 576 |
# Load the state dictionary to the model
|
| 577 |
generator.load_state_dict(state_dict)
|
| 578 |
-
generator = DataParallel(generator, device_ids=
|
| 579 |
-
|
| 580 |
else:
|
| 581 |
generator = generator.to(device)
|
| 582 |
generator.load_state_dict(torch.load(generator_path))
|
| 583 |
generator.eval()
|
| 584 |
|
| 585 |
# Preprocess the input image
|
| 586 |
-
|
| 587 |
|
| 588 |
# Perform inference
|
| 589 |
with torch.no_grad():
|
| 590 |
-
output = generator(
|
| 591 |
output = output[0].cpu().numpy()
|
| 592 |
output = (output + 1.0) / 2.0
|
| 593 |
output = output.transpose(1, 2, 0)
|
| 594 |
-
|
| 595 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
# Load the generator model path
|
| 597 |
-
generator_path = '
|
| 598 |
|
| 599 |
|
| 600 |
|
|
|
|
| 16 |
import matplotlib.pyplot as plt
|
| 17 |
import gradio as gr
|
| 18 |
import io
|
| 19 |
+
from torchvision.transforms.functional import to_pil_image
|
| 20 |
|
| 21 |
class mydata(Dataset):
|
| 22 |
def __init__(self, LR_path, GT_path, in_memory = True, transform = None):
|
|
|
|
| 557 |
|
| 558 |
#train(config)
|
| 559 |
|
| 560 |
+
# def preprocess_input(image_path):
|
| 561 |
+
# # Load the input image
|
| 562 |
+
# input_image = Image.open(image_path).convert("RGB")
|
| 563 |
+
# input_image = np.array(input_image) / 127.5 - 1.0
|
| 564 |
+
# input_image = input_image.transpose(2, 0, 1).astype(np.float32)
|
| 565 |
+
# return torch.tensor(input_image).unsqueeze(0)
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def preprocess_input(image):
|
| 569 |
+
# Convert Gradio image to PIL format
|
| 570 |
+
image_pil = to_pil_image(image)
|
| 571 |
+
# Convert image to numpy array and normalize
|
| 572 |
+
input_image = np.array(image_pil) / 127.5 - 1.0
|
| 573 |
+
# Transpose dimensions
|
| 574 |
input_image = input_image.transpose(2, 0, 1).astype(np.float32)
|
| 575 |
+
# Convert to PyTorch tensor and add batch dimension
|
| 576 |
+
input_tensor = torch.tensor(input_image).unsqueeze(0)
|
| 577 |
+
return input_tensor
|
| 578 |
+
|
| 579 |
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
# def test_single_image(generator_path, input_image_path):
|
| 583 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 584 |
+
# num_gpus = torch.cuda.device_count()
|
| 585 |
+
# # Load the generator model
|
| 586 |
+
# generator = Generator()
|
| 587 |
+
|
| 588 |
+
# if num_gpus == 2:
|
| 589 |
+
# state_dict = torch.load(generator_path)
|
| 590 |
+
# if 'module.' in list(state_dict.keys())[0]:
|
| 591 |
+
# state_dict = {k[7:]: v for k, v in state_dict.items()}
|
| 592 |
+
# # Load the state dictionary to the model
|
| 593 |
+
# generator.load_state_dict(state_dict)
|
| 594 |
+
# generator = DataParallel(generator, device_ids= [0,1]).to(device)
|
| 595 |
+
|
| 596 |
+
# else:
|
| 597 |
+
# generator = generator.to(device)
|
| 598 |
+
# generator.load_state_dict(torch.load(generator_path))
|
| 599 |
+
# generator.eval()
|
| 600 |
+
|
| 601 |
+
# # Preprocess the input image
|
| 602 |
+
# input_image = preprocess_input(input_image_path).to(device)
|
| 603 |
+
|
| 604 |
+
# # Perform inference
|
| 605 |
+
# with torch.no_grad():
|
| 606 |
+
# output = generator(input_image)
|
| 607 |
+
# output = output[0].cpu().numpy()
|
| 608 |
+
# output = (output + 1.0) / 2.0
|
| 609 |
+
# output = output.transpose(1, 2, 0)
|
| 610 |
+
# plt.imshow(output)
|
| 611 |
+
# plt.show()
|
| 612 |
+
|
| 613 |
+
def test_single_image(input_image):
|
| 614 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 615 |
num_gpus = torch.cuda.device_count()
|
| 616 |
+
generator_path = 'pre_trained_model_064.pt'
|
| 617 |
+
|
| 618 |
# Load the generator model
|
| 619 |
generator = Generator()
|
| 620 |
+
|
| 621 |
if num_gpus == 2:
|
| 622 |
state_dict = torch.load(generator_path)
|
| 623 |
if 'module.' in list(state_dict.keys())[0]:
|
| 624 |
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
| 625 |
# Load the state dictionary to the model
|
| 626 |
generator.load_state_dict(state_dict)
|
| 627 |
+
generator = torch.nn.DataParallel(generator, device_ids=[0, 1]).to(device)
|
|
|
|
| 628 |
else:
|
| 629 |
generator = generator.to(device)
|
| 630 |
generator.load_state_dict(torch.load(generator_path))
|
| 631 |
generator.eval()
|
| 632 |
|
| 633 |
# Preprocess the input image
|
| 634 |
+
input_tensor = preprocess_input(input_image).to(device)
|
| 635 |
|
| 636 |
# Perform inference
|
| 637 |
with torch.no_grad():
|
| 638 |
+
output = generator(input_tensor)
|
| 639 |
output = output[0].cpu().numpy()
|
| 640 |
output = (output + 1.0) / 2.0
|
| 641 |
output = output.transpose(1, 2, 0)
|
| 642 |
+
# Convert output to PIL image format
|
| 643 |
+
output_image = to_pil_image(output)
|
| 644 |
+
return output_image
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
|
| 650 |
# Load the generator model path
|
| 651 |
+
#generator_path = 'pre_trained_model_064.pt'
|
| 652 |
|
| 653 |
|
| 654 |
|