File size: 7,313 Bytes
ef2a093
 
 
 
 
 
 
 
 
 
 
 
e28d517
ef2a093
 
 
 
8f3a127
ef2a093
 
 
 
8f3a127
 
 
 
 
ef2a093
8f3a127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef2a093
8f3a127
 
 
 
 
 
 
 
 
 
 
ef2a093
 
8f3a127
ef2a093
8f3a127
 
 
 
 
 
 
 
 
 
 
 
ef2a093
 
8f3a127
 
ef2a093
 
 
 
8f3a127
 
 
 
ef2a093
 
8f3a127
 
ef2a093
 
8f3a127
 
 
 
 
ef2a093
8f3a127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef2a093
 
 
 
 
 
 
 
 
 
 
 
314a753
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import gradio as gr
import os.path
import numpy as np
from collections import OrderedDict
import torch
import cv2
from PIL import Image, ImageOps
import utils_image as util
from network_fbcnn import FBCNN as net
import requests
import datetime

print(gr.__version__)
for model_path in ['fbcnn_gray.pth','fbcnn_color.pth']:
    if os.path.exists(model_path):
        print(f'{model_path} exists.')
    else:
        print("Downloading model: ", f'{model_path}')
        url = 'https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/{}'.format(os.path.basename(model_path))
        r = requests.get(url, allow_redirects=True)
        open(model_path, 'wb').write(r.content)    

def inference(filepaths, is_gray, res_percentage, input_quality):
    outputs = []
    before_afters = []
    if filepaths is None:
      return [], None
    
    for filepath, *_ in filepaths:
      filename = os.path.basename(filepath)
      print("Processing: ", filename)
      input_img = np.array(Image.open(filepath).convert("RGB"))
      
      print("Datetime: ", datetime.datetime.utcnow())
      input_img_width, input_img_height = Image.fromarray(input_img).size
      print("Img size: ", (input_img_width, input_img_height))
  
      resized_input = Image.fromarray(input_img).resize(
          (
              int(input_img_width * (res_percentage/100)),
              int(input_img_height * (res_percentage/100))
          ), resample = Image.BICUBIC)
      input_img = np.array(resized_input)
      print("Input image resized to: ", resized_input.size)
  
      if is_gray:
          n_channels = 1
          model_name = 'fbcnn_gray.pth'
      else:
          n_channels = 3
          model_name = 'fbcnn_color.pth'
      nc = [64,128,256,512]
      nb = 4
  
      input_quality = 100 - input_quality
  
      model_path = model_name
  
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      print("Device: ", device)
  
      print(f'Loading model from {model_path}')
      
      model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R')
      print("#model.load_state_dict(torch.load(model_path), strict=True)")
      model.load_state_dict(torch.load(model_path), strict=True)
      print("#model.eval()")
      model.eval()
      print("#for k, v in model.named_parameters()")
      for k, v in model.named_parameters():
          v.requires_grad = False
      print("#model.to(device)")
      model = model.to(device)
      print("Model loaded.")
  
      test_results = OrderedDict()
      test_results['psnr'] = []
      test_results['ssim'] = []
      test_results['psnrb'] = []
  
      print("#if n_channels")
      if n_channels == 1:
          open_cv_image = Image.fromarray(input_img)
          open_cv_image = ImageOps.grayscale(open_cv_image)
          open_cv_image = np.array(open_cv_image)
          img = np.expand_dims(open_cv_image, axis=2)
      elif n_channels == 3:
          open_cv_image = np.array(input_img)
          if open_cv_image.ndim == 2:
              open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_GRAY2RGB)
          else:
              open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)
  
      print("#util.uint2tensor4(open_cv_image)")
      img_L = util.uint2tensor4(open_cv_image)
      
      print("#img_L.to(device)")
      img_L = img_L.to(device)
  
      print("#model(img_L)")
      img_E, QF = model(img_L)
      print("#util.tensor2single(img_E)")
      img_E = util.tensor2single(img_E)
      print("#util.single2uint(img_E)")
      img_E = util.single2uint(img_E)
      
      print("#torch.tensor([[1-input_quality/100]]).cuda() || torch.tensor([[1-input_quality/100]])")
      qf_input = torch.tensor([[1-input_quality/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-input_quality/100]])
      print("#util.single2uint(img_E)")
      img_E, QF = model(img_L, qf_input)  
  
      print("#util.tensor2single(img_E)")
      img_E = util.tensor2single(img_E)
      print("#util.single2uint(img_E)")
      img_E = util.single2uint(img_E)
  
      if img_E.ndim == 3:
          img_E = img_E[:, :, [2, 1, 0]]
  
      print("--inference finished")
      
      outputs.append((img_E, f'{filename}'))
      before_afters.append((input_img, img_E))
      
    return outputs, before_afters
    
