|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import DataLoader |
|
|
from torchvision import transforms |
|
|
import torchvision.models as models |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from skimage.color import rgb2ycbcr |
|
|
from skimage.metrics import peak_signal_noise_ratio,structural_similarity |
|
|
import matplotlib.pyplot as plt |
|
|
from torch.utils.data import Dataset |
|
|
import random |
|
|
from torch.nn.parallel import DataParallel |
|
|
import os |
|
|
import matplotlib.pyplot as plt |
|
|
import gradio as gr |
|
|
import io |
|
|
from torchvision.transforms.functional import to_pil_image |
|
|
|
|
|
class mydata(Dataset): |
|
|
def __init__(self, LR_path, GT_path, in_memory = True, transform = None): |
|
|
|
|
|
self.LR_path = LR_path |
|
|
self.GT_path = GT_path |
|
|
self.in_memory = in_memory |
|
|
self.transform = transform |
|
|
|
|
|
|
|
|
self.LR_img = sorted(os.listdir(LR_path)) |
|
|
self.GT_img = sorted(os.listdir(GT_path)) |
|
|
|
|
|
if in_memory: |
|
|
self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr)).convert("RGB")).astype(np.uint8) for lr in self.LR_img] |
|
|
self.GT_img = [np.array(Image.open(os.path.join(self.GT_path, gt)).convert("RGB")).astype(np.uint8) for gt in self.GT_img] |
|
|
|
|
|
def __len__(self): |
|
|
|
|
|
return len(self.LR_img) |
|
|
|
|
|
def __getitem__(self, i): |
|
|
|
|
|
img_item = {} |
|
|
|
|
|
if self.in_memory: |
|
|
GT = self.GT_img[i].astype(np.float32) |
|
|
LR = self.LR_img[i].astype(np.float32) |
|
|
|
|
|
else: |
|
|
GT = np.array(Image.open(os.path.join(self.GT_path, self.GT_img[i])).convert("RGB")) |
|
|
LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i])).convert("RGB")) |
|
|
|
|
|
img_item['GT'] = (GT / 127.5) - 1.0 |
|
|
img_item['LR'] = (LR / 127.5) - 1.0 |
|
|
|
|
|
if self.transform is not None: |
|
|
img_item = self.transform(img_item) |
|
|
|
|
|
img_item['GT'] = img_item['GT'].transpose(2, 0, 1).astype(np.float32) |
|
|
img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32) |
|
|
|
|
|
return img_item |
|
|
|
|
|
|
|
|
class testOnly_data(Dataset): |
|
|
def __init__(self, LR_path, in_memory = True, transform = None): |
|
|
|
|
|
self.LR_path = LR_path |
|
|
self.LR_img = sorted(os.listdir(LR_path)) |
|
|
self.in_memory = in_memory |
|
|
if in_memory: |
|
|
self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr))) for lr in self.LR_img] |
|
|
|
|
|
def __len__(self): |
|
|
|
|
|
return len(self.LR_img) |
|
|
|
|
|
def __getitem__(self, i): |
|
|
|
|
|
img_item = {} |
|
|
|
|
|
if self.in_memory: |
|
|
LR = self.LR_img[i] |
|
|
|
|
|
else: |
|
|
LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i]))) |
|
|
|
|
|
img_item['LR'] = (LR / 127.5) - 1.0 |
|
|
img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32) |
|
|
|
|
|
return img_item |
|
|
|
|
|
|
|
|
class crop(object): |
|
|
def __init__(self, scale, patch_size): |
|
|
|
|
|
self.scale = scale |
|
|
self.patch_size = patch_size |
|
|
|
|
|
def __call__(self, sample): |
|
|
LR_img, GT_img = sample['LR'], sample['GT'] |
|
|
ih, iw = LR_img.shape[:2] |
|
|
|
|
|
ix = random.randrange(0, iw - self.patch_size +1) |
|
|
iy = random.randrange(0, ih - self.patch_size +1) |
|
|
|
|
|
tx = ix * self.scale |
|
|
ty = iy * self.scale |
|
|
|
|
|
LR_patch = LR_img[iy : iy + self.patch_size, ix : ix + self.patch_size] |
|
|
GT_patch = GT_img[ty : ty + (self.scale * self.patch_size), tx : tx + (self.scale * self.patch_size)] |
|
|
|
|
|
return {'LR' : LR_patch, 'GT' : GT_patch} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class augmentation(object): |
|
|
|
|
|
def __call__(self, sample): |
|
|
LR_img, GT_img = sample['LR'], sample['GT'] |
|
|
|
|
|
hor_flip = random.randrange(0,2) |
|
|
ver_flip = random.randrange(0,2) |
|
|
rot = random.randrange(0,2) |
|
|
if hor_flip: |
|
|
temp_LR = np.fliplr(LR_img) |
|
|
LR_img = temp_LR.copy() |
|
|
temp_GT = np.fliplr(GT_img) |
|
|
GT_img = temp_GT.copy() |
|
|
|
|
|
del temp_LR, temp_GT |
|
|
|
|
|
if ver_flip: |
|
|
temp_LR = np.flipud(LR_img) |
|
|
LR_img = temp_LR.copy() |
|
|
temp_GT = np.flipud(GT_img) |
|
|
GT_img = temp_GT.copy() |
|
|
|
|
|
del temp_LR, temp_GT |
|
|
|
|
|
if rot: |
|
|
LR_img = LR_img.transpose(1, 0, 2) |
|
|
GT_img = GT_img.transpose(1, 0, 2) |
|
|
|
|
|
|
|
|
return {'LR' : LR_img, 'GT' : GT_img} |
|
|
|
|
|
|
|
|
class RWMAB(nn.Module): |
|
|
def __init__(self, in_channels): |
|
|
super().__init__() |
|
|
self.layer1 = nn.Sequential( |
|
|
nn.Conv2d(in_channels, in_channels, (3, 3), stride=1, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(in_channels, in_channels, (3, 3), stride=1, padding=1), |
|
|
) |
|
|
self.layer2 = nn.Sequential( |
|
|
nn.Conv2d(in_channels, in_channels, (1, 1), stride=1, padding=0), |
|
|
nn.Sigmoid(), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x_ = self.layer1(x) |
|
|
x__ = self.layer2(x_) |
|
|
|
|
|
x = x__ * x_ + x |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class ShortResidualBlock(nn.Module): |
|
|
def __init__(self, in_channels): |
|
|
super().__init__() |
|
|
|
|
|
self.layers = nn.ModuleList([RWMAB(in_channels) for _ in range(16)]) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x_ = x.clone() |
|
|
|
|
|
for layer in self.layers: |
|
|
x_ = layer(x_) |
|
|
|
|
|
return x_ + x |
|
|
|
|
|
|
|
|
class Generator(nn.Module): |
|
|
def __init__(self, in_channels=3, blocks=8, scale = 4): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.conv = nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1) |
|
|
|
|
|
self.short_blocks = nn.ModuleList( |
|
|
[ShortResidualBlock(64) for _ in range(blocks)] |
|
|
) |
|
|
|
|
|
self.conv2 = nn.Conv2d(64, 64, (1, 1), stride=1, padding=0) |
|
|
|
|
|
self.conv3 = nn.Sequential( |
|
|
nn.Conv2d(128, 256, (3, 3), stride=1, padding=1), |
|
|
nn.PixelShuffle(2), |
|
|
nn.LeakyReLU(), |
|
|
nn.Conv2d(64, 256, (3, 3), stride=1, padding=1), |
|
|
nn.PixelShuffle(2), |
|
|
nn.LeakyReLU(), |
|
|
nn.Conv2d(64, 3, (1, 1), stride=1, padding=0), |
|
|
nn.Tanh(), |
|
|
|
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.conv(x) |
|
|
x_ = x.clone() |
|
|
|
|
|
for layer in self.short_blocks: |
|
|
x_ = layer(x_) |
|
|
|
|
|
x = torch.cat([self.conv2(x_), x], dim=1) |
|
|
|
|
|
x = self.conv3(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class D_Block(nn.Module): |
|
|
def __init__(self, in_channels, out_channels, stride=2): |
|
|
super().__init__() |
|
|
|
|
|
self.layer = nn.Sequential( |
|
|
nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=1), |
|
|
nn.BatchNorm2d(out_channels), |
|
|
nn.LeakyReLU(), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
return self.layer(x) |
|
|
|
|
|
|
|
|
class Discriminator(nn.Module): |
|
|
def __init__(self, img_size=[64,64], in_channels=3): |
|
|
super().__init__() |
|
|
|
|
|
self.conv_1_1 = nn.Sequential( |
|
|
nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1), nn.LeakyReLU() |
|
|
) |
|
|
|
|
|
self.block_1_1 = D_Block(64, 64, stride=2) |
|
|
self.block_1_2 = D_Block(64, 128, stride=1) |
|
|
self.block_1_3 = D_Block(128, 128) |
|
|
|
|
|
|
|
|
|
|
|
self.conv_2_1 = nn.Sequential( |
|
|
nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1), nn.LeakyReLU() |
|
|
) |
|
|
|
|
|
self.block_2_2 = nn.Sequential( |
|
|
D_Block(64, 128, stride=1), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.LeakyReLU() |
|
|
) |
|
|
|
|
|
self.block3 = D_Block(256, 256, stride=1) |
|
|
self.block4 = D_Block(256, 256) |
|
|
self.block5 = D_Block(256, 512, stride=1) |
|
|
self.block6 = D_Block(512, 512) |
|
|
self.block7 = D_Block(512, 1024) |
|
|
self.block8 = D_Block(1024, 1024) |
|
|
|
|
|
self.flatten = nn.Flatten() |
|
|
|
|
|
|
|
|
self.fc1 = nn.Sequential(nn.Linear(4096, 100), nn.LeakyReLU()) |
|
|
self.fc2 = nn.Linear(100, 2) |
|
|
|
|
|
self.relu = nn.LeakyReLU(negative_slope=0.2) |
|
|
self.sigmoid = nn.Sigmoid() |
|
|
|
|
|
def forward(self, x1, x2): |
|
|
''' |
|
|
x1 is the array for generator image |
|
|
x2 is the array for lr image |
|
|
''' |
|
|
x_1 = self.block_1_3(self.block_1_2(self.block_1_1(self.conv_1_1(x1)))) |
|
|
x_2 = self.block_2_2(self.conv_2_1(x2)) |
|
|
|
|
|
x = torch.cat([x_1, x_2], dim=1) |
|
|
x = self.block8( |
|
|
self.block7(self.block6(self.block5(self.block4(self.block3(x))))) |
|
|
) |
|
|
|
|
|
x = self.flatten(x) |
|
|
x = self.fc1(x) |
|
|
x = self.fc2(self.relu(x)) |
|
|
return self.sigmoid(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MeanShift(nn.Conv2d): |
|
|
def __init__( |
|
|
self, rgb_range = 1, |
|
|
norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225), sign=-1): |
|
|
|
|
|
super(MeanShift, self).__init__(3, 3, kernel_size=1) |
|
|
std = torch.Tensor(norm_std) |
|
|
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) |
|
|
self.bias.data = sign * rgb_range * torch.Tensor(norm_mean) / std |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
for p in self.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
class perceptual_loss(nn.Module): |
|
|
def __init__(self, vgg): |
|
|
super(perceptual_loss, self).__init__() |
|
|
self.normalization_mean = [0.485, 0.456, 0.406] |
|
|
self.normalization_std = [0.229, 0.224, 0.225] |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self.transform = MeanShift(norm_mean=self.normalization_mean, norm_std=self.normalization_std).to(self.device) |
|
|
self.vgg = vgg.to(self.device) |
|
|
self.criterion = nn.MSELoss() |
|
|
|
|
|
def forward(self, HR, SR, layer='relu5_4'): |
|
|
hr = self.transform(HR).to(self.device) |
|
|
sr = self.transform(SR).to(self.device) |
|
|
|
|
|
vgg_outputs_hr = self.vgg(hr) |
|
|
vgg_outputs_sr = self.vgg(sr) |
|
|
vgg_layer = self.vgg.vgg_layer |
|
|
|
|
|
|
|
|
hr_feat = vgg_outputs_hr[vgg_layer.index(layer)].to(self.device) |
|
|
sr_feat = vgg_outputs_sr[vgg_layer.index(layer)].to(self.device) |
|
|
|
|
|
return self.criterion(hr_feat, sr_feat), hr_feat, sr_feat |
|
|
|
|
|
class TVLoss(nn.Module): |
|
|
def __init__(self, tv_loss_weight=1): |
|
|
super(TVLoss, self).__init__() |
|
|
self.tv_loss_weight = tv_loss_weight |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size = x.size()[0] |
|
|
h_x = x.size()[2] |
|
|
w_x = x.size()[3] |
|
|
count_h = self.tensor_size(x[:, :, 1:, :]) |
|
|
count_w = self.tensor_size(x[:, :, :, 1:]) |
|
|
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() |
|
|
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() |
|
|
|
|
|
return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size |
|
|
|
|
|
@staticmethod |
|
|
def tensor_size(t): |
|
|
return t.size()[1] * t.size()[2] * t.size()[3] |
|
|
|
|
|
|
|
|
|
|
|
config = { |
|
|
'LR_path': '/kaggle/input/lumber-spine-dataset/LRimages', |
|
|
'GT_path': '/kaggle/input/lumber-spine-dataset/HRimages', |
|
|
'res_num': 16, |
|
|
'num_workers': 0, |
|
|
'batch_size': 16, |
|
|
'L2_coeff': 1.0, |
|
|
'adv_coeff': 1e-3, |
|
|
'tv_loss_coeff': 0.0, |
|
|
'pre_train_epoch': 16, |
|
|
'fine_train_epoch': 8, |
|
|
'scale': 4, |
|
|
'patch_size': 24, |
|
|
'feat_layer': 'relu5_4', |
|
|
'vgg_rescale_coeff': 0.006, |
|
|
'fine_tuning': False, |
|
|
'in_memory': True, |
|
|
'generator_path': None, |
|
|
'mode': 'train' |
|
|
} |
|
|
|
|
|
|
|
|
def preprocess_input(input_image): |
|
|
input_image = np.array(input_image) / 127.5 - 1.0 |
|
|
input_image = input_image.transpose(2, 0, 1).astype(np.float32) |
|
|
return torch.tensor(input_image).unsqueeze(0) |
|
|
|
|
|
|
|
|
def test_single_image( input_image): |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
generator_path = './MedSRGAN_gene_016.pt' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generator = Generator() |
|
|
state_dict = torch.load(generator_path, map_location='cpu') |
|
|
if 'module.' in list(state_dict.keys())[0]: |
|
|
state_dict = {k[7:]: v for k, v in state_dict.items()} |
|
|
|
|
|
generator.load_state_dict(state_dict) |
|
|
generator.eval() |
|
|
|
|
|
|
|
|
input_image = preprocess_input(input_image).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = generator(input_image) |
|
|
output = output[0].cpu().numpy() |
|
|
output = (output + 1.0) / 2.0 |
|
|
output = output.transpose(1, 2, 0) |
|
|
|
|
|
output_image = to_pil_image(output) |
|
|
return output_image |
|
|
|
|
|
|
|
|
uploaded_image_data = gr.components.Image(type="pil", label="Upload Medical Image(Spine)") |
|
|
|
|
|
output = gr.components.Image(type="pil", label="Enhanced medical image") |
|
|
|
|
|
|
|
|
|
|
|
gr.Interface(test_single_image, inputs=uploaded_image_data , outputs=output, |
|
|
title="Medical Image Super-Resolution with GAN", |
|
|
description="Upload a medical image to see it enhanced. This model is trained on 22,352 low-resolution (LR) and high-resolution (HR) medical images for 64 epochs, followed by 16 fine-tuning epochs. The enlarged image is 4 times the original size." |
|
|
).launch() |