Merlimhhs commited on
Commit
1be1a61
·
verified ·
1 Parent(s): 1be7ff1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -56
app.py CHANGED
@@ -8,33 +8,45 @@ import cv2
8
  import gradio as gr
9
  import numpy as np
10
  import torch
 
11
  from basicsr.archs.rrdbnet_arch import RRDBNet
12
  from realesrgan import RealESRGANer
13
 
 
14
  # =========================
15
  # CONFIG
16
  # =========================
17
  OUTSCALE = 2
18
 
19
- # Peso oficial do anime model mostrado no README do Real-ESRGAN
20
- MODEL_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
 
 
 
21
  MODEL_DIR = Path("weights")
22
  MODEL_PATH = MODEL_DIR / "RealESRGAN_x4plus_anime_6B.pth"
23
 
24
 
25
- def ensure_model() -> str:
 
 
 
26
  MODEL_DIR.mkdir(parents=True, exist_ok=True)
27
 
28
- if MODEL_PATH.exists() and MODEL_PATH.stat().st_size > 0:
29
  return str(MODEL_PATH)
30
 
 
31
  urlretrieve(MODEL_URL, MODEL_PATH)
 
32
  return str(MODEL_PATH)
33
 
34
 
 
 
 
35
  def build_upsampler():
36
  device = "cuda" if torch.cuda.is_available() else "cpu"
37
- use_half = device == "cuda"
38
 
39
  model_path = ensure_model()
40
 
@@ -47,117 +59,241 @@ def build_upsampler():
47
  scale=4,
48
  )
49
 
50
- return RealESRGANer(
51
  scale=4,
52
  model_path=model_path,
53
  model=model,
54
  tile=256,
55
  tile_pad=10,
56
  pre_pad=0,
57
- half=use_half,
58
  device=device,
59
  )
60
 
 
 
61
 
62
  UPSAMPLER = build_upsampler()
63
 
64
 
65
- def upscale_one_image(image: np.ndarray) -> np.ndarray:
66
- if image is None:
 
 
 
 
67
  return None
68
 
69
- if image.dtype != np.uint8:
70
- image = np.clip(image, 0, 1)
71
- image = (image * 255).astype(np.uint8)
72
 
73
- if image.ndim == 3 and image.shape[2] == 4:
74
- rgb = image[:, :, :3]
75
- alpha = image[:, :, 3]
 
 
 
 
 
76
 
77
  bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
78
- out_bgr, _ = UPSAMPLER.enhance(bgr, outscale=OUTSCALE)
79
- out_rgb = cv2.cvtColor(out_bgr, cv2.COLOR_BGR2RGB)
80
 
81
- alpha_up = cv2.resize(
 
 
 
 
 
 
 
 
 
 
82
  alpha,
83
- (alpha.shape[1] * OUTSCALE, alpha.shape[0] * OUTSCALE),
84
- interpolation=cv2.INTER_CUBIC,
 
 
 
85
  )
86
- return np.dstack([out_rgb, alpha_up])
87
 
88
- bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
89
- out_bgr, _ = UPSAMPLER.enhance(bgr, outscale=OUTSCALE)
90
- out_rgb = cv2.cvtColor(out_bgr, cv2.COLOR_BGR2RGB)
91
- return out_rgb
92
 
 
 
 
 
 
 
 
 
 
 
93
 
 
 
 
 
 
 
 
 
 
 
 
94
  def process_batch(files):
 
95
  if not files:
96
  return [], None
97
 
98
  previews = []
99
 
100
  with tempfile.TemporaryDirectory() as tmpdir:
 
101
  tmpdir = Path(tmpdir)
102
- out_dir = tmpdir / "upscaled"
103
- out_dir.mkdir(parents=True, exist_ok=True)
104
 
105
- for file_obj in files:
106
- in_path = Path(file_obj.name)
107
- image = cv2.imread(str(in_path), cv2.IMREAD_UNCHANGED)
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- if image is None:
110
  continue
111
 
112
- if image.ndim == 3 and image.shape[2] == 4:
113
- image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
114
- elif image.ndim == 3:
115
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- result = upscale_one_image(image)
118
  if result is None:
119
  continue
120
 
121
- out_name = f"{in_path.stem}_2x.png"
122
- out_path = out_dir / out_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- if result.ndim == 3 and result.shape[2] == 4:
125
- save_img = cv2.cvtColor(result, cv2.COLOR_RGBA2BGRA)
126
  else:
127
- save_img = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
 
 
 
 
 
128
 
