import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import transforms import torchvision.models as models import numpy as np from PIL import Image from skimage.color import rgb2ycbcr from skimage.metrics import peak_signal_noise_ratio,structural_similarity import matplotlib.pyplot as plt from torch.utils.data import Dataset import random from torch.nn.parallel import DataParallel import os import matplotlib.pyplot as plt import gradio as gr import io from torchvision.transforms.functional import to_pil_image class mydata(Dataset): def __init__(self, LR_path, GT_path, in_memory = True, transform = None): self.LR_path = LR_path self.GT_path = GT_path self.in_memory = in_memory self.transform = transform # LR image and GT image path used self.LR_img = sorted(os.listdir(LR_path)) self.GT_img = sorted(os.listdir(GT_path)) if in_memory: self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr)).convert("RGB")).astype(np.uint8) for lr in self.LR_img] self.GT_img = [np.array(Image.open(os.path.join(self.GT_path, gt)).convert("RGB")).astype(np.uint8) for gt in self.GT_img] def __len__(self): return len(self.LR_img) def __getitem__(self, i): img_item = {} if self.in_memory: GT = self.GT_img[i].astype(np.float32) LR = self.LR_img[i].astype(np.float32) else: GT = np.array(Image.open(os.path.join(self.GT_path, self.GT_img[i])).convert("RGB")) LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i])).convert("RGB")) img_item['GT'] = (GT / 127.5) - 1.0 img_item['LR'] = (LR / 127.5) - 1.0 if self.transform is not None: img_item = self.transform(img_item) img_item['GT'] = img_item['GT'].transpose(2, 0, 1).astype(np.float32) img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32) return img_item class testOnly_data(Dataset): def __init__(self, LR_path, in_memory = True, transform = None): self.LR_path = LR_path self.LR_img = sorted(os.listdir(LR_path)) self.in_memory = in_memory if in_memory: self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr))) for lr in self.LR_img] def __len__(self): return len(self.LR_img) def __getitem__(self, i): img_item = {} if self.in_memory: LR = self.LR_img[i] else: LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i]))) img_item['LR'] = (LR / 127.5) - 1.0 img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32) return img_item class crop(object): def __init__(self, scale, patch_size): self.scale = scale self.patch_size = patch_size def __call__(self, sample): LR_img, GT_img = sample['LR'], sample['GT'] ih, iw = LR_img.shape[:2] ix = random.randrange(0, iw - self.patch_size +1) iy = random.randrange(0, ih - self.patch_size +1) tx = ix * self.scale ty = iy * self.scale LR_patch = LR_img[iy : iy + self.patch_size, ix : ix + self.patch_size] GT_patch = GT_img[ty : ty + (self.scale * self.patch_size), tx : tx + (self.scale * self.patch_size)] return {'LR' : LR_patch, 'GT' : GT_patch} class augmentation(object): def __call__(self, sample): LR_img, GT_img = sample['LR'], sample['GT'] hor_flip = random.randrange(0,2) ver_flip = random.randrange(0,2) rot = random.randrange(0,2) if hor_flip: temp_LR = np.fliplr(LR_img) LR_img = temp_LR.copy() temp_GT = np.fliplr(GT_img) GT_img = temp_GT.copy() del temp_LR, temp_GT if ver_flip: temp_LR = np.flipud(LR_img) LR_img = temp_LR.copy() temp_GT = np.flipud(GT_img) GT_img = temp_GT.copy() del temp_LR, temp_GT if rot: LR_img = LR_img.transpose(1, 0, 2) GT_img = GT_img.transpose(1, 0, 2) return {'LR' : LR_img, 'GT' : GT_img} class RWMAB(nn.Module): def __init__(self, in_channels): super().__init__() self.layer1 = nn.Sequential( nn.Conv2d(in_channels, in_channels, (3, 3), stride=1, padding=1), nn.ReLU(), nn.Conv2d(in_channels, in_channels, (3, 3), stride=1, padding=1), ) self.layer2 = nn.Sequential( nn.Conv2d(in_channels, in_channels, (1, 1), stride=1, padding=0), nn.Sigmoid(), ) def forward(self, x): # print('In RWMAB') x_ = self.layer1(x) x__ = self.layer2(x_) x = x__ * x_ + x return x class ShortResidualBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.layers = nn.ModuleList([RWMAB(in_channels) for _ in range(16)]) def forward(self, x): # print('In Short Residual Block') x_ = x.clone() for layer in self.layers: x_ = layer(x_) return x_ + x class Generator(nn.Module): def __init__(self, in_channels=3, blocks=8, scale = 4): super().__init__() # Add the noise part here self.conv = nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1) self.short_blocks = nn.ModuleList( [ShortResidualBlock(64) for _ in range(blocks)] ) self.conv2 = nn.Conv2d(64, 64, (1, 1), stride=1, padding=0) self.conv3 = nn.Sequential( nn.Conv2d(128, 256, (3, 3), stride=1, padding=1), nn.PixelShuffle(2), nn.LeakyReLU(), nn.Conv2d(64, 256, (3, 3), stride=1, padding=1), nn.PixelShuffle(2), # Remove if output is 2x the input nn.LeakyReLU(), nn.Conv2d(64, 3, (1, 1), stride=1, padding=0), # Change 64 -> 256 nn.Tanh(), # nn.Sigmoid(), ) def forward(self, x): # print('In Generator') x = self.conv(x) x_ = x.clone() for layer in self.short_blocks: x_ = layer(x_) x = torch.cat([self.conv2(x_), x], dim=1) x = self.conv3(x) return x class D_Block(nn.Module): def __init__(self, in_channels, out_channels, stride=2): super().__init__() self.layer = nn.Sequential( nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=1), nn.BatchNorm2d(out_channels), nn.LeakyReLU(), ) def forward(self, x): return self.layer(x) class Discriminator(nn.Module): def __init__(self, img_size=[64,64], in_channels=3): super().__init__() # Layers for generator image size self.conv_1_1 = nn.Sequential( nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1), nn.LeakyReLU() ) self.block_1_1 = D_Block(64, 64, stride=2) # stride= 2 if output 4x self.block_1_2 = D_Block(64, 128, stride=1) self.block_1_3 = D_Block(128, 128) # Layer for LR image size self.conv_2_1 = nn.Sequential( nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1), nn.LeakyReLU() ) self.block_2_2 = nn.Sequential( D_Block(64, 128, stride=1), nn.BatchNorm2d(128), nn.LeakyReLU() ) # D_Block is the class defined above self.block3 = D_Block(256, 256, stride=1) self.block4 = D_Block(256, 256) self.block5 = D_Block(256, 512, stride=1) self.block6 = D_Block(512, 512) self.block7 = D_Block(512, 1024) self.block8 = D_Block(1024, 1024) self.flatten = nn.Flatten() #self.fc1 = nn.Sequential(nn.Linear(1024 * img_size[0] * img_size[1] // 256, 100), nn.LeakyReLU() )# Change based on input image size self.fc1 = nn.Sequential(nn.Linear(4096, 100), nn.LeakyReLU()) self.fc2 = nn.Linear(100, 2) self.relu = nn.LeakyReLU(negative_slope=0.2) self.sigmoid = nn.Sigmoid() def forward(self, x1, x2): ''' x1 is the array for generator image x2 is the array for lr image ''' x_1 = self.block_1_3(self.block_1_2(self.block_1_1(self.conv_1_1(x1)))) x_2 = self.block_2_2(self.conv_2_1(x2)) x = torch.cat([x_1, x_2], dim=1) x = self.block8( self.block7(self.block6(self.block5(self.block4(self.block3(x))))) ) x = self.flatten(x) x = self.fc1(x) x = self.fc2(self.relu(x)) return self.sigmoid(x) class MeanShift(nn.Conv2d): def __init__( self, rgb_range = 1, norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225), sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(norm_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) self.bias.data = sign * rgb_range * torch.Tensor(norm_mean) / std self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") for p in self.parameters(): p.requires_grad = False class perceptual_loss(nn.Module): def __init__(self, vgg): super(perceptual_loss, self).__init__() self.normalization_mean = [0.485, 0.456, 0.406] self.normalization_std = [0.229, 0.224, 0.225] self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print("Using device:", self.device) self.transform = MeanShift(norm_mean=self.normalization_mean, norm_std=self.normalization_std).to(self.device) self.vgg = vgg.to(self.device) # Move the entire model to the GPU self.criterion = nn.MSELoss() def forward(self, HR, SR, layer='relu5_4'): hr = self.transform(HR).to(self.device) # Move HR tensor to the GPU sr = self.transform(SR).to(self.device) # Move SR tensor to the GPU vgg_outputs_hr = self.vgg(hr) vgg_outputs_sr = self.vgg(sr) vgg_layer = self.vgg.vgg_layer # Ensure that both hr_feat and sr_feat are on the same device hr_feat = vgg_outputs_hr[vgg_layer.index(layer)].to(self.device) sr_feat = vgg_outputs_sr[vgg_layer.index(layer)].to(self.device) return self.criterion(hr_feat, sr_feat), hr_feat, sr_feat class TVLoss(nn.Module): def __init__(self, tv_loss_weight=1): super(TVLoss, self).__init__() self.tv_loss_weight = tv_loss_weight def forward(self, x): batch_size = x.size()[0] h_x = x.size()[2] w_x = x.size()[3] count_h = self.tensor_size(x[:, :, 1:, :]) count_w = self.tensor_size(x[:, :, :, 1:]) h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size @staticmethod def tensor_size(t): return t.size()[1] * t.size()[2] * t.size()[3] config = { 'LR_path': '/kaggle/input/lumber-spine-dataset/LRimages', 'GT_path': '/kaggle/input/lumber-spine-dataset/HRimages', 'res_num': 16, 'num_workers': 0, 'batch_size': 16, 'L2_coeff': 1.0, 'adv_coeff': 1e-3, 'tv_loss_coeff': 0.0, 'pre_train_epoch': 16, 'fine_train_epoch': 8, 'scale': 4, 'patch_size': 24, 'feat_layer': 'relu5_4', 'vgg_rescale_coeff': 0.006, 'fine_tuning': False, 'in_memory': True, 'generator_path': None, # Set the path if available 'mode': 'train' } def preprocess_input(input_image): input_image = np.array(input_image) / 127.5 - 1.0 input_image = input_image.transpose(2, 0, 1).astype(np.float32) return torch.tensor(input_image).unsqueeze(0) def test_single_image( input_image): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # generator_path = './pre_trained_model_064.pt' generator_path = './MedSRGAN_gene_016.pt' # Load the generator model generator = Generator() #for ours state_dict = torch.load(generator_path, map_location='cpu') if 'module.' in list(state_dict.keys())[0]: state_dict = {k[7:]: v for k, v in state_dict.items()} generator.load_state_dict(state_dict) generator.eval() # Preprocess the input image input_image = preprocess_input(input_image).to(device) # Perform inference with torch.no_grad(): output = generator(input_image) output = output[0].cpu().numpy() output = (output + 1.0) / 2.0 output = output.transpose(1, 2, 0) # Convert output to PIL image format output_image = to_pil_image(output) return output_image uploaded_image_data = gr.components.Image(type="pil", label="Upload Medical Image(Spine)") # with gr.Column(scale=2, min_width=400): output = gr.components.Image(type="pil", label="Enhanced medical image") # Deploy the interface gr.Interface(test_single_image, inputs=uploaded_image_data , outputs=output, title="Medical Image Super-Resolution with GAN", description="Upload a medical image to see it enhanced. This model is trained on 22,352 low-resolution (LR) and high-resolution (HR) medical images for 64 epochs, followed by 16 fine-tuning epochs. The enlarged image is 4 times the original size." ).launch()