AItool commited on
Commit
4af397d
·
verified ·
1 Parent(s): 1d4d2e2

Update app_enhance.py

Browse files
Files changed (1) hide show
  1. app_enhance.py +121 -69
app_enhance.py CHANGED
@@ -1,79 +1,131 @@
1
  import os
2
- import uuid
3
  import subprocess
4
- from pathlib import Path
5
- from typing import Tuple
 
 
6
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # Constants
9
- TMP_DIR = "/tmp/gradio/output/"
10
- GIF_EXT = "gif"
11
- PALETTE_PATH = "/tmp/gradio/palette.png"
12
-
13
- def ensure_tmp_dir():
14
- os.makedirs(TMP_DIR, exist_ok=True)
15
-
16
- def generate_interpolation_frames(img_a: str, img_b: str, exp: int = 4):
17
- """Runs inference_img.py to generate interpolated frames"""
18
- cmd = [
19
- "python3", "inference_img.py",
20
- "--img", img_a, img_b,
21
- "--exp", str(exp)
22
- ]
23
- subprocess.run(cmd, check=True)
24
-
25
- def create_palette():
26
- """Generates GIF palette from interpolated frames"""
27
- cmd = [
28
- "ffmpeg", "-y", "-r", "14", "-f", "image2",
29
- "-i", f"{TMP_DIR}img%d.png",
30
- "-vf", "palettegen=stats_mode=single",
31
- PALETTE_PATH
32
- ]
33
- subprocess.run(cmd, check=True)
34
-
35
- def write_gif(gif_path: str):
36
- """Creates final interpolated GIF using palette"""
37
- cmd = [
38
- "ffmpeg", "-y", "-r", "14", "-f", "image2",
39
- "-i", f"{TMP_DIR}img%d.png",
40
- "-i", PALETTE_PATH,
41
- "-lavfi", "paletteuse",
42
- gif_path
43
- ]
44
- subprocess.run(cmd, check=True)
45
-
46
- def enhance_image(img_a: str, img_b: str, mode: str) -> Tuple[str, str]:
47
- ensure_tmp_dir()
48
- gif_path = f"{TMP_DIR}{uuid.uuid4()}.{GIF_EXT}"
49
-
50
- try:
51
- generate_interpolation_frames(img_a, img_b)
52
- create_palette()
53
- write_gif(gif_path)
54
- return gif_path, gif_path
55
- except subprocess.CalledProcessError as e:
56
- raise gr.Error(f"Interpolation failed: {e}")
57
-
58
- # Gradio UI
59
- def build_interface():
60
- with gr.Blocks(title="RIFE Interpolation") as demo:
61
- with gr.Row():
62
- input_imageA = gr.Image(label="Image A", type="filepath")
63
- input_imageB = gr.Image(label="Image B", type="filepath")
64
- enhance_mode = gr.Dropdown(choices=["default"], value="default", label="Mode")
65
- output_image = gr.Image(label="Result GIF", type="filepath")
66
- output_path = gr.Textbox(label="GIF Path", interactive=False)
67
 
68
- g_btn = gr.Button("Interpolate")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  g_btn.click(
71
  fn=enhance_image,
72
- inputs=[input_imageA, input_imageB, enhance_mode],
73
- outputs=[output_image, output_path]
74
  )
75
 
76
- return demo
77
-
78
- demo = build_interface()
79
- demo.launch()
 
1
  import os
 
2
  import subprocess
3
+ import spaces
4
+ import torch
5
+ import cv2
6
+ import uuid
7
  import gradio as gr
8
+ import numpy as np
9
+
10
+ from PIL import Image
11
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
12
+ from gfpgan.utils import GFPGANer
13
+ from realesrgan.utils import RealESRGANer
14
+
15
+ def runcmd(cmd, verbose = False):
16
+
17
+ process = subprocess.Popen(
18
+ cmd,
19
+ stdout = subprocess.PIPE,
20
+ stderr = subprocess.PIPE,
21
+ text = True,
22
+ shell = True
23
+ )
24
+ std_out, std_err = process.communicate()
25
+ if verbose:
26
+ print(std_out.strip(), std_err)
27
+ pass
28
+
29
+
30
+ if not os.path.exists('GFPGANv1.4.pth'):
31
+ runcmd("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
32
+ if not os.path.exists('realesr-general-x4v3.pth'):
33
+ runcmd("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
34
+
35
+
36
+
37
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
38
+ model_path = 'realesr-general-x4v3.pth'
39
+ half = True if torch.cuda.is_available() else False
40
+ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ @spaces.GPU(duration=15)
44
+ def enhance_image(
45
+ input_image: Image,
46
+ scale: int,
47
+ enhance_mode: str,
48
+ ):
49
+ only_face = enhance_mode == "Only Face Enhance"
50
+ if enhance_mode == "Only Face Enhance":
51
+ face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2)
52
+ elif enhance_mode == "Only Image Enhance":
53
+ face_enhancer = None
54
+ else:
55
+ face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
56
+
57
+ img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
58
 
59
+ h, w = img.shape[0:2]
60
+ if h < 300:
61
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
62
+
63
+ if face_enhancer is not None:
64
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=only_face, paste_back=True)
65
+ else:
66
+ output, _ = upsampler.enhance(img, outscale=scale)
67
+
68
+ # if scale != 2:
69
+ # interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
70
+ # h, w = img.shape[0:2]
71
+ # output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
72
+
73
+ h, w = output.shape[0:2]
74
+ max_size = 3480
75
+ if h > max_size:
76
+ w = int(w * max_size / h)
77
+ h = max_size
78
+
79
+ if w > max_size:
80
+ h = int(h * max_size / w)
81
+ w = max_size
82
+
83
+ output = cv2.resize(output, (w, h), interpolation=cv2.INTER_LANCZOS4)
84
+
85
+ enhanced_image = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
86
+ tmpPrefix = "/tmp/gradio/"
87
+
88
+ extension = 'png'
89
+
90
+ targetDir = f"{tmpPrefix}output/"
91
+ if not os.path.exists(targetDir):
92
+ os.makedirs(targetDir)
93
+
94
+ enhanced_path = f"{targetDir}{uuid.uuid4()}.{extension}"
95
+ enhanced_image.save(enhanced_path, quality=100)
96
+
97
+ return enhanced_image, enhanced_path
98
+
99
+
100
+ def create_demo() -> gr.Blocks:
101
+
102
+ with gr.Blocks() as demo:
103
+ with gr.Row():
104
+ with gr.Column():
105
+ scale = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Scale")
106
+ with gr.Column():
107
+ enhance_mode = gr.Dropdown(
108
+ label="Enhance Mode",
109
+ choices=[
110
+ "Only Face Enhance",
111
+ "Only Image Enhance",
112
+ "Face Enhance + Image Enhance",
113
+ ],
114
+ value="Face Enhance + Image Enhance",
115
+ )
116
+ g_btn = gr.Button("Enhance Image")
117
+ with gr.Row():
118
+ with gr.Column():
119
+ input_image = gr.Image(label="Input Image", type="pil")
120
+ with gr.Column():
121
+ output_image = gr.Image(label="Enhanced Image", type="pil", interactive=False)
122
+ enhance_image_path = gr.File(label="Download the Enhanced Image", interactive=False)
123
+
124
+
125
  g_btn.click(
126
  fn=enhance_image,
127
+ inputs=[input_image, scale, enhance_mode],
128
+ outputs=[output_image, enhance_image_path],
129
  )
130
 
131
+ return demo