intelAIcorrect / app.py
snehilchatterjee's picture
Update app.py
23e5276 verified
Raw
History Blame Contribute Delete
13.9 kB
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from torch import nn
import cv2
from super_image import EdsrModel, ImageLoader
device='cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using: {device}')
class _conv(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias):
super(_conv, self).__init__(in_channels = in_channels, out_channels = out_channels,
kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True)
self.weight.data = torch.normal(torch.zeros((out_channels, in_channels, kernel_size, kernel_size)), 0.02)
self.bias.data = torch.zeros((out_channels))
for p in self.parameters():
p.requires_grad = True
class conv(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, BN = False, act = None, stride = 1, bias = True):
super(conv, self).__init__()
m = []
m.append(_conv(in_channels = in_channel, out_channels = out_channel,
kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True))
if BN:
m.append(nn.BatchNorm2d(num_features = out_channel))
if act is not None:
m.append(act)
self.body = nn.Sequential(*m)
def forward(self, x):
out = self.body(x)
return out
class ResBlock(nn.Module):
def __init__(self, channels, kernel_size, act = nn.ReLU(inplace = True), bias = True):
super(ResBlock, self).__init__()
m = []
m.append(conv(channels, channels, kernel_size, BN = True, act = act))
m.append(conv(channels, channels, kernel_size, BN = True, act = None))
self.body = nn.Sequential(*m)
def forward(self, x):
res = self.body(x)
res += x
return res
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, num_res_block, act = nn.ReLU(inplace = True)):
super(BasicBlock, self).__init__()
m = []
self.conv = conv(in_channels, out_channels, kernel_size, BN = False, act = act)
for i in range(num_res_block):
m.append(ResBlock(out_channels, kernel_size, act))
m.append(conv(out_channels, out_channels, kernel_size, BN = True, act = None))
self.body = nn.Sequential(*m)
def forward(self, x):
res = self.conv(x)
out = self.body(res)
out += res
return out
class Upsampler(nn.Module):
def __init__(self, channel, kernel_size, scale, act = nn.ReLU(inplace = True)):
super(Upsampler, self).__init__()
m = []
m.append(conv(channel, channel * scale * scale, kernel_size))
m.append(nn.PixelShuffle(scale))
if act is not None:
m.append(act)
self.body = nn.Sequential(*m)
def forward(self, x):
out = self.body(x)
return out
class discrim_block(nn.Module):
def __init__(self, in_feats, out_feats, kernel_size, act = nn.LeakyReLU(inplace = True)):
super(discrim_block, self).__init__()
m = []
m.append(conv(in_feats, out_feats, kernel_size, BN = True, act = act))
m.append(conv(out_feats, out_feats, kernel_size, BN = True, act = act, stride = 2))
self.body = nn.Sequential(*m)
def forward(self, x):
out = self.body(x)
return out
class MiniSRGAN(nn.Module):
def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, num_block = 8, act = nn.PReLU(), scale=4):
super(MiniSRGAN, self).__init__()
self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 9, BN = False, act = act)
resblocks = [ResBlock(channels = n_feats, kernel_size = 3, act = act) for _ in range(num_block)]
self.body = nn.Sequential(*resblocks)
self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = True, act = None)
if(scale == 4):
upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = 2, act = act) for _ in range(2)]
else:
upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = scale, act = act)]
self.tail = nn.Sequential(*upsample_blocks)
self.last_conv = conv(in_channel = n_feats, out_channel = img_feat, kernel_size = 3, BN = False, act = nn.Tanh())
def forward(self, x):
x = self.conv01(x)
_skip_connection = x
x = self.body(x)
x = self.conv02(x)
feat = x + _skip_connection
x = self.tail(feat)
x = self.last_conv(x)
return x, feat
class TinySRGAN(nn.Module):
def __init__(self, img_feat = 3, n_feats = 32, kernel_size = 3, num_block = 6, act = nn.PReLU(), scale=4):
super(TinySRGAN, self).__init__()
self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 9, BN = False, act = act)
resblocks = [ResBlock(channels = n_feats, kernel_size = 3, act = act) for _ in range(num_block)]
self.body = nn.Sequential(*resblocks)
self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = True, act = None)
if(scale == 4):
upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = 2, act = act) for _ in range(2)]
else:
upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = scale, act = act)]
self.tail = nn.Sequential(*upsample_blocks)
self.last_conv = conv(in_channel = n_feats, out_channel = img_feat, kernel_size = 3, BN = False, act = nn.Tanh())
def forward(self, x):
x = self.conv01(x)
_skip_connection = x
x = self.body(x)
x = self.conv02(x)
feat = x + _skip_connection
x = self.tail(feat)
x = self.last_conv(x)
return x, feat
def build_generator():
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, expansion=6, stride=1, alpha=1.0):
super(ResidualBlock, self).__init__()
self.expansion = expansion
self.stride = stride
self.in_channels = in_channels
self.out_channels = int(out_channels * alpha)
self.pointwise_conv_filters = self._make_divisible(self.out_channels, 8)
self.conv1 = nn.Conv2d(in_channels, in_channels * expansion, kernel_size=1, stride=1, padding=0, bias=True)
self.bn1 = nn.BatchNorm2d(in_channels * expansion)
self.conv2 = nn.Conv2d(in_channels * expansion, in_channels * expansion, kernel_size=3, stride=stride, padding=1, groups=in_channels * expansion, bias=True)
self.bn2 = nn.BatchNorm2d(in_channels * expansion)
self.conv3 = nn.Conv2d(in_channels * expansion, self.pointwise_conv_filters, kernel_size=1, stride=1, padding=0, bias=True)
self.bn3 = nn.BatchNorm2d(self.pointwise_conv_filters)
self.relu = nn.ReLU(inplace=True)
self.skip_add = (stride == 1 and in_channels == self.pointwise_conv_filters)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.skip_add:
out = out + identity
return out
@staticmethod
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class Generator(nn.Module):
def __init__(self, in_channels, num_residual_blocks, gf):
super(Generator, self).__init__()
self.num_residual_blocks = num_residual_blocks
self.gf = gf
self.conv1 = nn.Conv2d(in_channels, gf, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(gf)
self.prelu1 = nn.PReLU()
self.residual_blocks = self.make_layer(ResidualBlock, gf, num_residual_blocks)
self.conv2 = nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(gf)
self.upsample1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1),
nn.PReLU()
)
self.upsample2 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(gf, gf, kernel_size=3, stride=1, padding=1),
nn.PReLU()
)
self.conv3 = nn.Conv2d(gf, 3, kernel_size=3, stride=1, padding=1)
self.tanh = nn.Tanh()
def make_layer(self, block, out_channels, blocks):
layers = []
for _ in range(blocks):
layers.append(block(out_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
out1 = self.prelu1(self.bn1(self.conv1(x)))
out = self.residual_blocks(out1)
out = self.bn2(self.conv2(out))
out = out + out1
out = self.upsample1(out)
out = self.upsample2(out)
out = self.tanh(self.conv3(out))
return out
return Generator(3, 6, 32)
def numpify(imgs):
all_images = []
for img in imgs:
img = img.permute(1,2,0).to('cpu') ### MIGHT CRASH HERE
all_images.append(img)
return np.stack(all_images, axis=0)
transform = transforms.Compose([
transforms.ToTensor()
])
# Function to translate the image
def translate_image(image, sharpen, model_name, save):
print('Translating!')
desired_width = 480
original_width, original_height = image.size
desired_height = int((original_height / original_width) * desired_width)
resized_image = image.resize((desired_width, desired_height))
if(model_name=='MobileSR'):
model=build_generator().to(device)
model.load_state_dict(torch.load('./weights/mobile_sr.pt', map_location=torch.device('cpu')))
low_res = transform(resized_image)
low_res = low_res.unsqueeze(dim=0).to(device)
model.eval()
with torch.no_grad():
sr = model(low_res)
fake_imgs = numpify(sr)
sr_img = Image.fromarray((((fake_imgs[0] + 1) / 2) * 255).astype(np.uint8))
elif(model_name=='MiniSRGAN'):
model = MiniSRGAN().to(device)
model.load_state_dict(torch.load('./weights/miniSRGAN.pt', map_location=torch.device('cpu')))
model.eval()
inputs = np.array(resized_image)
inputs = (inputs / 127.5) - 1.0
inputs = torch.tensor(inputs.transpose(2, 0, 1).astype(np.float32)).to(device)
with torch.no_grad():
output, _ = model(torch.unsqueeze(inputs,dim=0))
output = output[0].cpu().numpy()
output = np.clip(output, -1.0, 1.0)
output = (output + 1.0) / 2.0
output = output.transpose(1, 2, 0)
sr_img = Image.fromarray((output * 255.0).astype(np.uint8))
elif(model_name=='TinySRGAN'):
model = TinySRGAN().to(device)
model.load_state_dict(torch.load('./weights/tinySRGAN.pt', map_location=torch.device('cpu')))
inputs = np.array(resized_image)
inputs = (inputs / 127.5) - 1.0
inputs = torch.tensor(inputs.transpose(2, 0, 1).astype(np.float32)).to(device)
model.eval()
with torch.no_grad():
output, _ = model(torch.unsqueeze(inputs,dim=0))
output = output[0].cpu().numpy()
output = (output + 1.0) / 2.0
output = output.transpose(1, 2, 0)
sr_img = Image.fromarray((output * 255.0).astype(np.uint8))
if sharpen:
sr_img_cv = np.array(sr_img)
sr_img_cv = cv2.cvtColor(sr_img_cv, cv2.COLOR_RGB2BGR)
kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
sharpened_sr_img_cv = cv2.filter2D(sr_img_cv, -1, kernel)
sharpened_sr_img = Image.fromarray(cv2.cvtColor(sharpened_sr_img_cv, cv2.COLOR_BGR2RGB))
if(save=="True"):
sharpened_sr_img.save('super_resolved_image.png')
return sharpened_sr_img
else:
if(save=="True"):
sr_img.save('super_resolved_image.png')
return sr_img
interface = gr.Interface(
fn=translate_image,
inputs=[
gr.Image(type="pil"),
gr.Checkbox(label="Sharpen Image"),
gr.Radio(choices=["MobileSR", "MiniSRGAN", "TinySRGAN"], label="Select Model", value="MobileSR"),
gr.Radio(choices=["True", "False"], label="Save Output", value="False")
],
outputs=gr.Image(type="pil", label="Translated Image"),
title="Correction App",
description="Upload an image and get the translated version. Some images may be blurry, you can tick the checkbox to sharpen them. Choose between three different models for translation.",
allow_flagging='never'
)
# Launch the Gradio app
interface.launch()