KenjieDec commited on
Commit
8f3a127
·
verified ·
1 Parent(s): 3349e70

Allowed Multiple Images Upload

Browse files

- Allowed upload of multiple images
- Used gradio's native ImageSlider as it's implemented now
- Removed zoom/xy-shift as it's redundant

Files changed (1) hide show
  1. app.py +161 -151
app.py CHANGED
@@ -9,186 +9,196 @@ import utils_image as util
9
  from network_fbcnn import FBCNN as net
10
  import requests
11
  import datetime
12
- from gradio_imageslider import ImageSlider
13
 
14
  for model_path in ['fbcnn_gray.pth','fbcnn_color.pth']:
15
  if os.path.exists(model_path):
16
  print(f'{model_path} exists.')
17
  else:
18
- print("downloading model")
19
  url = 'https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/{}'.format(os.path.basename(model_path))
20
  r = requests.get(url, allow_redirects=True)
21
  open(model_path, 'wb').write(r.content)
22
 
23
- def inference(input_img, is_gray, res_percentage, input_quality, zoom, x_shift, y_shift):
 
 
 
 
24
 
25
- print("datetime:", datetime.datetime.utcnow())
26
- input_img_width, input_img_height = Image.fromarray(input_img).size
27
- print("img size:", (input_img_width, input_img_height))
28
-
29
- resized_input = Image.fromarray(input_img).resize(
30
- (
31
- int(input_img_width * (res_percentage/100)),
32
- int(input_img_height * (res_percentage/100))
33
- ), resample = Image.BICUBIC)
34
- input_img = np.array(resized_input)
35
- print("input image resized to:", resized_input.size)
36
-
37
- if is_gray:
38
- n_channels = 1
39
- model_name = 'fbcnn_gray.pth'
40
- else:
41
- n_channels = 3
42
- model_name = 'fbcnn_color.pth'
43
- nc = [64,128,256,512]
44
- nb = 4
45
-
46
- input_quality = 100 - input_quality
47
-
48
- model_path = model_name
49
-
50
- if os.path.exists(model_path):
51
- print(f'{model_path} already exists.')
52
- else:
53
- print("downloading model")
54
- os.makedirs(os.path.dirname(model_path), exist_ok=True)
55
- url = 'https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/{}'.format(os.path.basename(model_path))
56
- r = requests.get(url, allow_redirects=True)
57
- open(model_path, 'wb').write(r.content)
58
-
59
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60
- print("device:", device)
61
-
62
- print(f'loading model from {model_path}')
63
-
64
- model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R')
65
- print("#model.load_state_dict(torch.load(model_path), strict=True)")
66
- model.load_state_dict(torch.load(model_path), strict=True)
67
- print("#model.eval()")
68
- model.eval()
69
- print("#for k, v in model.named_parameters()")
70
- for k, v in model.named_parameters():
71
- v.requires_grad = False
72
- print("#model.to(device)")
73
- model = model.to(device)
74
- print("Model loaded.")
75
-
76
- test_results = OrderedDict()
77
- test_results['psnr'] = []
78
- test_results['ssim'] = []
79
- test_results['psnrb'] = []
80
-
81
- print("#if n_channels")
82
- if n_channels == 1:
83
- open_cv_image = Image.fromarray(input_img)
84
- open_cv_image = ImageOps.grayscale(open_cv_image)
85
- open_cv_image = np.array(open_cv_image)
86
- img = np.expand_dims(open_cv_image, axis=2)
87
- elif n_channels == 3:
88
- open_cv_image = np.array(input_img)
89
- if open_cv_image.ndim == 2:
90
- open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_GRAY2RGB)
91
- else:
92
- open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)
93
-
94
- print("#util.uint2tensor4(open_cv_image)")
95
- img_L = util.uint2tensor4(open_cv_image)
96
-
97
- print("#img_L.to(device)")
98
- img_L = img_L.to(device)
99
-
100
- print("#model(img_L)")
101
- img_E, QF = model(img_L)
102
- print("#util.tensor2single(img_E)")
103
- img_E = util.tensor2single(img_E)
104
- print("#util.single2uint(img_E)")
105
- img_E = util.single2uint(img_E)
106
-
107
- print("#torch.tensor([[1-input_quality/100]]).cuda() || torch.tensor([[1-input_quality/100]])")
108
- qf_input = torch.tensor([[1-input_quality/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-input_quality/100]])
109
- print("#util.single2uint(img_E)")
110
- img_E, QF = model(img_L, qf_input)
111
-
112
- print("#util.tensor2single(img_E)")
113
- img_E = util.tensor2single(img_E)
114
- print("#util.single2uint(img_E)")
115
- img_E = util.single2uint(img_E)
116
-
117
- if img_E.ndim == 3:
118
- img_E = img_E[:, :, [2, 1, 0]]
119
-
120
- print("--inference finished")
121
-
122
- (in_img, out_img) = zoom_image(zoom, x_shift, y_shift, input_img, img_E)
123
- print("--generating preview finished")
124
-
125
- return img_E, (in_img, out_img)
126
-
127
- def zoom_image(zoom, x_shift, y_shift, input_img, output_img = None):
128
- if output_img is None:
129
- return None
130
-
131
- img = Image.fromarray(input_img)
132
- out_img = Image.fromarray(output_img)
133
-
134
- img_w, img_h = img.size
135
- zoom_factor = (100 - zoom) / 100
136
- x_shift /= 100
137
- y_shift /= 100
138
-
139
- zoom_w, zoom_h = int(img_w * zoom_factor), int(img_h * zoom_factor)
140
- x_offset = int((img_w - zoom_w) * x_shift)
141
- y_offset = int((img_h - zoom_h) * y_shift)
142
-
143
- crop_box = (x_offset, y_offset, x_offset + zoom_w, y_offset + zoom_h)
144
- img = img.resize((img_w, img_h), Image.BILINEAR).crop(crop_box)
145
- out_img = out_img.resize((img_w, img_h), Image.BILINEAR).crop(crop_box)
146
-
147
- return (img, out_img)
148
 
 
 
 
 
 
 
 
 
 
 
 
149
  with gr.Blocks() as demo:
150
  gr.Markdown("# JPEG Artifacts Removal [FBCNN]")
151
-
152
  with gr.Row():
153
- input_img = gr.Image(label="Input Image")
154
- output_img = gr.Image(label="Result", interactive=False)
 
 
 
 
 
 
 
 
 
 
155
 
156
  is_gray = gr.Checkbox(label="Grayscale (Check this if your image is grayscale)")
157
- max_res = gr.Slider(1, 100, step=0.5, label="Output image resolution Percentage (Higher% = longer processing time)")
158
- input_quality = gr.Slider(1, 100, step=1, label="Intensity (Higher = stronger JPEG artifact removal)")
159
- zoom = gr.Slider(0, 100, step=1, value=50, label="Zoom Percentage (0 = original size)")
160
- x_shift = gr.Slider(0, 100, step=1, label="Horizontal shift Percentage (Before/After)")
161
- y_shift = gr.Slider(0, 100, step=1, label="Vertical shift Percentage (Before/After)")
162
 
163
  run = gr.Button("Run")
164
 
165
  with gr.Row():
166
- before_after = ImageSlider(label="Before/After", type="pil", value=None)
167
-
 
 
168
  run.click(
169
  inference,
170
- inputs=[input_img, is_gray, max_res, input_quality, zoom, x_shift, y_shift],
171
- outputs=[output_img, before_after]
172
  )
173
 
174
- gr.Examples([
175
- ["doraemon.jpg", False, 100, 60, 58, 50, 50],
176
- ["tomandjerry.jpg", False, 100, 60, 60, 57, 44],
177
- ["somepanda.jpg", True, 100, 100, 70, 8, 24],
178
- ["cemetry.jpg", False, 100, 70, 80, 76, 62],
179
- ["michelangelo_david.jpg", True, 100, 30, 88, 53, 27],
180
- ["elon_musk.jpg", False, 100, 45, 75, 33, 30],
181
- ["text.jpg", True, 100, 70, 50, 11, 29]
182
- ], inputs=[input_img, is_gray, max_res, input_quality, zoom, x_shift, y_shift])
183
 
