Spaces:
Running
on
Zero
Running
on
Zero
LuJingyi-John
commited on
Commit
·
11c0865
1
Parent(s):
6678b47
Simplify for HF Spaces deployment
Browse files- Remove SAM mask refinement functionality
- Remove output_path file saving features
- Simplify UI to focus on core drag inpainting
- Remove complex dependencies for better HF Spaces compatibility
- app.py +20 -43
- utils/refine_mask.py +0 -168
- utils/ui_utils.py +66 -90
app.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
| 2 |
from utils.ui_utils import *
|
| 3 |
|
| 4 |
CANVAS_SIZE = 400
|
|
@@ -22,7 +24,6 @@ def create_interface():
|
|
| 22 |
canvas = gr.Image(type="numpy", tool="sketch", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE)
|
| 23 |
with gr.Row():
|
| 24 |
fit_btn = gr.Button("Resize Image")
|
| 25 |
-
if_sam_box = gr.Checkbox(label='Refine mask (SAM)')
|
| 26 |
|
| 27 |
# Control Points Column
|
| 28 |
with gr.Column():
|
|
@@ -40,20 +41,15 @@ def create_interface():
|
|
| 40 |
run_btn = gr.Button("Inpaint")
|
| 41 |
reset_btn = gr.Button("Reset All")
|
| 42 |
|
| 43 |
-
#
|
| 44 |
-
with gr.Row(
|
| 45 |
-
sam_ks = gr.Slider(minimum=11, maximum=51, value=21, step=2, label='How much to refine mask with SAM', interactive=True)
|
| 46 |
inpaint_ks = gr.Slider(minimum=0, maximum=25, value=5, step=1, label='How much to expand inpainting mask', interactive=True)
|
| 47 |
-
output_path = gr.Textbox(value='output/app', label="Output path")
|
| 48 |
|
| 49 |
setup_events(
|
| 50 |
components={
|
| 51 |
'canvas': canvas,
|
| 52 |
'input_img': input_img,
|
| 53 |
'output_img': output_img,
|
| 54 |
-
'output_path': output_path,
|
| 55 |
-
'if_sam_box': if_sam_box,
|
| 56 |
-
'sam_ks': sam_ks,
|
| 57 |
'inpaint_ks': inpaint_ks,
|
| 58 |
},
|
| 59 |
state=state,
|
|
@@ -75,21 +71,21 @@ def setup_events(components, state, buttons):
|
|
| 75 |
clear_all,
|
| 76 |
[state['canvas_size']],
|
| 77 |
[components['canvas'], components['input_img'], components['output_img'],
|
| 78 |
-
state['points_list'], components['
|
| 79 |
)
|
| 80 |
|
| 81 |
components['canvas'].clear(
|
| 82 |
clear_all,
|
| 83 |
[state['canvas_size']],
|
| 84 |
[components['canvas'], components['input_img'], components['output_img'],
|
| 85 |
-
state['points_list'], components['
|
| 86 |
)
|
| 87 |
|
| 88 |
# Image manipulation events
|
| 89 |
def setup_image_events():
|
| 90 |
buttons['fit'].click(
|
| 91 |
clear_point,
|
| 92 |
-
[components['canvas'], state['points_list'], components['
|
| 93 |
[components['input_img']]
|
| 94 |
).then(
|
| 95 |
resize,
|
|
@@ -101,41 +97,21 @@ def setup_events(components, state, buttons):
|
|
| 101 |
def setup_canvas_events():
|
| 102 |
components['canvas'].edit(
|
| 103 |
visualize_user_drag,
|
| 104 |
-
[components['canvas'], state['points_list']
|
| 105 |
[components['input_img']]
|
| 106 |
).then(
|
| 107 |
preview_out_image,
|
| 108 |
-
[components['canvas'], state['points_list'], components['
|
| 109 |
[components['output_img'], state['inpaint_mask']]
|
| 110 |
)
|
| 111 |
|
| 112 |
-
components['if_sam_box'].change(
|
| 113 |
-
visualize_user_drag,
|
| 114 |
-
[components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box']],
|
| 115 |
-
[components['input_img']]
|
| 116 |
-
).then(
|
| 117 |
-
preview_out_image,
|
| 118 |
-
[components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
|
| 119 |
-
[components['output_img'], state['inpaint_mask']]
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
components['sam_ks'].change(
|
| 123 |
-
visualize_user_drag,
|
| 124 |
-
[components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box']],
|
| 125 |
-
[components['input_img']]
|
| 126 |
-
).then(
|
| 127 |
-
preview_out_image,
|
| 128 |
-
[components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
|
| 129 |
-
[components['output_img'], state['inpaint_mask']]
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
components['inpaint_ks'].change(
|
| 133 |
visualize_user_drag,
|
| 134 |
-
[components['canvas'], state['points_list']
|
| 135 |
[components['input_img']]
|
| 136 |
).then(
|
| 137 |
preview_out_image,
|
| 138 |
-
[components['canvas'], state['points_list'], components['
|
| 139 |
[components['output_img'], state['inpaint_mask']]
|
| 140 |
)
|
| 141 |
|
|
@@ -143,11 +119,11 @@ def setup_events(components, state, buttons):
|
|
| 143 |
def setup_input_events():
|
| 144 |
components['input_img'].select(
|
| 145 |
add_point,
|
| 146 |
-
[components['canvas'], state['points_list'], components['
|
| 147 |
[components['input_img']]
|
| 148 |
).then(
|
| 149 |
preview_out_image,
|
| 150 |
-
[components['canvas'], state['points_list'], components['
|
| 151 |
[components['output_img'], state['inpaint_mask']]
|
| 152 |
)
|
| 153 |
|
|
@@ -155,21 +131,21 @@ def setup_events(components, state, buttons):
|
|
| 155 |
def setup_point_events():
|
| 156 |
buttons['undo'].click(
|
| 157 |
undo_point,
|
| 158 |
-
[components['canvas'], state['points_list'], components['
|
| 159 |
[components['input_img']]
|
| 160 |
).then(
|
| 161 |
preview_out_image,
|
| 162 |
-
[components['canvas'], state['points_list'], components['
|
| 163 |
[components['output_img'], state['inpaint_mask']]
|
| 164 |
)
|
| 165 |
|
| 166 |
buttons['clear'].click(
|
| 167 |
clear_point,
|
| 168 |
-
[components['canvas'], state['points_list'], components['
|
| 169 |
[components['input_img']]
|
| 170 |
).then(
|
| 171 |
preview_out_image,
|
| 172 |
-
[components['canvas'], state['points_list'], components['
|
| 173 |
[components['output_img'], state['inpaint_mask']]
|
| 174 |
)
|
| 175 |
|
|
@@ -177,7 +153,7 @@ def setup_events(components, state, buttons):
|
|
| 177 |
def setup_processing_events():
|
| 178 |
buttons['run'].click(
|
| 179 |
preview_out_image,
|
| 180 |
-
|
| 181 |
[components['output_img'], state['inpaint_mask']]
|
| 182 |
).then(
|
| 183 |
inpaint,
|
|
@@ -195,7 +171,8 @@ def setup_events(components, state, buttons):
|
|
| 195 |
|
| 196 |
def main():
|
| 197 |
app = create_interface()
|
| 198 |
-
|
|
|
|
| 199 |
|
| 200 |
if __name__ == '__main__':
|
| 201 |
main()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import tempfile
|
| 3 |
+
import os
|
| 4 |
from utils.ui_utils import *
|
| 5 |
|
| 6 |
CANVAS_SIZE = 400
|
|
|
|
| 24 |
canvas = gr.Image(type="numpy", tool="sketch", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE)
|
| 25 |
with gr.Row():
|
| 26 |
fit_btn = gr.Button("Resize Image")
|
|
|
|
| 27 |
|
| 28 |
# Control Points Column
|
| 29 |
with gr.Column():
|
|
|
|
| 41 |
run_btn = gr.Button("Inpaint")
|
| 42 |
reset_btn = gr.Button("Reset All")
|
| 43 |
|
| 44 |
+
# Generation Parameters
|
| 45 |
+
with gr.Row():
|
|
|
|
| 46 |
inpaint_ks = gr.Slider(minimum=0, maximum=25, value=5, step=1, label='How much to expand inpainting mask', interactive=True)
|
|
|
|
| 47 |
|
| 48 |
setup_events(
|
| 49 |
components={
|
| 50 |
'canvas': canvas,
|
| 51 |
'input_img': input_img,
|
| 52 |
'output_img': output_img,
|
|
|
|
|
|
|
|
|
|
| 53 |
'inpaint_ks': inpaint_ks,
|
| 54 |
},
|
| 55 |
state=state,
|
|
|
|
| 71 |
clear_all,
|
| 72 |
[state['canvas_size']],
|
| 73 |
[components['canvas'], components['input_img'], components['output_img'],
|
| 74 |
+
state['points_list'], components['inpaint_ks'], state['inpaint_mask']]
|
| 75 |
)
|
| 76 |
|
| 77 |
components['canvas'].clear(
|
| 78 |
clear_all,
|
| 79 |
[state['canvas_size']],
|
| 80 |
[components['canvas'], components['input_img'], components['output_img'],
|
| 81 |
+
state['points_list'], components['inpaint_ks'], state['inpaint_mask']]
|
| 82 |
)
|
| 83 |
|
| 84 |
# Image manipulation events
|
| 85 |
def setup_image_events():
|
| 86 |
buttons['fit'].click(
|
| 87 |
clear_point,
|
| 88 |
+
[components['canvas'], state['points_list'], components['inpaint_ks']],
|
| 89 |
[components['input_img']]
|
| 90 |
).then(
|
| 91 |
resize,
|
|
|
|
| 97 |
def setup_canvas_events():
|
| 98 |
components['canvas'].edit(
|
| 99 |
visualize_user_drag,
|
| 100 |
+
[components['canvas'], state['points_list']],
|
| 101 |
[components['input_img']]
|
| 102 |
).then(
|
| 103 |
preview_out_image,
|
| 104 |
+
[components['canvas'], state['points_list'], components['inpaint_ks']],
|
| 105 |
[components['output_img'], state['inpaint_mask']]
|
| 106 |
)
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
components['inpaint_ks'].change(
|
| 109 |
visualize_user_drag,
|
| 110 |
+
[components['canvas'], state['points_list']],
|
| 111 |
[components['input_img']]
|
| 112 |
).then(
|
| 113 |
preview_out_image,
|
| 114 |
+
[components['canvas'], state['points_list'], components['inpaint_ks']],
|
| 115 |
[components['output_img'], state['inpaint_mask']]
|
| 116 |
)
|
| 117 |
|
|
|
|
| 119 |
def setup_input_events():
|
| 120 |
components['input_img'].select(
|
| 121 |
add_point,
|
| 122 |
+
[components['canvas'], state['points_list'], components['inpaint_ks']],
|
| 123 |
[components['input_img']]
|
| 124 |
).then(
|
| 125 |
preview_out_image,
|
| 126 |
+
[components['canvas'], state['points_list'], components['inpaint_ks']],
|
| 127 |
[components['output_img'], state['inpaint_mask']]
|
| 128 |
)
|
| 129 |
|
|
|
|
| 131 |
def setup_point_events():
|
| 132 |
buttons['undo'].click(
|
| 133 |
undo_point,
|
| 134 |
+
[components['canvas'], state['points_list'], components['inpaint_ks']],
|
| 135 |
[components['input_img']]
|
| 136 |
).then(
|
| 137 |
preview_out_image,
|
| 138 |
+
[components['canvas'], state['points_list'], components['inpaint_ks']],
|
| 139 |
[components['output_img'], state['inpaint_mask']]
|
| 140 |
)
|
| 141 |
|
| 142 |
buttons['clear'].click(
|
| 143 |
clear_point,
|
| 144 |
+
[components['canvas'], state['points_list'], components['inpaint_ks']],
|
| 145 |
[components['input_img']]
|
| 146 |
).then(
|
| 147 |
preview_out_image,
|
| 148 |
+
[components['canvas'], state['points_list'], components['inpaint_ks']],
|
| 149 |
[components['output_img'], state['inpaint_mask']]
|
| 150 |
)
|
| 151 |
|
|
|
|
| 153 |
def setup_processing_events():
|
| 154 |
buttons['run'].click(
|
| 155 |
preview_out_image,
|
| 156 |
+
[components['canvas'], state['points_list'], components['inpaint_ks']],
|
| 157 |
[components['output_img'], state['inpaint_mask']]
|
| 158 |
).then(
|
| 159 |
inpaint,
|
|
|
|
| 171 |
|
| 172 |
def main():
|
| 173 |
app = create_interface()
|
| 174 |
+
# HF Space compatible launch
|
| 175 |
+
app.queue().launch()
|
| 176 |
|
| 177 |
if __name__ == '__main__':
|
| 178 |
main()
|
utils/refine_mask.py
DELETED
|
@@ -1,168 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import urllib.request
|
| 3 |
-
from typing import Optional
|
| 4 |
-
|
| 5 |
-
import cv2
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def download_model(checkpoint_path: str, model_name: str = "efficientvit_sam_l0.pt") -> str:
|
| 12 |
-
"""
|
| 13 |
-
Download the model checkpoint if not found locally.
|
| 14 |
-
|
| 15 |
-
Args:
|
| 16 |
-
checkpoint_path: Local path where model should be saved
|
| 17 |
-
model_name: Name of the model file to download
|
| 18 |
-
|
| 19 |
-
Returns:
|
| 20 |
-
str: Path to the downloaded checkpoint
|
| 21 |
-
"""
|
| 22 |
-
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
|
| 23 |
-
|
| 24 |
-
base_url = "https://huggingface.co/mit-han-lab/efficientvit-sam/resolve/main"
|
| 25 |
-
model_url = f"{base_url}/{model_name}"
|
| 26 |
-
|
| 27 |
-
try:
|
| 28 |
-
print(f"Downloading model from {model_url}...")
|
| 29 |
-
urllib.request.urlretrieve(model_url, checkpoint_path)
|
| 30 |
-
print(f"Model successfully downloaded to {checkpoint_path}")
|
| 31 |
-
return checkpoint_path
|
| 32 |
-
except Exception as e:
|
| 33 |
-
raise RuntimeError(f"Failed to download model: {str(e)}")
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
class SamMaskRefiner(nn.Module):
|
| 37 |
-
CHECKPOINT_DIR = 'checkpoints'
|
| 38 |
-
MODEL_CONFIGS = {
|
| 39 |
-
'l0': 'efficientvit_sam_l0.pt',
|
| 40 |
-
'l1': 'efficientvit_sam_l1.pt',
|
| 41 |
-
'l2': 'efficientvit_sam_l2.pt'
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
def __init__(self, model_name: str = 'l0') -> None:
|
| 45 |
-
"""
|
| 46 |
-
Initialize SAM predictor with specified model version.
|
| 47 |
-
|
| 48 |
-
Args:
|
| 49 |
-
model_name: Model version to use ('l0', 'l1', or 'l2'). Defaults to 'l0'.
|
| 50 |
-
|
| 51 |
-
Raises:
|
| 52 |
-
ValueError: If invalid model_name is provided
|
| 53 |
-
RuntimeError: If model loading fails after download attempt
|
| 54 |
-
"""
|
| 55 |
-
super().__init__()
|
| 56 |
-
|
| 57 |
-
if model_name not in self.MODEL_CONFIGS:
|
| 58 |
-
raise ValueError(f"Invalid model_name. Choose from: {list(self.MODEL_CONFIGS.keys())}")
|
| 59 |
-
|
| 60 |
-
model_filename = self.MODEL_CONFIGS[model_name]
|
| 61 |
-
checkpoint_path = os.path.join(self.CHECKPOINT_DIR, model_filename)
|
| 62 |
-
|
| 63 |
-
try:
|
| 64 |
-
from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
|
| 65 |
-
from efficientvit.sam_model_zoo import create_efficientvit_sam_model
|
| 66 |
-
except ImportError:
|
| 67 |
-
raise ImportError(
|
| 68 |
-
"Failed to import EfficientViT modules. Please ensure the package is installed:\n"
|
| 69 |
-
"pip install git+https://github.com/mit-han-lab/efficientvit.git"
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
if not os.path.exists(checkpoint_path):
|
| 73 |
-
print(f"Checkpoint not found at {checkpoint_path}. Attempting to download...")
|
| 74 |
-
checkpoint_path = download_model(checkpoint_path, model_filename)
|
| 75 |
-
|
| 76 |
-
try:
|
| 77 |
-
model_type = f'efficientvit-sam-{model_name}'
|
| 78 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 79 |
-
self.model = create_efficientvit_sam_model(model_type, True, checkpoint_path).eval()
|
| 80 |
-
self.model = self.model.requires_grad_(False).to(device)
|
| 81 |
-
self.predictor = EfficientViTSamPredictor(self.model)
|
| 82 |
-
print(f"\033[92mEfficientViT-SAM model loaded from: {checkpoint_path}\033[0m")
|
| 83 |
-
except Exception as e:
|
| 84 |
-
raise RuntimeError(f"Failed to load model: {str(e)}")
|
| 85 |
-
|
| 86 |
-
def sample_points_from_mask(self, mask: np.ndarray, max_points: int = 128) -> np.ndarray:
|
| 87 |
-
"""
|
| 88 |
-
Sample points uniformly from masked regions.
|
| 89 |
-
|
| 90 |
-
Args:
|
| 91 |
-
mask: Binary mask array of shape (H, W) with 0-1 values.
|
| 92 |
-
max_points: Maximum number of points to sample.
|
| 93 |
-
|
| 94 |
-
Returns:
|
| 95 |
-
np.ndarray: Array of shape (N, 2) containing [x,y] coordinates.
|
| 96 |
-
"""
|
| 97 |
-
y_indices, x_indices = np.where(mask > 0.5)
|
| 98 |
-
total_points = len(y_indices)
|
| 99 |
-
|
| 100 |
-
if total_points <= max_points:
|
| 101 |
-
return np.stack([x_indices, y_indices], axis=1)
|
| 102 |
-
|
| 103 |
-
y_min, y_max = y_indices.min(), y_indices.max()
|
| 104 |
-
x_min, x_max = x_indices.min(), x_indices.max()
|
| 105 |
-
|
| 106 |
-
aspect_ratio = (x_max - x_min) / max(y_max - y_min, 1)
|
| 107 |
-
ny = int(np.sqrt(max_points / aspect_ratio))
|
| 108 |
-
nx = int(ny * aspect_ratio)
|
| 109 |
-
|
| 110 |
-
x_bins = np.linspace(x_min, x_max + 1, nx + 1, dtype=np.int32)
|
| 111 |
-
y_bins = np.linspace(y_min, y_max + 1, ny + 1, dtype=np.int32)
|
| 112 |
-
|
| 113 |
-
x_dig = np.digitize(x_indices, x_bins) - 1
|
| 114 |
-
y_dig = np.digitize(y_indices, y_bins) - 1
|
| 115 |
-
bin_indices = y_dig * nx + x_dig
|
| 116 |
-
unique_bins = np.unique(bin_indices)
|
| 117 |
-
|
| 118 |
-
points = []
|
| 119 |
-
for idx in unique_bins:
|
| 120 |
-
bin_y = idx // nx
|
| 121 |
-
bin_x = idx % nx
|
| 122 |
-
mask = (y_dig == bin_y) & (x_dig == bin_x)
|
| 123 |
-
|
| 124 |
-
if np.any(mask):
|
| 125 |
-
px = int(np.mean(x_indices[mask]))
|
| 126 |
-
py = int(np.mean(y_indices[mask]))
|
| 127 |
-
points.append([px, py])
|
| 128 |
-
|
| 129 |
-
points = np.array(points)
|
| 130 |
-
|
| 131 |
-
if len(points) > max_points:
|
| 132 |
-
indices = np.linspace(0, len(points) - 1, max_points, dtype=int)
|
| 133 |
-
points = points[indices]
|
| 134 |
-
|
| 135 |
-
return points
|
| 136 |
-
|
| 137 |
-
def refine_mask(self, image: np.ndarray, input_mask: np.ndarray, kernel_size: int = 21) -> np.ndarray:
|
| 138 |
-
"""
|
| 139 |
-
Refine an input mask using the SAM (Segment Anything Model) model.
|
| 140 |
-
|
| 141 |
-
Args:
|
| 142 |
-
image: RGB image, shape (H, W, 3), values in [0, 255]
|
| 143 |
-
input_mask: Binary mask, shape (H, W), values in {0, 1}
|
| 144 |
-
kernel_size: Size of morphological kernel (default: 21)
|
| 145 |
-
|
| 146 |
-
Returns:
|
| 147 |
-
Refined binary mask, shape (H, W), values in {0, 1}
|
| 148 |
-
"""
|
| 149 |
-
points = self.sample_points_from_mask(input_mask, max_points=128)
|
| 150 |
-
if len(points) == 0:
|
| 151 |
-
return input_mask
|
| 152 |
-
|
| 153 |
-
self.predictor.set_image(image)
|
| 154 |
-
masks_pred, _, _ = self.predictor.predict(
|
| 155 |
-
point_coords=points,
|
| 156 |
-
point_labels=np.ones(len(points)),
|
| 157 |
-
multimask_output=False
|
| 158 |
-
)
|
| 159 |
-
sam_mask = masks_pred[0]
|
| 160 |
-
|
| 161 |
-
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| 162 |
-
expanded_input = cv2.dilate(input_mask.astype(np.uint8), kernel)
|
| 163 |
-
preserved_input = cv2.erode(input_mask.astype(np.uint8), kernel)
|
| 164 |
-
|
| 165 |
-
sam_mask = np.logical_and(expanded_input, sam_mask).astype(np.uint8)
|
| 166 |
-
sam_mask = np.logical_or(preserved_input, sam_mask).astype(np.uint8)
|
| 167 |
-
|
| 168 |
-
return sam_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/ui_utils.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import pickle
|
| 3 |
from time import perf_counter
|
|
|
|
| 4 |
|
| 5 |
import cv2
|
| 6 |
import gradio as gr
|
|
@@ -10,7 +11,6 @@ from PIL import Image
|
|
| 10 |
from diffusers import AutoPipelineForInpainting, AutoencoderTiny, LCMScheduler
|
| 11 |
|
| 12 |
from utils.drag import bi_warp
|
| 13 |
-
from utils.refine_mask import SamMaskRefiner
|
| 14 |
|
| 15 |
|
| 16 |
__all__ = [
|
|
@@ -19,10 +19,13 @@ __all__ = [
|
|
| 19 |
'add_point', 'undo_point', 'clear_point',
|
| 20 |
]
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
# UI functions
|
| 23 |
def clear_all(length):
|
| 24 |
"""Reset UI by clearing all input images and parameters."""
|
| 25 |
-
return (gr.Image(value=None, height=length, width=length),) * 3 + ([],
|
| 26 |
|
| 27 |
def resize(canvas, gen_length, canvas_length):
|
| 28 |
"""Resize canvas while maintaining aspect ratio."""
|
|
@@ -51,57 +54,35 @@ def process_canvas(canvas):
|
|
| 51 |
return image, mask
|
| 52 |
|
| 53 |
# Point manipulation functions
|
| 54 |
-
def add_point(canvas, points,
|
| 55 |
"""Add selected point to points list and update image."""
|
| 56 |
if canvas is None:
|
| 57 |
return None
|
| 58 |
points.append(evt.index)
|
| 59 |
-
return visualize_user_drag(canvas, points
|
| 60 |
|
| 61 |
-
def undo_point(canvas, points,
|
| 62 |
"""Remove last point and update image."""
|
| 63 |
if canvas is None:
|
| 64 |
return None
|
| 65 |
if len(points) > 0:
|
| 66 |
points.pop()
|
| 67 |
-
return visualize_user_drag(canvas, points
|
| 68 |
|
| 69 |
-
def clear_point(canvas, points,
|
| 70 |
"""Clear all points and update image."""
|
| 71 |
if canvas is None:
|
| 72 |
return None
|
| 73 |
points.clear()
|
| 74 |
-
return visualize_user_drag(canvas, points
|
| 75 |
|
| 76 |
# Visualization tools
|
| 77 |
-
def
|
| 78 |
-
"""
|
| 79 |
-
global sam_refiner
|
| 80 |
-
try:
|
| 81 |
-
if 'sam_refiner' not in globals():
|
| 82 |
-
sam_refiner = SamMaskRefiner()
|
| 83 |
-
return sam_refiner.refine_mask(image, mask, kernel_size)
|
| 84 |
-
except ImportError:
|
| 85 |
-
gr.Warning("EfficientVit not installed. Please install with: pip install git+https://github.com/mit-han-lab/efficientvit.git")
|
| 86 |
-
return mask
|
| 87 |
-
except Exception as e:
|
| 88 |
-
gr.Warning(f"Error refining mask: {str(e)}")
|
| 89 |
-
return mask
|
| 90 |
-
|
| 91 |
-
def visualize_user_drag(canvas, points, sam_ks, if_sam=False, output_path=None):
|
| 92 |
-
"""Visualize control points and motion vectors on the input image.
|
| 93 |
-
|
| 94 |
-
Args:
|
| 95 |
-
canvas (dict): Gradio canvas containing image and mask
|
| 96 |
-
points (list): List of (x,y) coordinate pairs for control points
|
| 97 |
-
sam_ks (int): Kernel size for SAM mask refinement
|
| 98 |
-
if_sam (bool): Whether to use SAM refinement on mask
|
| 99 |
-
"""
|
| 100 |
if canvas is None:
|
| 101 |
return None
|
| 102 |
|
| 103 |
image, mask = process_canvas(canvas)
|
| 104 |
-
mask = refine_mask(image, mask, sam_ks) if if_sam and mask.sum() > 0 else mask
|
| 105 |
|
| 106 |
# Apply colored mask overlay
|
| 107 |
result = image.copy()
|
|
@@ -120,29 +101,11 @@ def visualize_user_drag(canvas, points, sam_ks, if_sam=False, output_path=None):
|
|
| 120 |
else:
|
| 121 |
cv2.circle(image, tuple(point), 10, (255, 0, 0), -1) # Start point
|
| 122 |
prev_point = point
|
| 123 |
-
|
| 124 |
-
if output_path:
|
| 125 |
-
os.makedirs(output_path, exist_ok=True)
|
| 126 |
-
Image.fromarray(image).save(os.path.join(output_path, 'user_drag_i4p.png'))
|
| 127 |
return image
|
| 128 |
|
| 129 |
-
def preview_out_image(canvas, points,
|
| 130 |
-
"""Preview warped image result and generate inpainting mask.
|
| 131 |
-
|
| 132 |
-
Args:
|
| 133 |
-
canvas (dict): Gradio canvas containing the input image and mask
|
| 134 |
-
points (list): List of (x,y) coordinate pairs defining source and target positions for warping
|
| 135 |
-
sam_ks (int): Kernel size parameter for SAM mask refinement
|
| 136 |
-
inpaint_ks (int): Kernel size parameter for inpainting mask generation
|
| 137 |
-
if_sam (bool): Whether to use SAM model for mask refinement
|
| 138 |
-
output_path (str, optional): Directory path to save original image and metadata
|
| 139 |
-
|
| 140 |
-
Returns:
|
| 141 |
-
tuple:
|
| 142 |
-
- ndarray: Warped image with grid pattern overlay on regions needing inpainting
|
| 143 |
-
- ndarray: Binary mask (255 for inpainting regions, 0 elsewhere)
|
| 144 |
-
- (None, None): If canvas is empty or fewer than 2 control points provided
|
| 145 |
-
"""
|
| 146 |
if canvas is None:
|
| 147 |
return None, None
|
| 148 |
|
|
@@ -155,15 +118,7 @@ def preview_out_image(canvas, points, sam_ks, inpaint_ks, if_sam=False, output_p
|
|
| 155 |
size_valid = all(max(x.shape[:2] if len(x.shape) > 2 else x.shape) == 512 for x in (image, mask))
|
| 156 |
if not (shapes_valid and size_valid):
|
| 157 |
gr.Warning('Click Resize Image Button first.')
|
| 158 |
-
|
| 159 |
-
mask = refine_mask(image, mask, sam_ks) if if_sam and mask.sum() > 0 else mask
|
| 160 |
-
|
| 161 |
-
if output_path:
|
| 162 |
-
os.makedirs(output_path, exist_ok=True)
|
| 163 |
-
Image.fromarray(image).save(os.path.join(output_path, 'original_image.png'))
|
| 164 |
-
metadata = {'mask': mask, 'points': points}
|
| 165 |
-
with open(os.path.join(output_path, 'meta_data_i4p.pkl'), 'wb') as f:
|
| 166 |
-
pickle.dump(metadata, f)
|
| 167 |
|
| 168 |
handle_pts, target_pts, inpaint_mask = bi_warp(mask, points, inpaint_ks)
|
| 169 |
image[target_pts[:, 1], target_pts[:, 0]] = image[handle_pts[:, 1], handle_pts[:, 0]]
|
|
@@ -172,9 +127,6 @@ def preview_out_image(canvas, points, sam_ks, inpaint_ks, if_sam=False, output_p
|
|
| 172 |
background = np.ones_like(mask) * 255
|
| 173 |
background[::10] = background[:, ::10] = 0
|
| 174 |
image = np.where(inpaint_mask[..., np.newaxis]==1, background[..., np.newaxis], image)
|
| 175 |
-
|
| 176 |
-
if output_path:
|
| 177 |
-
Image.fromarray(image).save(os.path.join(output_path, 'preview_image.png'))
|
| 178 |
|
| 179 |
return image, (inpaint_mask * 255).astype(np.uint8)
|
| 180 |
|
|
@@ -187,11 +139,26 @@ def setup_pipeline(device='cuda', model_version='v1-5'):
|
|
| 187 |
}
|
| 188 |
model_id, lora_id, vae_id = MODEL_CONFIGS[model_version]
|
| 189 |
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
| 192 |
pipe.load_lora_weights(lora_id)
|
| 193 |
pipe.fuse_lora()
|
| 194 |
-
pipe.vae = AutoencoderTiny.from_pretrained(vae_id, torch_dtype=
|
| 195 |
pipe = pipe.to(device)
|
| 196 |
|
| 197 |
# Pre-compute prompt embeddings during setup
|
|
@@ -206,19 +173,20 @@ def setup_pipeline(device='cuda', model_version='v1-5'):
|
|
| 206 |
|
| 207 |
return pipe
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
def inpaint(image, inpaint_mask):
|
| 213 |
-
"""Perform efficient inpainting on masked regions using Stable Diffusion.
|
| 214 |
-
|
| 215 |
-
Args:
|
| 216 |
-
image (ndarray): Input RGB image array (warped preview image)
|
| 217 |
-
inpaint_mask (ndarray): Binary mask array where 255 indicates regions to inpaint
|
| 218 |
-
|
| 219 |
-
Returns:
|
| 220 |
-
ndarray: Inpainted image with masked regions filled in
|
| 221 |
-
"""
|
| 222 |
if image is None:
|
| 223 |
return None
|
| 224 |
|
|
@@ -226,6 +194,10 @@ def inpaint(image, inpaint_mask):
|
|
| 226 |
return image
|
| 227 |
|
| 228 |
start = perf_counter()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
pipe_id = 'xl' if 'xl' in pipe.config._name_or_path else 'v1-5'
|
| 230 |
inpaint_strength = 0.99 if pipe_id == 'xl' else 1.0
|
| 231 |
|
|
@@ -254,18 +226,22 @@ def inpaint(image, inpaint_mask):
|
|
| 254 |
}
|
| 255 |
|
| 256 |
# Run pipeline
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
# Post-process results
|
| 270 |
inpaint_mask = (inpaint_mask[..., np.newaxis] / 255).astype(np.uint8)
|
| 271 |
-
return (inpainted * 255).astype(np.uint8) * inpaint_mask + image * (1 - inpaint_mask)
|
|
|
|
| 1 |
import os
|
| 2 |
import pickle
|
| 3 |
from time import perf_counter
|
| 4 |
+
import tempfile
|
| 5 |
|
| 6 |
import cv2
|
| 7 |
import gradio as gr
|
|
|
|
| 11 |
from diffusers import AutoPipelineForInpainting, AutoencoderTiny, LCMScheduler
|
| 12 |
|
| 13 |
from utils.drag import bi_warp
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
__all__ = [
|
|
|
|
| 19 |
'add_point', 'undo_point', 'clear_point',
|
| 20 |
]
|
| 21 |
|
| 22 |
+
# Global variables for lazy loading
|
| 23 |
+
pipe = None
|
| 24 |
+
|
| 25 |
# UI functions
|
| 26 |
def clear_all(length):
|
| 27 |
"""Reset UI by clearing all input images and parameters."""
|
| 28 |
+
return (gr.Image(value=None, height=length, width=length),) * 3 + ([], 2, None)
|
| 29 |
|
| 30 |
def resize(canvas, gen_length, canvas_length):
|
| 31 |
"""Resize canvas while maintaining aspect ratio."""
|
|
|
|
| 54 |
return image, mask
|
| 55 |
|
| 56 |
# Point manipulation functions
|
| 57 |
+
def add_point(canvas, points, inpaint_ks, evt: gr.SelectData):
|
| 58 |
"""Add selected point to points list and update image."""
|
| 59 |
if canvas is None:
|
| 60 |
return None
|
| 61 |
points.append(evt.index)
|
| 62 |
+
return visualize_user_drag(canvas, points)
|
| 63 |
|
| 64 |
+
def undo_point(canvas, points, inpaint_ks):
|
| 65 |
"""Remove last point and update image."""
|
| 66 |
if canvas is None:
|
| 67 |
return None
|
| 68 |
if len(points) > 0:
|
| 69 |
points.pop()
|
| 70 |
+
return visualize_user_drag(canvas, points)
|
| 71 |
|
| 72 |
+
def clear_point(canvas, points, inpaint_ks):
|
| 73 |
"""Clear all points and update image."""
|
| 74 |
if canvas is None:
|
| 75 |
return None
|
| 76 |
points.clear()
|
| 77 |
+
return visualize_user_drag(canvas, points)
|
| 78 |
|
| 79 |
# Visualization tools
|
| 80 |
+
def visualize_user_drag(canvas, points):
|
| 81 |
+
"""Visualize control points and motion vectors on the input image."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
if canvas is None:
|
| 83 |
return None
|
| 84 |
|
| 85 |
image, mask = process_canvas(canvas)
|
|
|
|
| 86 |
|
| 87 |
# Apply colored mask overlay
|
| 88 |
result = image.copy()
|
|
|
|
| 101 |
else:
|
| 102 |
cv2.circle(image, tuple(point), 10, (255, 0, 0), -1) # Start point
|
| 103 |
prev_point = point
|
| 104 |
+
|
|
|
|
|
|
|
|
|
|
| 105 |
return image
|
| 106 |
|
| 107 |
+
def preview_out_image(canvas, points, inpaint_ks):
|
| 108 |
+
"""Preview warped image result and generate inpainting mask."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
if canvas is None:
|
| 110 |
return None, None
|
| 111 |
|
|
|
|
| 118 |
size_valid = all(max(x.shape[:2] if len(x.shape) > 2 else x.shape) == 512 for x in (image, mask))
|
| 119 |
if not (shapes_valid and size_valid):
|
| 120 |
gr.Warning('Click Resize Image Button first.')
|
| 121 |
+
return image, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
handle_pts, target_pts, inpaint_mask = bi_warp(mask, points, inpaint_ks)
|
| 124 |
image[target_pts[:, 1], target_pts[:, 0]] = image[handle_pts[:, 1], handle_pts[:, 0]]
|
|
|
|
| 127 |
background = np.ones_like(mask) * 255
|
| 128 |
background[::10] = background[:, ::10] = 0
|
| 129 |
image = np.where(inpaint_mask[..., np.newaxis]==1, background[..., np.newaxis], image)
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
return image, (inpaint_mask * 255).astype(np.uint8)
|
| 132 |
|
|
|
|
| 139 |
}
|
| 140 |
model_id, lora_id, vae_id = MODEL_CONFIGS[model_version]
|
| 141 |
|
| 142 |
+
# Check if CUDA is available, fallback to CPU
|
| 143 |
+
if not torch.cuda.is_available():
|
| 144 |
+
device = 'cpu'
|
| 145 |
+
torch_dtype = torch.float32
|
| 146 |
+
variant = None
|
| 147 |
+
else:
|
| 148 |
+
torch_dtype = torch.float16
|
| 149 |
+
variant = "fp16"
|
| 150 |
+
|
| 151 |
+
gr.Info('Loading inpainting pipeline...')
|
| 152 |
+
pipe = AutoPipelineForInpainting.from_pretrained(
|
| 153 |
+
model_id,
|
| 154 |
+
torch_dtype=torch_dtype,
|
| 155 |
+
variant=variant,
|
| 156 |
+
safety_checker=None
|
| 157 |
+
)
|
| 158 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
| 159 |
pipe.load_lora_weights(lora_id)
|
| 160 |
pipe.fuse_lora()
|
| 161 |
+
pipe.vae = AutoencoderTiny.from_pretrained(vae_id, torch_dtype=torch_dtype)
|
| 162 |
pipe = pipe.to(device)
|
| 163 |
|
| 164 |
# Pre-compute prompt embeddings during setup
|
|
|
|
| 173 |
|
| 174 |
return pipe
|
| 175 |
|
| 176 |
+
def get_pipeline():
|
| 177 |
+
"""Lazy load pipeline only when needed."""
|
| 178 |
+
global pipe
|
| 179 |
+
if pipe is None:
|
| 180 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 181 |
+
pipe = setup_pipeline(device=device, model_version='v1-5')
|
| 182 |
+
if device == 'cuda':
|
| 183 |
+
pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cuda', 1, False)[0]
|
| 184 |
+
else:
|
| 185 |
+
pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cpu', 1, False)[0]
|
| 186 |
+
return pipe
|
| 187 |
|
| 188 |
def inpaint(image, inpaint_mask):
|
| 189 |
+
"""Perform efficient inpainting on masked regions using Stable Diffusion."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
if image is None:
|
| 191 |
return None
|
| 192 |
|
|
|
|
| 194 |
return image
|
| 195 |
|
| 196 |
start = perf_counter()
|
| 197 |
+
|
| 198 |
+
# Get pipeline (lazy loading)
|
| 199 |
+
pipe = get_pipeline()
|
| 200 |
+
|
| 201 |
pipe_id = 'xl' if 'xl' in pipe.config._name_or_path else 'v1-5'
|
| 202 |
inpaint_strength = 0.99 if pipe_id == 'xl' else 1.0
|
| 203 |
|
|
|
|
| 226 |
}
|
| 227 |
|
| 228 |
# Run pipeline
|
| 229 |
+
try:
|
| 230 |
+
if pipe_id == 'v1-5':
|
| 231 |
+
inpainted = pipe(
|
| 232 |
+
prompt_embeds=pipe.cached_prompt_embeds,
|
| 233 |
+
**common_params
|
| 234 |
+
).images[0]
|
| 235 |
+
else:
|
| 236 |
+
inpainted = pipe(
|
| 237 |
+
prompt_embeds=pipe.cached_prompt_embeds,
|
| 238 |
+
pooled_prompt_embeds=pipe.cached_pooled_prompt_embeds,
|
| 239 |
+
**common_params
|
| 240 |
+
).images[0]
|
| 241 |
+
except Exception as e:
|
| 242 |
+
gr.Warning(f"Inpainting failed: {str(e)}")
|
| 243 |
+
return image
|
| 244 |
|
| 245 |
# Post-process results
|
| 246 |
inpaint_mask = (inpaint_mask[..., np.newaxis] / 255).astype(np.uint8)
|
| 247 |
+
return (inpainted * 255).astype(np.uint8) * inpaint_mask + image * (1 - inpaint_mask)
|