def select_image(event: gr.SelectData, before_afters):
  index = event.index
  if index is None or index >= len(before_afters):
      return None
  return before_afters[index], index
  
def select_changed_image(index, before_afters):
  if index is None or index >= len(before_afters):
      return None
  return before_afters[index]
  
with gr.Blocks() as demo:
    gr.Markdown("# JPEG Artifacts Removal [FBCNN]")
    
    with gr.Row():
        input_img = gr.Gallery(
            label="Input Image(s)",
            file_types=['image'],
            type="filepath",
            height="auto"
        )

        output_img = gr.Gallery(
            label="Results",
            height="auto",
            interactive=False
        )
    
    is_gray = gr.Checkbox(label="Grayscale (Check this if your image is grayscale)")
    max_res = gr.Slider(1, 100, step=0.5, value=100, label="Output image resolution Percentage (Higher% = longer processing time)")
    input_quality = gr.Slider(1, 100, step=1, value=40, label="Intensity (Higher = stronger JPEG artifact removal)")
    
    run = gr.Button("Run")

    with gr.Row():
        before_afters = gr.State([])
        current_index = gr.State(0)
        before_after = gr.ImageSlider(label="Before/After (Select One)", value=None, height="auto")
        
    run.click(
        inference, 
        inputs=[input_img, is_gray, max_res, input_quality], 
        outputs=[output_img, before_afters]
    )
    
    output_img.select(
        select_image,
        inputs=[before_afters],
        outputs=[before_after, current_index]
    )
    
    output_img.change(
        select_changed_image,
        inputs=[current_index, before_afters],
        outputs=[before_after]
    )
    
    gr.Examples(
      examples=[
        [[("doraemon.jpg", "doraemon.jpg")], False, 100, 60],
        [[("tomandjerry.jpg", "tomandjerry.jpg")], False,  100, 60],
        [[("somepanda.jpg", "somepanda.jpg")], True,  100, 100],
        [[("cemetry.jpg", "cemetry.jpg")], False,  100, 70],
        [[("michelangelo_david.jpg", "michelangelo_david.jpg")], True,  100, 30],
        [[("elon_musk.jpg", "elon_musk.jpg")], False, 100, 45],
        [[("text.jpg", "text.jpg")], True,  100, 70]
      ], 
      inputs=[input_img, is_gray, max_res, input_quality], 
      outputs=[output_img, before_afters]
    )
    
    gr.Markdown("""
    JPEG Artifacts are noticeable distortions of images caused by JPEG lossy compression.
    Note that this is not an AI Upscaler, but just a JPEG Compression Artifact Remover.
    [Original Demo](https://huggingface.co/spaces/danielsapit/JPEG_Artifacts_Removal)
    [FBCNN GitHub Repo](https://github.com/jiaxi-jiang/FBCNN)  
    [Towards Flexible Blind JPEG Artifacts Removal (FBCNN, ICCV 2021)](https://arxiv.org/abs/2109.14573)  
    [Jiaxi Jiang](https://jiaxi-jiang.github.io/),  
    [Kai Zhang](https://cszn.github.io/),  
    [Radu Timofte](http://people.ee.ethz.ch/~timofter/)
    """)

demo.launch()