suyoyog's picture
Update app.py
152a480 verified
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.models as models
import numpy as np
from PIL import Image
from skimage.color import rgb2ycbcr
from skimage.metrics import peak_signal_noise_ratio,structural_similarity
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import random
from torch.nn.parallel import DataParallel
import os
import matplotlib.pyplot as plt
import gradio as gr
import io
from torchvision.transforms.functional import to_pil_image
class mydata(Dataset):
def __init__(self, LR_path, GT_path, in_memory = True, transform = None):
self.LR_path = LR_path
self.GT_path = GT_path
self.in_memory = in_memory
self.transform = transform
# LR image and GT image path used
self.LR_img = sorted(os.listdir(LR_path))
self.GT_img = sorted(os.listdir(GT_path))
if in_memory:
self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr)).convert("RGB")).astype(np.uint8) for lr in self.LR_img]
self.GT_img = [np.array(Image.open(os.path.join(self.GT_path, gt)).convert("RGB")).astype(np.uint8) for gt in self.GT_img]
def __len__(self):
return len(self.LR_img)
def __getitem__(self, i):
img_item = {}
if self.in_memory:
GT = self.GT_img[i].astype(np.float32)
LR = self.LR_img[i].astype(np.float32)
else:
GT = np.array(Image.open(os.path.join(self.GT_path, self.GT_img[i])).convert("RGB"))
LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i])).convert("RGB"))
img_item['GT'] = (GT / 127.5) - 1.0
img_item['LR'] = (LR / 127.5) - 1.0
if self.transform is not None:
img_item = self.transform(img_item)
img_item['GT'] = img_item['GT'].transpose(2, 0, 1).astype(np.float32)
img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32)
return img_item
class testOnly_data(Dataset):
def __init__(self, LR_path, in_memory = True, transform = None):
self.LR_path = LR_path
self.LR_img = sorted(os.listdir(LR_path))
self.in_memory = in_memory
if in_memory:
self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr))) for lr in self.LR_img]
def __len__(self):
return len(self.LR_img)
def __getitem__(self, i):
img_item = {}
if self.in_memory:
LR = self.LR_img[i]
else:
LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i])))
img_item['LR'] = (LR / 127.5) - 1.0
img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32)
return img_item
class crop(object):
def __init__(self, scale, patch_size):
self.scale = scale
self.patch_size = patch_size
def __call__(self, sample):
LR_img, GT_img = sample['LR'], sample['GT']
ih, iw = LR_img.shape[:2]
ix = random.randrange(0, iw - self.patch_size +1)
iy = random.randrange(0, ih - self.patch_size +1)
tx = ix * self.scale
ty = iy * self.scale
LR_patch = LR_img[iy : iy + self.patch_size, ix : ix + self.patch_size]
GT_patch = GT_img[ty : ty + (self.scale * self.patch_size), tx : tx + (self.scale * self.patch_size)]
return {'LR' : LR_patch, 'GT' : GT_patch}
class augmentation(object):
def __call__(self, sample):
LR_img, GT_img = sample['LR'], sample['GT']
hor_flip = random.randrange(0,2)
ver_flip = random.randrange(0,2)
rot = random.randrange(0,2)
if hor_flip:
temp_LR = np.fliplr(LR_img)
LR_img = temp_LR.copy()
temp_GT = np.fliplr(GT_img)
GT_img = temp_GT.copy()
del temp_LR, temp_GT
if ver_flip:
temp_LR = np.flipud(LR_img)
LR_img = temp_LR.copy()
temp_GT = np.flipud(GT_img)
GT_img = temp_GT.copy()
del temp_LR, temp_GT
if rot:
LR_img = LR_img.transpose(1, 0, 2)
GT_img = GT_img.transpose(1, 0, 2)
return {'LR' : LR_img, 'GT' : GT_img}
class RWMAB(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, (3, 3), stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels, in_channels, (3, 3), stride=1, padding=1),
)
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, (1, 1), stride=1, padding=0),
nn.Sigmoid(),
)
def forward(self, x):
# print('In RWMAB')
x_ = self.layer1(x)
x__ = self.layer2(x_)
x = x__ * x_ + x
return x
class ShortResidualBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.layers = nn.ModuleList([RWMAB(in_channels) for _ in range(16)])
def forward(self, x):
# print('In Short Residual Block')
x_ = x.clone()
for layer in self.layers:
x_ = layer(x_)
return x_ + x
class Generator(nn.Module):
def __init__(self, in_channels=3, blocks=8, scale = 4):
super().__init__()
# Add the noise part here
self.conv = nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1)
self.short_blocks = nn.ModuleList(
[ShortResidualBlock(64) for _ in range(blocks)]
)
self.conv2 = nn.Conv2d(64, 64, (1, 1), stride=1, padding=0)
self.conv3 = nn.Sequential(
nn.Conv2d(128, 256, (3, 3), stride=1, padding=1),
nn.PixelShuffle(2),
nn.LeakyReLU(),
nn.Conv2d(64, 256, (3, 3), stride=1, padding=1),
nn.PixelShuffle(2), # Remove if output is 2x the input
nn.LeakyReLU(),
nn.Conv2d(64, 3, (1, 1), stride=1, padding=0), # Change 64 -> 256
nn.Tanh(),
# nn.Sigmoid(),
)
def forward(self, x):
# print('In Generator')
x = self.conv(x)
x_ = x.clone()
for layer in self.short_blocks:
x_ = layer(x_)
x = torch.cat([self.conv2(x_), x], dim=1)
x = self.conv3(x)
return x
class D_Block(nn.Module):
def __init__(self, in_channels, out_channels, stride=2):
super().__init__()
self.layer = nn.Sequential(
nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(),
)
def forward(self, x):
return self.layer(x)
class Discriminator(nn.Module):
def __init__(self, img_size=[64,64], in_channels=3):
super().__init__()
# Layers for generator image size
self.conv_1_1 = nn.Sequential(
nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1), nn.LeakyReLU()
)
self.block_1_1 = D_Block(64, 64, stride=2) # stride= 2 if output 4x
self.block_1_2 = D_Block(64, 128, stride=1)
self.block_1_3 = D_Block(128, 128)
# Layer for LR image size
self.conv_2_1 = nn.Sequential(
nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1), nn.LeakyReLU()
)
self.block_2_2 = nn.Sequential(
D_Block(64, 128, stride=1),
nn.BatchNorm2d(128),
nn.LeakyReLU()
)
# D_Block is the class defined above
self.block3 = D_Block(256, 256, stride=1)
self.block4 = D_Block(256, 256)
self.block5 = D_Block(256, 512, stride=1)
self.block6 = D_Block(512, 512)
self.block7 = D_Block(512, 1024)
self.block8 = D_Block(1024, 1024)
self.flatten = nn.Flatten()
#self.fc1 = nn.Sequential(nn.Linear(1024 * img_size[0] * img_size[1] // 256, 100), nn.LeakyReLU() )# Change based on input image size
self.fc1 = nn.Sequential(nn.Linear(4096, 100), nn.LeakyReLU())
self.fc2 = nn.Linear(100, 2)
self.relu = nn.LeakyReLU(negative_slope=0.2)
self.sigmoid = nn.Sigmoid()
def forward(self, x1, x2):
'''
x1 is the array for generator image
x2 is the array for lr image
'''
x_1 = self.block_1_3(self.block_1_2(self.block_1_1(self.conv_1_1(x1))))
x_2 = self.block_2_2(self.conv_2_1(x2))
x = torch.cat([x_1, x_2], dim=1)
x = self.block8(
self.block7(self.block6(self.block5(self.block4(self.block3(x)))))
)
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(self.relu(x))
return self.sigmoid(x)
class MeanShift(nn.Conv2d):
def __init__(
self, rgb_range = 1,
norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225), sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(norm_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
self.bias.data = sign * rgb_range * torch.Tensor(norm_mean) / std
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for p in self.parameters():
p.requires_grad = False
class perceptual_loss(nn.Module):
def __init__(self, vgg):
super(perceptual_loss, self).__init__()
self.normalization_mean = [0.485, 0.456, 0.406]
self.normalization_std = [0.229, 0.224, 0.225]
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print("Using device:", self.device)
self.transform = MeanShift(norm_mean=self.normalization_mean, norm_std=self.normalization_std).to(self.device)
self.vgg = vgg.to(self.device) # Move the entire model to the GPU
self.criterion = nn.MSELoss()
def forward(self, HR, SR, layer='relu5_4'):
hr = self.transform(HR).to(self.device) # Move HR tensor to the GPU
sr = self.transform(SR).to(self.device) # Move SR tensor to the GPU
vgg_outputs_hr = self.vgg(hr)
vgg_outputs_sr = self.vgg(sr)
vgg_layer = self.vgg.vgg_layer
# Ensure that both hr_feat and sr_feat are on the same device
hr_feat = vgg_outputs_hr[vgg_layer.index(layer)].to(self.device)
sr_feat = vgg_outputs_sr[vgg_layer.index(layer)].to(self.device)
return self.criterion(hr_feat, sr_feat), hr_feat, sr_feat
class TVLoss(nn.Module):
def __init__(self, tv_loss_weight=1):
super(TVLoss, self).__init__()
self.tv_loss_weight = tv_loss_weight
def forward(self, x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self.tensor_size(x[:, :, 1:, :])
count_w = self.tensor_size(x[:, :, :, 1:])
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
@staticmethod
def tensor_size(t):
return t.size()[1] * t.size()[2] * t.size()[3]
config = {
'LR_path': '/kaggle/input/lumber-spine-dataset/LRimages',
'GT_path': '/kaggle/input/lumber-spine-dataset/HRimages',
'res_num': 16,
'num_workers': 0,
'batch_size': 16,
'L2_coeff': 1.0,
'adv_coeff': 1e-3,
'tv_loss_coeff': 0.0,
'pre_train_epoch': 16,
'fine_train_epoch': 8,
'scale': 4,
'patch_size': 24,
'feat_layer': 'relu5_4',
'vgg_rescale_coeff': 0.006,
'fine_tuning': False,
'in_memory': True,
'generator_path': None, # Set the path if available
'mode': 'train'
}
def preprocess_input(input_image):
input_image = np.array(input_image) / 127.5 - 1.0
input_image = input_image.transpose(2, 0, 1).astype(np.float32)
return torch.tensor(input_image).unsqueeze(0)
def test_single_image( input_image):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# generator_path = './pre_trained_model_064.pt'
generator_path = './MedSRGAN_gene_016.pt'
# Load the generator model
generator = Generator() #for ours
state_dict = torch.load(generator_path, map_location='cpu')
if 'module.' in list(state_dict.keys())[0]:
state_dict = {k[7:]: v for k, v in state_dict.items()}
generator.load_state_dict(state_dict)
generator.eval()
# Preprocess the input image
input_image = preprocess_input(input_image).to(device)
# Perform inference
with torch.no_grad():
output = generator(input_image)
output = output[0].cpu().numpy()
output = (output + 1.0) / 2.0
output = output.transpose(1, 2, 0)
# Convert output to PIL image format
output_image = to_pil_image(output)
return output_image
uploaded_image_data = gr.components.Image(type="pil", label="Upload Medical Image(Spine)")
# with gr.Column(scale=2, min_width=400):
output = gr.components.Image(type="pil", label="Enhanced medical image")
# Deploy the interface
gr.Interface(test_single_image, inputs=uploaded_image_data , outputs=output,
title="Medical Image Super-Resolution with GAN",
description="Upload a medical image to see it enhanced. This model is trained on 22,352 low-resolution (LR) and high-resolution (HR) medical images for 64 epochs, followed by 16 fine-tuning epochs. The enlarged image is 4 times the original size."
).launch()