AItool commited on
Commit
8f6c8a2
·
verified ·
1 Parent(s): 2d7ac41

Update app_enhance.py

Browse files
Files changed (1) hide show
  1. app_enhance.py +69 -121
app_enhance.py CHANGED
@@ -1,131 +1,79 @@
1
  import os
2
- import subprocess
3
- import spaces
4
- import torch
5
- import cv2
6
  import uuid
 
 
 
7
  import gradio as gr
8
- import numpy as np
9
-
10
- from PIL import Image
11
- from basicsr.archs.srvgg_arch import SRVGGNetCompact
12
- from gfpgan.utils import GFPGANer
13
- from realesrgan.utils import RealESRGANer
14
-
15
- def runcmd(cmd, verbose = False):
16
-
17
- process = subprocess.Popen(
18
- cmd,
19
- stdout = subprocess.PIPE,
20
- stderr = subprocess.PIPE,
21
- text = True,
22
- shell = True
23
- )
24
- std_out, std_err = process.communicate()
25
- if verbose:
26
- print(std_out.strip(), std_err)
27
- pass
28
-
29
-
30
- if not os.path.exists('GFPGANv1.4.pth'):
31
- runcmd("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
32
- if not os.path.exists('realesr-general-x4v3.pth'):
33
- runcmd("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
34
-
35
-
36
-
37
- model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
38
- model_path = 'realesr-general-x4v3.pth'
39
- half = True if torch.cuda.is_available() else False
40
- upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
41
-
42
-
43
- @spaces.GPU(duration=15)
44
- def enhance_image(
45
- input_image: Image,
46
- scale: int,
47
- enhance_mode: str,
48
- ):
49
- only_face = enhance_mode == "Only Face Enhance"
50
- if enhance_mode == "Only Face Enhance":
51
- face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2)
52
- elif enhance_mode == "Only Image Enhance":
53
- face_enhancer = None
54
- else:
55
- face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
56
-
57
- img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
58
-
59
- h, w = img.shape[0:2]
60
- if h < 300:
61
- img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
62
-
63
- if face_enhancer is not None:
64
- _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=only_face, paste_back=True)
65
- else:
66
- output, _ = upsampler.enhance(img, outscale=scale)
67
-
68
- # if scale != 2:
69
- # interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
70
- # h, w = img.shape[0:2]
71
- # output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
72
-
73
- h, w = output.shape[0:2]
74
- max_size = 3480
75
- if h > max_size:
76
- w = int(w * max_size / h)
77
- h = max_size
78
-
79
- if w > max_size:
80
- h = int(h * max_size / w)
81
- w = max_size
82
-
83
- output = cv2.resize(output, (w, h), interpolation=cv2.INTER_LANCZOS4)
84
-
85
- enhanced_image = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
86
- tmpPrefix = "/tmp/gradio/"
87
-
88
- extension = 'png'
89
-
90
- targetDir = f"{tmpPrefix}output/"
91
- if not os.path.exists(targetDir):
92
- os.makedirs(targetDir)
93
-
94
- enhanced_path = f"{targetDir}{uuid.uuid4()}.{extension}"
95
- enhanced_image.save(enhanced_path, quality=100)
96
-
97
- return enhanced_image, enhanced_path
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- def create_demo() -> gr.Blocks:
101
 
102
- with gr.Blocks() as demo:
103
- with gr.Row():
104
- with gr.Column():
105
- scale = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Scale")
106
- with gr.Column():
107
- enhance_mode = gr.Dropdown(
108
- label="Enhance Mode",
109
- choices=[
110
- "Only Face Enhance",
111
- "Only Image Enhance",
112
- "Face Enhance + Image Enhance",
113
- ],
114
- value="Face Enhance + Image Enhance",
115
- )
116
- g_btn = gr.Button("Enhance Image")
117
- with gr.Row():
118
- with gr.Column():
119
- input_image = gr.Image(label="Input Image", type="pil")
120
- with gr.Column():
121
- output_image = gr.Image(label="Enhanced Image", type="pil", interactive=False)
122
- enhance_image_path = gr.File(label="Download the Enhanced Image", interactive=False)
123
-
124
-
125
  g_btn.click(
126
  fn=enhance_image,
127
- inputs=[input_image, scale, enhance_mode],
128
- outputs=[output_image, enhance_image_path],
129
  )
130
 
131
- return demo
 
 
 
 
1
  import os
 
 
 
 
2
  import uuid
3
+ import subprocess
4
+ from pathlib import Path
5
+ from typing import Tuple
6
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Constants
9
+ TMP_DIR = "/tmp/gradio/output/"
10
+ GIF_EXT = "gif"
11
+ PALETTE_PATH = "/tmp/gradio/palette.png"
12
+
13
+ def ensure_tmp_dir():
14
+ os.makedirs(TMP_DIR, exist_ok=True)
15
+
16
+ def generate_interpolation_frames(img_a: str, img_b: str, exp: int = 4):
17
+ """Runs inference_img.py to generate interpolated frames"""
18
+ cmd = [
19
+ "python3", "inference_img.py",
20
+ "--img", img_a, img_b,
21
+ "--exp", str(exp)
22
+ ]
23
+ subprocess.run(cmd, check=True)
24
+
25
+ def create_palette():
26
+ """Generates GIF palette from interpolated frames"""
27
+ cmd = [
28
+ "ffmpeg", "-y", "-r", "14", "-f", "image2",
29
+ "-i", f"{TMP_DIR}img%d.png",
30
+ "-vf", "palettegen=stats_mode=single",
31
+ PALETTE_PATH
32
+ ]
33
+ subprocess.run(cmd, check=True)
34
+
35
+ def write_gif(gif_path: str):
36
+ """Creates final interpolated GIF using palette"""
37
+ cmd = [
38
+ "ffmpeg", "-y", "-r", "14", "-f", "image2",
39
+ "-i", f"{TMP_DIR}img%d.png",
40
+ "-i", PALETTE_PATH,
41
+ "-lavfi", "paletteuse",
42
+ gif_path
43
+ ]
44
+ subprocess.run(cmd, check=True)
45
+
46
+ def enhance_image(img_a: str, img_b: str, mode: str) -> Tuple[str, str]:
47
+ ensure_tmp_dir()
48
+ gif_path = f"{TMP_DIR}{uuid.uuid4()}.{GIF_EXT}"
49
+
50
+ try:
51
+ generate_interpolation_frames(img_a, img_b)
52
+ create_palette()
53
+ write_gif(gif_path)
54
+ return gif_path, gif_path
55
+ except subprocess.CalledProcessError as e:
56
+ raise gr.Error(f"Interpolation failed: {e}")
57
+
58
+ # Gradio UI
59
+ def build_interface():
60
+ with gr.Blocks(title="RIFE Interpolation") as demo:
61
+ with gr.Row():
62
+ input_imageA = gr.Image(label="Image A", type="filepath")
63
+ input_imageB = gr.Image(label="Image B", type="filepath")
64
+ enhance_mode = gr.Dropdown(choices=["default"], value="default", label="Mode")
65
+ output_image = gr.Image(label="Result GIF", type="filepath")
66
+ output_path = gr.Textbox(label="GIF Path", interactive=False)
67
 
68
+ g_btn = gr.Button("Interpolate")
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  g_btn.click(
71
  fn=enhance_image,
72
+ inputs=[input_imageA, input_imageB, enhance_mode],
73
+ outputs=[output_image, output_path]
74
  )
75
 
76
+ return demo
77
+
78
+ demo = build_interface()
79
+ demo.launch()