FBCNN / app.py
KenjieDec's picture
Update app.py
e28d517 verified
import gradio as gr
import os.path
import numpy as np
from collections import OrderedDict
import torch
import cv2
from PIL import Image, ImageOps
import utils_image as util
from network_fbcnn import FBCNN as net
import requests
import datetime
print(gr.__version__)
for model_path in ['fbcnn_gray.pth','fbcnn_color.pth']:
if os.path.exists(model_path):
print(f'{model_path} exists.')
else:
print("Downloading model: ", f'{model_path}')
url = 'https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/{}'.format(os.path.basename(model_path))
r = requests.get(url, allow_redirects=True)
open(model_path, 'wb').write(r.content)
def inference(filepaths, is_gray, res_percentage, input_quality):
outputs = []
before_afters = []
if filepaths is None:
return [], None
for filepath, *_ in filepaths:
filename = os.path.basename(filepath)
print("Processing: ", filename)
input_img = np.array(Image.open(filepath).convert("RGB"))
print("Datetime: ", datetime.datetime.utcnow())
input_img_width, input_img_height = Image.fromarray(input_img).size
print("Img size: ", (input_img_width, input_img_height))
resized_input = Image.fromarray(input_img).resize(
(
int(input_img_width * (res_percentage/100)),
int(input_img_height * (res_percentage/100))
), resample = Image.BICUBIC)
input_img = np.array(resized_input)
print("Input image resized to: ", resized_input.size)
if is_gray:
n_channels = 1
model_name = 'fbcnn_gray.pth'
else:
n_channels = 3
model_name = 'fbcnn_color.pth'
nc = [64,128,256,512]
nb = 4
input_quality = 100 - input_quality
model_path = model_name
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device: ", device)
print(f'Loading model from {model_path}')
model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R')
print("#model.load_state_dict(torch.load(model_path), strict=True)")
model.load_state_dict(torch.load(model_path), strict=True)
print("#model.eval()")
model.eval()
print("#for k, v in model.named_parameters()")
for k, v in model.named_parameters():
v.requires_grad = False
print("#model.to(device)")
model = model.to(device)
print("Model loaded.")
test_results = OrderedDict()
test_results['psnr'] = []
test_results['ssim'] = []
test_results['psnrb'] = []
print("#if n_channels")
if n_channels == 1:
open_cv_image = Image.fromarray(input_img)
open_cv_image = ImageOps.grayscale(open_cv_image)
open_cv_image = np.array(open_cv_image)
img = np.expand_dims(open_cv_image, axis=2)
elif n_channels == 3:
open_cv_image = np.array(input_img)
if open_cv_image.ndim == 2:
open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_GRAY2RGB)
else:
open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)
print("#util.uint2tensor4(open_cv_image)")
img_L = util.uint2tensor4(open_cv_image)
print("#img_L.to(device)")
img_L = img_L.to(device)
print("#model(img_L)")
img_E, QF = model(img_L)
print("#util.tensor2single(img_E)")
img_E = util.tensor2single(img_E)
print("#util.single2uint(img_E)")
img_E = util.single2uint(img_E)
print("#torch.tensor([[1-input_quality/100]]).cuda() || torch.tensor([[1-input_quality/100]])")
qf_input = torch.tensor([[1-input_quality/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-input_quality/100]])
print("#util.single2uint(img_E)")
img_E, QF = model(img_L, qf_input)
print("#util.tensor2single(img_E)")
img_E = util.tensor2single(img_E)
print("#util.single2uint(img_E)")
img_E = util.single2uint(img_E)
if img_E.ndim == 3:
img_E = img_E[:, :, [2, 1, 0]]
print("--inference finished")
image_path = check_file_exist("output_images", filename)
outputs.append((img_E, f'{filename}'))
before_afters.append((input_img, img_E))
return outputs, before_afters
def select_image(event: gr.SelectData, before_afters):
index = event.index
if index is None or index >= len(before_afters):
return None
return before_afters[index], index
def select_changed_image(index, before_afters):
if index is None or index >= len(before_afters):
return None
return before_afters[index]
with gr.Blocks() as demo:
gr.Markdown("# JPEG Artifacts Removal [FBCNN]")
with gr.Row():
input_img = gr.Gallery(
label="Input Image(s)",
file_types=['image'],
type="filepath",
height="auto"
)
output_img = gr.Gallery(
label="Results",
height="auto",
interactive=False
)
is_gray = gr.Checkbox(label="Grayscale (Check this if your image is grayscale)")
max_res = gr.Slider(1, 100, step=0.5, value=100, label="Output image resolution Percentage (Higher% = longer processing time)")
input_quality = gr.Slider(1, 100, step=1, value=40, label="Intensity (Higher = stronger JPEG artifact removal)")
run = gr.Button("Run")
with gr.Row():
before_afters = gr.State([])
current_index = gr.State(0)
before_after = gr.ImageSlider(label="Before/After (Select One)", value=None, height="auto")
run.click(
inference,
inputs=[input_img, is_gray, max_res, input_quality],
outputs=[output_img, before_afters]
)
output_img.select(
select_image,
inputs=[before_afters],
outputs=[before_after, current_index]
)
output_img.change(
select_changed_image,
inputs=[current_index, before_afters],
outputs=[before_after]
)
gr.Examples(
examples=[
[[("doraemon.jpg", "doraemon.jpg")], False, 100, 60],
[[("tomandjerry.jpg", "tomandjerry.jpg")], False, 100, 60],
[[("somepanda.jpg", "somepanda.jpg")], True, 100, 100],
[[("cemetry.jpg", "cemetry.jpg")], False, 100, 70],
[[("michelangelo_david.jpg", "michelangelo_david.jpg")], True, 100, 30],
[[("elon_musk.jpg", "elon_musk.jpg")], False, 100, 45],
[[("text.jpg", "text.jpg")], True, 100, 70]
],
inputs=[input_img, is_gray, max_res, input_quality],
outputs=[output_img, before_afters]
)
gr.Markdown("""
JPEG Artifacts are noticeable distortions of images caused by JPEG lossy compression.
Note that this is not an AI Upscaler, but just a JPEG Compression Artifact Remover.
[Original Demo](https://huggingface.co/spaces/danielsapit/JPEG_Artifacts_Removal)
[FBCNN GitHub Repo](https://github.com/jiaxi-jiang/FBCNN)
[Towards Flexible Blind JPEG Artifacts Removal (FBCNN, ICCV 2021)](https://arxiv.org/abs/2109.14573)
[Jiaxi Jiang](https://jiaxi-jiang.github.io/),
[Kai Zhang](https://cszn.github.io/),
[Radu Timofte](http://people.ee.ethz.ch/~timofter/)
""")
demo.launch()