129
- cv2.imwrite(str(out_path), save_img)
130
- previews.append((result, out_name))
 
131
 
132
  zip_path = tmpdir / "upscaled_images.zip"
133
- with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
134
- for img_file in out_dir.iterdir():
135
- zf.write(img_file, arcname=img_file.name)
136
 
137
- final_zip = Path(tempfile.gettempdir()) / "upscaled_images.zip"
138
- final_zip.write_bytes(zip_path.read_bytes())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  return previews, str(final_zip)
141
 
142
 
 
 
 
143
  with gr.Blocks() as demo:
144
- gr.Markdown("# Anime Upscaler 2x\nUpload em lote e baixe um ZIP.")
145
 
146
- files_in = gr.Files(
147
- label="Envie várias imagens",
 
 
 
 
 
148
  file_types=["image"],
149
  file_count="multiple",
150
  )
151
 
152
- run_btn = gr.Button("Processar")
153
- gallery_out = gr.Gallery(label="Prévia", columns=2, height=420)
154
- zip_out = gr.File(label="Baixar ZIP")
 
 
 
 
 
 
 
 
 
 
155
 
156
  run_btn.click(
157
  fn=process_batch,
158
  inputs=files_in,
159
- outputs=[gallery_out, zip_out],
 
 
 
160
  )
161
 
 
 
 
 
162
  if __name__ == "__main__":
163
- demo.launch()
 
 
 
 
 
8
  import gradio as gr
9
  import numpy as np
10
  import torch
11
+
12
  from basicsr.archs.rrdbnet_arch import RRDBNet
13
  from realesrgan import RealESRGANer
14
 
15
+
16
  # =========================
17
  # CONFIG
18
  # =========================
19
  OUTSCALE = 2
20
 
21
+ MODEL_URL = (
22
+ "https://github.com/xinntao/Real-ESRGAN/releases/download/"
23
+ "v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
24
+ )
25
+
26
  MODEL_DIR = Path("weights")
27
  MODEL_PATH = MODEL_DIR / "RealESRGAN_x4plus_anime_6B.pth"
28
 
29
 
30
+ # =========================
31
+ # DOWNLOAD MODEL
32
+ # =========================
33
+ def ensure_model():
34
  MODEL_DIR.mkdir(parents=True, exist_ok=True)
35
 
36
+ if MODEL_PATH.exists():
37
  return str(MODEL_PATH)
38
 
39
+ print("Downloading model...")
40
  urlretrieve(MODEL_URL, MODEL_PATH)
41
+
42
  return str(MODEL_PATH)
43
 
44
 
45
+ # =========================
46
+ # BUILD UPSAMPLER
47
+ # =========================
48
  def build_upsampler():
49
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
50
 
51
  model_path = ensure_model()
52
 
 
59
  scale=4,
60
  )
61
 
62
+ upsampler = RealESRGANer(
63
  scale=4,
64
  model_path=model_path,
65
  model=model,
66
  tile=256,
67
  tile_pad=10,
68
  pre_pad=0,
69
+ half=torch.cuda.is_available(),
70
  device=device,
71
  )
72
 
73
+ return upsampler
74
+
75
 
76
  UPSAMPLER = build_upsampler()
77
 
78
 
79
+ # =========================
80
+ # UPSCALE SINGLE IMAGE
81
+ # =========================
82
+ def upscale_image(img):
83
+
84
+ if img is None:
85
  return None
86
 
87
+ if img.dtype != np.uint8:
88
+ img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
 
89
 
90
+ has_alpha = (
91
+ len(img.shape) == 3
92
+ and img.shape[2] == 4
93
+ )
94
+
95
+ if has_alpha:
96
+ rgb = img[:, :, :3]
97
+ alpha = img[:, :, 3]
98
 
99
  bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
 
 
100
 
101
+ output, _ = UPSAMPLER.enhance(
102
+ bgr,
103
+ outscale=OUTSCALE
104
+ )
105
+
106
+ output = cv2.cvtColor(
107
+ output,
108
+ cv2.COLOR_BGR2RGB
109
+ )
110
+
111
+ alpha = cv2.resize(
112
  alpha,
113
+ (
114
+ alpha.shape[1] * OUTSCALE,
115
+ alpha.shape[0] * OUTSCALE
116
+ ),
117
+ interpolation=cv2.INTER_CUBIC
118
  )
 
119
 
120
+ output = np.dstack([output, alpha])
121
+
122
+ return output
 
123
 