184
- zoom.release(zoom_image, inputs=[zoom, x_shift, y_shift, input_img, output_img], outputs=[before_after])
185
- x_shift.release(zoom_image, inputs=[zoom, x_shift, y_shift, input_img, output_img], outputs=[before_after])
186
- y_shift.release(zoom_image, inputs=[zoom, x_shift, y_shift, input_img, output_img], outputs=[before_after])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  gr.Markdown("""
189
  JPEG Artifacts are noticeable distortions of images caused by JPEG lossy compression.
190
  Note that this is not an AI Upscaler, but just a JPEG Compression Artifact Remover.
191
-
192
  [Original Demo](https://huggingface.co/spaces/danielsapit/JPEG_Artifacts_Removal)
193
  [FBCNN GitHub Repo](https://github.com/jiaxi-jiang/FBCNN)
194
  [Towards Flexible Blind JPEG Artifacts Removal (FBCNN, ICCV 2021)](https://arxiv.org/abs/2109.14573)
 
9
  from network_fbcnn import FBCNN as net
10
  import requests
11
  import datetime
 
12
 
13
  for model_path in ['fbcnn_gray.pth','fbcnn_color.pth']:
14
  if os.path.exists(model_path):
15
  print(f'{model_path} exists.')
16
  else:
17
+ print("Downloading model: ", f'{model_path}')
18
  url = 'https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/{}'.format(os.path.basename(model_path))
19
  r = requests.get(url, allow_redirects=True)
20
  open(model_path, 'wb').write(r.content)
21
 
22
+ def inference(filepaths, is_gray, res_percentage, input_quality):
23
+ outputs = []
24
+ before_afters = []
25
+ if filepaths is None:
26
+ return [], None
27
 
28
+ for filepath, *_ in filepaths:
29
+ filename = os.path.basename(filepath)
30
+ print("Processing: ", filename)
31
+ input_img = np.array(Image.open(filepath).convert("RGB"))
32
+
33
+ print("Datetime: ", datetime.datetime.utcnow())
34
+ input_img_width, input_img_height = Image.fromarray(input_img).size
35
+ print("Img size: ", (input_img_width, input_img_height))
36
+
37
+ resized_input = Image.fromarray(input_img).resize(
38
+ (
39
+ int(input_img_width * (res_percentage/100)),
40
+ int(input_img_height * (res_percentage/100))
41
+ ), resample = Image.BICUBIC)
42
+ input_img = np.array(resized_input)
43
+ print("Input image resized to: ", resized_input.size)
44
+
45
+ if is_gray:
46
+ n_channels = 1
47
+ model_name = 'fbcnn_gray.pth'
48
+ else:
49
+ n_channels = 3
50
+ model_name = 'fbcnn_color.pth'
51
+ nc = [64,128,256,512]
52
+ nb = 4
53
+
54
+ input_quality = 100 - input_quality
55
+
56
+ model_path = model_name
57
+
58
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
59
+ print("Device: ", device)
60
+
61
+ print(f'Loading model from {model_path}')
62
+
63
+ model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R')
64
+ print("#model.load_state_dict(torch.load(model_path), strict=True)")
65
+ model.load_state_dict(torch.load(model_path), strict=True)
66
+ print("#model.eval()")
67
+ model.eval()
68
+ print("#for k, v in model.named_parameters()")
69
+ for k, v in model.named_parameters():
70
+ v.requires_grad = False
71
+ print("#model.to(device)")
72
+ model = model.to(device)
73
+ print("Model loaded.")
74
+
75
+ test_results = OrderedDict()
76
+ test_results['psnr'] = []
77
+ test_results['ssim'] = []
78
+ test_results['psnrb'] = []
79
+
80
+ print("#if n_channels")
81
+ if n_channels == 1:
82
+ open_cv_image = Image.fromarray(input_img)
83
+ open_cv_image = ImageOps.grayscale(open_cv_image)
84
+ open_cv_image = np.array(open_cv_image)
85
+ img = np.expand_dims(open_cv_image, axis=2)
86
+ elif n_channels == 3:
87
+ open_cv_image = np.array(input_img)
88
+ if open_cv_image.ndim == 2:
89
+ open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_GRAY2RGB)
90
+ else:
91
+ open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)
92
+
93
+ print("#util.uint2tensor4(open_cv_image)")
94
+ img_L = util.uint2tensor4(open_cv_image)
95
+
96
+ print("#img_L.to(device)")
97
+ img_L = img_L.to(device)
98
+
99
+ print("#model(img_L)")
100
+ img_E, QF = model(img_L)
101
+ print("#util.tensor2single(img_E)")
102
+ img_E = util.tensor2single(img_E)
103
+ print("#util.single2uint(img_E)")
104
+ img_E = util.single2uint(img_E)
105
+
106
+ print("#torch.tensor([[1-input_quality/100]]).cuda() || torch.tensor([[1-input_quality/100]])")
107
+ qf_input = torch.tensor([[1-input_quality/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-input_quality/100]])
108
+ print("#util.single2uint(img_E)")
109
+ img_E, QF = model(img_L, qf_input)
110
+
111
+ print("#util.tensor2single(img_E)")
112
+ img_E = util.tensor2single(img_E)
113
+ print("#util.single2uint(img_E)")
114
+ img_E = util.single2uint(img_E)
115
+
116
+ if img_E.ndim == 3:
117
+ img_E = img_E[:, :, [2, 1, 0]]
118
+
119
+ print("--inference finished")
120
+
121
+ image_path = check_file_exist("output_images", filename)
122
+
123
+ outputs.append((img_E, f'{filename}'))
124
+ before_afters.append((input_img, img_E))
125
+
126
+ return outputs, before_afters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ def select_image(event: gr.SelectData, before_afters):
129
+ index = event.index
130
+ if index is None or index >= len(before_afters):
131
+ return None
132
+ return before_afters[index], index
133
+
134
+ def select_changed_image(index, before_afters):
135
+ if index is None or index >= len(before_afters):
136
+ return None
137
+ return before_afters[index]
138
+
139
  with gr.Blocks() as demo:
140
  gr.Markdown("# JPEG Artifacts Removal [FBCNN]")
141
+
142
  with gr.Row():
143
+ input_img = gr.Gallery(
144
+ label="Input Image(s)",
145
+ file_types=['image'],
146
+ type="filepath",
147
+ height="auto"
148
+ )
149
+
150
+ output_img = gr.Gallery(
151
+ label="Results",
152
+ height="auto",
153
+ interactive=False
154
+ )
155
 
156
  is_gray = gr.Checkbox(label="Grayscale (Check this if your image is grayscale)")
157
+ max_res = gr.Slider(1, 100, step=0.5, value=100, label="Output image resolution Percentage (Higher% = longer processing time)")
158
+ input_quality = gr.Slider(1, 100, step=1, value=40, label="Intensity (Higher = stronger JPEG artifact removal)")
 
 
 
159
 
160
  run = gr.Button("Run")
161
 
162
  with gr.Row():
163
+ before_afters = gr.State([])
164
+ current_index = gr.State(0)
165
+ before_after = gr.ImageSlider(label="Before/After (Select One)", value=None, height="auto")
166
+
167
  run.click(
168
  inference,
169
+ inputs=[input_img, is_gray, max_res, input_quality],
170
+ outputs=[output_img, before_afters]
171
  )
172
 
173
+ output_img.select(
174
+ select_image,
175
+ inputs=[before_afters],
176
+ outputs=[before_after, current_index]
177
+ )
 
 
 
 
178
 
179
+ output_img.change(
180
+ select_changed_image,
181
+ inputs=[current_index, before_afters],
182
+ outputs=[before_after]
183
+ )
184
+
185
+ gr.Examples(
186
+ examples=[
187
+ [[("doraemon.jpg", "doraemon.jpg")], False, 100, 60],
188
+ [[("tomandjerry.jpg", "tomandjerry.jpg")], False, 100, 60],
189
+ [[("somepanda.jpg", "somepanda.jpg")], True, 100, 100],
190
+ [[("cemetry.jpg", "cemetry.jpg")], False, 100, 70],
191
+ [[("michelangelo_david.jpg", "michelangelo_david.jpg")], True, 100, 30],
192
+ [[("elon_musk.jpg", "elon_musk.jpg")], False, 100, 45],
193
+ [[("text.jpg", "text.jpg")], True, 100, 70]
194
+ ],
195
+ inputs=[input_img, is_gray, max_res, input_quality],
196
+ outputs=[output_img, before_afters]
197
+ )
198
 
199
  gr.Markdown("""
200
  JPEG Artifacts are noticeable distortions of images caused by JPEG lossy compression.
201
  Note that this is not an AI Upscaler, but just a JPEG Compression Artifact Remover.
 
202
  [Original Demo](https://huggingface.co/spaces/danielsapit/JPEG_Artifacts_Removal)
203
  [FBCNN GitHub Repo](https://github.com/jiaxi-jiang/FBCNN)
204
  [Towards Flexible Blind JPEG Artifacts Removal (FBCNN, ICCV 2021)](https://arxiv.org/abs/2109.14573)