Spaces:
Runtime error
Runtime error
add dilation bar and improve UI
Browse files
app.py
CHANGED
|
@@ -1,19 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
from pathlib import Path
|
| 4 |
from matplotlib import pyplot as plt
|
| 5 |
import torch
|
| 6 |
import tempfile
|
| 7 |
-
import os
|
| 8 |
-
from omegaconf import OmegaConf
|
| 9 |
-
from sam_segment import predict_masks_with_sam
|
| 10 |
from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama
|
| 11 |
from utils import load_img_to_array, save_array_to_img, dilate_mask, \
|
| 12 |
show_mask, show_points
|
| 13 |
from PIL import Image
|
|
|
|
| 14 |
from segment_anything import SamPredictor, sam_model_registry
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def mkstemp(suffix, dir=None):
|
| 18 |
fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
|
| 19 |
os.close(fd)
|
|
@@ -21,9 +40,7 @@ def mkstemp(suffix, dir=None):
|
|
| 21 |
|
| 22 |
|
| 23 |
def get_sam_feat(img):
|
| 24 |
-
# predictor.set_image(img)
|
| 25 |
model['sam'].set_image(img)
|
| 26 |
-
# self.is_image_set = False
|
| 27 |
features = model['sam'].features
|
| 28 |
orig_h = model['sam'].orig_h
|
| 29 |
orig_w = model['sam'].orig_w
|
|
@@ -33,24 +50,18 @@ def get_sam_feat(img):
|
|
| 33 |
return features, orig_h, orig_w, input_h, input_w
|
| 34 |
|
| 35 |
|
| 36 |
-
def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w):
|
| 37 |
point_coords = [w, h]
|
| 38 |
point_labels = [1]
|
| 39 |
-
dilate_kernel_size = 15
|
| 40 |
|
| 41 |
-
# model['sam'].is_image_set = False
|
| 42 |
model['sam'].is_image_set = True
|
| 43 |
model['sam'].features = features
|
| 44 |
model['sam'].orig_h = orig_h
|
| 45 |
model['sam'].orig_w = orig_w
|
| 46 |
model['sam'].input_h = input_h
|
| 47 |
model['sam'].input_w = input_w
|
| 48 |
-
|
| 49 |
-
# model['sam'].
|
| 50 |
-
# model['sam'].input_size = input_size
|
| 51 |
-
# model['sam'].is_image_set = True
|
| 52 |
-
|
| 53 |
-
model['sam'].set_image(img)
|
| 54 |
masks, _, _ = model['sam'].predict(
|
| 55 |
point_coords=np.array([point_coords]),
|
| 56 |
point_labels=np.array(point_labels),
|
|
@@ -77,6 +88,7 @@ def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w):
|
|
| 77 |
show_points(plt.gca(), [point_coords], point_labels,
|
| 78 |
size=(width*0.04)**2)
|
| 79 |
show_mask(plt.gca(), mask, random_color=False)
|
|
|
|
| 80 |
plt.savefig(tmp_p, bbox_inches='tight', pad_inches=0)
|
| 81 |
figs.append(fig)
|
| 82 |
plt.close()
|
|
@@ -84,8 +96,7 @@ def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w):
|
|
| 84 |
|
| 85 |
|
| 86 |
def get_inpainted_img(img, mask0, mask1, mask2):
|
| 87 |
-
lama_config =
|
| 88 |
-
# lama_ckpt = "pretrained_models/big-lama"
|
| 89 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 90 |
out = []
|
| 91 |
for mask in [mask0, mask1, mask2]:
|
|
@@ -97,25 +108,27 @@ def get_inpainted_img(img, mask0, mask1, mask2):
|
|
| 97 |
return out
|
| 98 |
|
| 99 |
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
model = {}
|
| 102 |
# build the sam model
|
| 103 |
model_type="vit_h"
|
| 104 |
-
ckpt_p=
|
| 105 |
model_sam = sam_model_registry[model_type](checkpoint=ckpt_p)
|
| 106 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 107 |
model_sam.to(device=device)
|
| 108 |
-
# predictor = SamPredictor(model_sam)
|
| 109 |
model['sam'] = SamPredictor(model_sam)
|
| 110 |
|
| 111 |
# build the lama model
|
| 112 |
-
lama_config =
|
| 113 |
-
lama_ckpt =
|
| 114 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 115 |
-
# model_lama = build_lama_model(lama_config, lama_ckpt, device=device)
|
| 116 |
model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
|
| 117 |
|
| 118 |
-
|
| 119 |
with gr.Blocks() as demo:
|
| 120 |
features = gr.State(None)
|
| 121 |
orig_h = gr.State(None)
|
|
@@ -123,36 +136,59 @@ with gr.Blocks() as demo:
|
|
| 123 |
input_h = gr.State(None)
|
| 124 |
input_w = gr.State(None)
|
| 125 |
|
| 126 |
-
with gr.Row():
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
with gr.Row():
|
| 131 |
w = gr.Number(label="Point Coordinate W")
|
| 132 |
h = gr.Number(label="Point Coordinate H")
|
| 133 |
-
|
| 134 |
-
sam_mask = gr.Button("Predict Mask
|
| 135 |
-
lama = gr.Button("Inpaint Image
|
| 136 |
-
|
| 137 |
|
| 138 |
# todo: maybe we can delete this row, for it's unnecessary to show the original mask for customers
|
| 139 |
-
with gr.Row():
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
def get_select_coords(img, evt: gr.SelectData):
|
| 158 |
dpi = plt.rcParams['figure.dpi']
|
|
@@ -160,22 +196,17 @@ with gr.Blocks() as demo:
|
|
| 160 |
fig = plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
|
| 161 |
plt.imshow(img)
|
| 162 |
plt.axis('off')
|
|
|
|
| 163 |
show_points(plt.gca(), [[evt.index[0], evt.index[1]]], [1],
|
| 164 |
size=(width*0.04)**2)
|
| 165 |
return evt.index[0], evt.index[1], fig
|
| 166 |
|
| 167 |
img.select(get_select_coords, [img], [w, h, img_pointed])
|
| 168 |
-
# sam_feat.click(
|
| 169 |
-
# get_sam_feat,
|
| 170 |
-
# [img],
|
| 171 |
-
# []
|
| 172 |
-
# )
|
| 173 |
-
# img.change(get_sam_feat, [img], [])
|
| 174 |
img.upload(get_sam_feat, [img], [features, orig_h, orig_w, input_h, input_w])
|
| 175 |
|
| 176 |
sam_mask.click(
|
| 177 |
get_masked_img,
|
| 178 |
-
[img, w, h, features, orig_h, orig_w, input_h, input_w],
|
| 179 |
[img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
|
| 180 |
)
|
| 181 |
|
|
@@ -185,16 +216,16 @@ with gr.Blocks() as demo:
|
|
| 185 |
[img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2]
|
| 186 |
)
|
| 187 |
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
| 195 |
|
| 196 |
if __name__ == "__main__":
|
| 197 |
-
|
| 198 |
-
# demo.launch(max_threads=8)
|
| 199 |
-
demo.launch()
|
| 200 |
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
# sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))
|
| 4 |
+
# os.chdir("../")
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
| 7 |
from pathlib import Path
|
| 8 |
from matplotlib import pyplot as plt
|
| 9 |
import torch
|
| 10 |
import tempfile
|
|
|
|
|
|
|
|
|
|
| 11 |
from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama
|
| 12 |
from utils import load_img_to_array, save_array_to_img, dilate_mask, \
|
| 13 |
show_mask, show_points
|
| 14 |
from PIL import Image
|
| 15 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "segment-anything"))
|
| 16 |
from segment_anything import SamPredictor, sam_model_registry
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
def setup_args(parser):
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--lama_config", type=str,
|
| 22 |
+
default="./third_party/lama/configs/prediction/default.yaml",
|
| 23 |
+
help="The path to the config file of lama model. "
|
| 24 |
+
"Default: the config of big-lama",
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--lama_ckpt", type=str,
|
| 28 |
+
default="pretrained_models/big-lama",
|
| 29 |
+
help="The path to the lama checkpoint.",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--sam_ckpt", type=str,
|
| 33 |
+
default="./pretrained_models/sam_vit_h_4b8939.pth",
|
| 34 |
+
help="The path to the SAM checkpoint to use for mask generation.",
|
| 35 |
+
)
|
| 36 |
def mkstemp(suffix, dir=None):
|
| 37 |
fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
|
| 38 |
os.close(fd)
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
def get_sam_feat(img):
|
|
|
|
| 43 |
model['sam'].set_image(img)
|
|
|
|
| 44 |
features = model['sam'].features
|
| 45 |
orig_h = model['sam'].orig_h
|
| 46 |
orig_w = model['sam'].orig_w
|
|
|
|
| 50 |
return features, orig_h, orig_w, input_h, input_w
|
| 51 |
|
| 52 |
|
| 53 |
+
def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w, dilate_kernel_size):
|
| 54 |
point_coords = [w, h]
|
| 55 |
point_labels = [1]
|
|
|
|
| 56 |
|
|
|
|
| 57 |
model['sam'].is_image_set = True
|
| 58 |
model['sam'].features = features
|
| 59 |
model['sam'].orig_h = orig_h
|
| 60 |
model['sam'].orig_w = orig_w
|
| 61 |
model['sam'].input_h = input_h
|
| 62 |
model['sam'].input_w = input_w
|
| 63 |
+
|
| 64 |
+
# model['sam'].set_image(img) # todo : update here for accelerating
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
masks, _, _ = model['sam'].predict(
|
| 66 |
point_coords=np.array([point_coords]),
|
| 67 |
point_labels=np.array(point_labels),
|
|
|
|
| 88 |
show_points(plt.gca(), [point_coords], point_labels,
|
| 89 |
size=(width*0.04)**2)
|
| 90 |
show_mask(plt.gca(), mask, random_color=False)
|
| 91 |
+
plt.tight_layout()
|
| 92 |
plt.savefig(tmp_p, bbox_inches='tight', pad_inches=0)
|
| 93 |
figs.append(fig)
|
| 94 |
plt.close()
|
|
|
|
| 96 |
|
| 97 |
|
| 98 |
def get_inpainted_img(img, mask0, mask1, mask2):
|
| 99 |
+
lama_config = args.lama_config
|
|
|
|
| 100 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 101 |
out = []
|
| 102 |
for mask in [mask0, mask1, mask2]:
|
|
|
|
| 108 |
return out
|
| 109 |
|
| 110 |
|
| 111 |
+
# get args
|
| 112 |
+
parser = argparse.ArgumentParser()
|
| 113 |
+
setup_args(parser)
|
| 114 |
+
args = parser.parse_args(sys.argv[1:])
|
| 115 |
+
# build models
|
| 116 |
model = {}
|
| 117 |
# build the sam model
|
| 118 |
model_type="vit_h"
|
| 119 |
+
ckpt_p=args.sam_ckpt
|
| 120 |
model_sam = sam_model_registry[model_type](checkpoint=ckpt_p)
|
| 121 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 122 |
model_sam.to(device=device)
|
|
|
|
| 123 |
model['sam'] = SamPredictor(model_sam)
|
| 124 |
|
| 125 |
# build the lama model
|
| 126 |
+
lama_config = args.lama_config
|
| 127 |
+
lama_ckpt = args.lama_ckpt
|
| 128 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 129 |
model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
|
| 130 |
|
| 131 |
+
button_size = (100,50)
|
| 132 |
with gr.Blocks() as demo:
|
| 133 |
features = gr.State(None)
|
| 134 |
orig_h = gr.State(None)
|
|
|
|
| 136 |
input_h = gr.State(None)
|
| 137 |
input_w = gr.State(None)
|
| 138 |
|
| 139 |
+
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
| 140 |
+
with gr.Column(variant="panel"):
|
| 141 |
+
with gr.Row():
|
| 142 |
+
gr.Markdown("## Input Image")
|
| 143 |
+
with gr.Row():
|
| 144 |
+
img = gr.Image(label="Input Image").style(height="200px")
|
| 145 |
+
with gr.Column(variant="panel"):
|
| 146 |
+
with gr.Row():
|
| 147 |
+
gr.Markdown("## Pointed Image")
|
| 148 |
+
with gr.Row():
|
| 149 |
+
img_pointed = gr.Plot(label='Pointed Image')
|
| 150 |
+
with gr.Column(variant="panel"):
|
| 151 |
+
with gr.Row():
|
| 152 |
+
gr.Markdown("## Control Panel")
|
| 153 |
with gr.Row():
|
| 154 |
w = gr.Number(label="Point Coordinate W")
|
| 155 |
h = gr.Number(label="Point Coordinate H")
|
| 156 |
+
dilate_kernel_size = gr.Slider(label="Dilate Kernel Size", minimum=0, maximum=100, step=1, value=15)
|
| 157 |
+
sam_mask = gr.Button("Predict Mask", variant="primary").style(full_width=True, size="sm")
|
| 158 |
+
lama = gr.Button("Inpaint Image", variant="primary").style(full_width=True, size="sm")
|
| 159 |
+
clear_button_image = gr.Button(value="Reset", label="Reset", variant="secondary").style(full_width=True, size="sm")
|
| 160 |
|
| 161 |
# todo: maybe we can delete this row, for it's unnecessary to show the original mask for customers
|
| 162 |
+
with gr.Row(variant="panel"):
|
| 163 |
+
with gr.Column():
|
| 164 |
+
with gr.Row():
|
| 165 |
+
gr.Markdown("## Segmentation Mask")
|
| 166 |
+
with gr.Row():
|
| 167 |
+
mask_0 = gr.outputs.Image(type="numpy", label="Segmentation Mask 0").style(height="200px")
|
| 168 |
+
mask_1 = gr.outputs.Image(type="numpy", label="Segmentation Mask 1").style(height="200px")
|
| 169 |
+
mask_2 = gr.outputs.Image(type="numpy", label="Segmentation Mask 2").style(height="200px")
|
| 170 |
+
|
| 171 |
+
with gr.Row(variant="panel"):
|
| 172 |
+
with gr.Column():
|
| 173 |
+
with gr.Row():
|
| 174 |
+
gr.Markdown("## Image with Mask")
|
| 175 |
+
with gr.Row():
|
| 176 |
+
img_with_mask_0 = gr.Plot(label="Image with Segmentation Mask 0")
|
| 177 |
+
img_with_mask_1 = gr.Plot(label="Image with Segmentation Mask 1")
|
| 178 |
+
img_with_mask_2 = gr.Plot(label="Image with Segmentation Mask 2")
|
| 179 |
+
|
| 180 |
+
with gr.Row(variant="panel"):
|
| 181 |
+
with gr.Column():
|
| 182 |
+
with gr.Row():
|
| 183 |
+
gr.Markdown("## Image Removed with Mask")
|
| 184 |
+
with gr.Row():
|
| 185 |
+
img_rm_with_mask_0 = gr.outputs.Image(
|
| 186 |
+
type="numpy", label="Image Removed with Segmentation Mask 0").style(height="200px")
|
| 187 |
+
img_rm_with_mask_1 = gr.outputs.Image(
|
| 188 |
+
type="numpy", label="Image Removed with Segmentation Mask 1").style(height="200px")
|
| 189 |
+
img_rm_with_mask_2 = gr.outputs.Image(
|
| 190 |
+
type="numpy", label="Image Removed with Segmentation Mask 2").style(height="200px")
|
| 191 |
+
|
| 192 |
|
| 193 |
def get_select_coords(img, evt: gr.SelectData):
|
| 194 |
dpi = plt.rcParams['figure.dpi']
|
|
|
|
| 196 |
fig = plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
|
| 197 |
plt.imshow(img)
|
| 198 |
plt.axis('off')
|
| 199 |
+
plt.tight_layout()
|
| 200 |
show_points(plt.gca(), [[evt.index[0], evt.index[1]]], [1],
|
| 201 |
size=(width*0.04)**2)
|
| 202 |
return evt.index[0], evt.index[1], fig
|
| 203 |
|
| 204 |
img.select(get_select_coords, [img], [w, h, img_pointed])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
img.upload(get_sam_feat, [img], [features, orig_h, orig_w, input_h, input_w])
|
| 206 |
|
| 207 |
sam_mask.click(
|
| 208 |
get_masked_img,
|
| 209 |
+
[img, w, h, features, orig_h, orig_w, input_h, input_w, dilate_kernel_size],
|
| 210 |
[img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
|
| 211 |
)
|
| 212 |
|
|
|
|
| 216 |
[img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2]
|
| 217 |
)
|
| 218 |
|
| 219 |
+
|
| 220 |
+
def reset(*args):
|
| 221 |
+
return [None for _ in args]
|
| 222 |
+
|
| 223 |
+
clear_button_image.click(
|
| 224 |
+
reset,
|
| 225 |
+
[img, features, img_pointed, w, h, mask_0, mask_1, mask_2, img_with_mask_0, img_with_mask_1, img_with_mask_2, img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2],
|
| 226 |
+
[img, features, img_pointed, w, h, mask_0, mask_1, mask_2, img_with_mask_0, img_with_mask_1, img_with_mask_2, img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2]
|
| 227 |
+
)
|
| 228 |
|
| 229 |
if __name__ == "__main__":
|
| 230 |
+
demo.launch(share=True)
|
|
|
|
|
|
|
| 231 |
|