124
+ else:
125
+ bgr = cv2.cvtColor(
126
+ img,
127
+ cv2.COLOR_RGB2BGR
128
+ )
129
+
130
+ output, _ = UPSAMPLER.enhance(
131
+ bgr,
132
+ outscale=OUTSCALE
133
+ )
134
 
135
+ output = cv2.cvtColor(
136
+ output,
137
+ cv2.COLOR_BGR2RGB
138
+ )
139
+
140
+ return output
141
+
142
+
143
+ # =========================
144
+ # BATCH PROCESS
145
+ # =========================
146
  def process_batch(files):
147
+
148
  if not files:
149
  return [], None
150
 
151
  previews = []
152
 
153
  with tempfile.TemporaryDirectory() as tmpdir:
154
+
155
  tmpdir = Path(tmpdir)
 
 
156
 
157
+ output_dir = tmpdir / "outputs"
158
+ output_dir.mkdir(
159
+ parents=True,
160
+ exist_ok=True
161
+ )
162
+
163
+ for file in files:
164
+
165
+ input_path = Path(file.name)
166
+
167
+ img = cv2.imread(
168
+ str(input_path),
169
+ cv2.IMREAD_UNCHANGED
170
+ )
171
 
172
+ if img is None:
173
  continue
174
 
175
+ # BGR -> RGB
176
+ if len(img.shape) == 3:
177
+
178
+ if img.shape[2] == 4:
179
+ img = cv2.cvtColor(
180
+ img,
181
+ cv2.COLOR_BGRA2RGBA
182
+ )
183
+ else:
184
+ img = cv2.cvtColor(
185
+ img,
186
+ cv2.COLOR_BGR2RGB
187
+ )
188
+
189
+ result = upscale_image(img)
190
 
 
191
  if result is None:
192
  continue
193
 
194
+ out_name = f"{input_path.stem}_2x.png"
195
+
196
+ out_path = output_dir / out_name
197
+
198
+ # RGB -> BGR
199
+ if len(result.shape) == 3:
200
+
201
+ if result.shape[2] == 4:
202
+ save_img = cv2.cvtColor(
203
+ result,
204
+ cv2.COLOR_RGBA2BGRA
205
+ )
206
+ else:
207
+ save_img = cv2.cvtColor(
208
+ result,
209
+ cv2.COLOR_RGB2BGR
210
+ )
211
 
 
 
212
  else:
213
+ save_img = result
214
+
215
+ cv2.imwrite(
216
+ str(out_path),
217
+ save_img
218
+ )
219
 
220
+ previews.append(
221
+ (result, out_name)
222
+ )
223
 
224
  zip_path = tmpdir / "upscaled_images.zip"
 
 
 
225
 
226
+ with zipfile.ZipFile(
227
+ zip_path,
228
+ "w",
229
+ compression=zipfile.ZIP_DEFLATED
230
+ ) as zipf:
231
+
232
+ for img_file in output_dir.iterdir():
233
+
234
+ zipf.write(
235
+ img_file,
236
+ arcname=img_file.name
237
+ )
238
+
239
+ final_zip = (
240
+ Path(tempfile.gettempdir())
241
+ / "upscaled_images.zip"
242
+ )
243
+
244
+ final_zip.write_bytes(
245
+ zip_path.read_bytes()
246
+ )
247
 
248
  return previews, str(final_zip)
249
 
250
 
251
+ # =========================
252
+ # UI
253
+ # =========================
254
  with gr.Blocks() as demo:
 
255
 
256
+ gr.Markdown(
257
+ "# Anime Upscaler 2x\n"
258
+ "Batch upscale com ZIP."
259
+ )
260
+
261
+ files_in = gr.File(
262
+ label="Envie imagens",
263
  file_types=["image"],
264
  file_count="multiple",
265
  )
266
 
267
+ run_btn = gr.Button(
268
+ "Processar"
269
+ )
270
+
271
+ gallery_out = gr.Gallery(
272
+ label="Preview",
273
+ columns=2,
274
+ height=400
275
+ )
276
+
277
+ zip_out = gr.File(
278
+ label="Download ZIP"
279
+ )
280
 
281
  run_btn.click(
282
  fn=process_batch,
283
  inputs=files_in,
284
+ outputs=[
285
+ gallery_out,
286
+ zip_out
287
+ ]
288
  )
289
 
290
+
291
+ # =========================
292
+ # START
293
+ # =========================
294
  if __name__ == "__main__":
295
+
296
+ demo.launch(
297
+ server_name="0.0.0.0",
298
+ server_port=7860
299
+ )