LuJingyi-John commited on
Commit
11c0865
·
1 Parent(s): 6678b47

Simplify for HF Spaces deployment

Browse files

- Remove SAM mask refinement functionality
- Remove output_path file saving features
- Simplify UI to focus on core drag inpainting
- Remove complex dependencies for better HF Spaces compatibility

Files changed (3) hide show
  1. app.py +20 -43
  2. utils/refine_mask.py +0 -168
  3. utils/ui_utils.py +66 -90
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import gradio as gr
 
 
2
  from utils.ui_utils import *
3
 
4
  CANVAS_SIZE = 400
@@ -22,7 +24,6 @@ def create_interface():
22
  canvas = gr.Image(type="numpy", tool="sketch", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE)
23
  with gr.Row():
24
  fit_btn = gr.Button("Resize Image")
25
- if_sam_box = gr.Checkbox(label='Refine mask (SAM)')
26
 
27
  # Control Points Column
28
  with gr.Column():
@@ -40,20 +41,15 @@ def create_interface():
40
  run_btn = gr.Button("Inpaint")
41
  reset_btn = gr.Button("Reset All")
42
 
43
- # Output Settings
44
- with gr.Row("Generation Parameters"):
45
- sam_ks = gr.Slider(minimum=11, maximum=51, value=21, step=2, label='How much to refine mask with SAM', interactive=True)
46
  inpaint_ks = gr.Slider(minimum=0, maximum=25, value=5, step=1, label='How much to expand inpainting mask', interactive=True)
47
- output_path = gr.Textbox(value='output/app', label="Output path")
48
 
49
  setup_events(
50
  components={
51
  'canvas': canvas,
52
  'input_img': input_img,
53
  'output_img': output_img,
54
- 'output_path': output_path,
55
- 'if_sam_box': if_sam_box,
56
- 'sam_ks': sam_ks,
57
  'inpaint_ks': inpaint_ks,
58
  },
59
  state=state,
@@ -75,21 +71,21 @@ def setup_events(components, state, buttons):
75
  clear_all,
76
  [state['canvas_size']],
77
  [components['canvas'], components['input_img'], components['output_img'],
78
- state['points_list'], components['sam_ks'], components['inpaint_ks'], components['output_path'], state['inpaint_mask']]
79
  )
80
 
81
  components['canvas'].clear(
82
  clear_all,
83
  [state['canvas_size']],
84
  [components['canvas'], components['input_img'], components['output_img'],
85
- state['points_list'], components['sam_ks'], components['inpaint_ks'], components['output_path'], state['inpaint_mask']]
86
  )
87
 
88
  # Image manipulation events
89
  def setup_image_events():
90
  buttons['fit'].click(
91
  clear_point,
92
- [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
93
  [components['input_img']]
94
  ).then(
95
  resize,
@@ -101,41 +97,21 @@ def setup_events(components, state, buttons):
101
  def setup_canvas_events():
102
  components['canvas'].edit(
103
  visualize_user_drag,
104
- [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
105
  [components['input_img']]
106
  ).then(
107
  preview_out_image,
108
- [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
109
  [components['output_img'], state['inpaint_mask']]
110
  )
111
 
112
- components['if_sam_box'].change(
113
- visualize_user_drag,
114
- [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box']],
115
- [components['input_img']]
116
- ).then(
117
- preview_out_image,
118
- [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
119
- [components['output_img'], state['inpaint_mask']]
120
- )
121
-
122
- components['sam_ks'].change(
123
- visualize_user_drag,
124
- [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box']],
125
- [components['input_img']]
126
- ).then(
127
- preview_out_image,
128
- [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
129
- [components['output_img'], state['inpaint_mask']]
130
- )
131
-
132
  components['inpaint_ks'].change(
133
  visualize_user_drag,
134
- [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box']],
135
  [components['input_img']]
136
  ).then(
137
  preview_out_image,
138
- [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
139
  [components['output_img'], state['inpaint_mask']]
140
  )
141
 
@@ -143,11 +119,11 @@ def setup_events(components, state, buttons):
143
  def setup_input_events():
144
  components['input_img'].select(
145
  add_point,
146
- [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
147
  [components['input_img']]
148
  ).then(
149
  preview_out_image,
150
- [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
151
  [components['output_img'], state['inpaint_mask']]
152
  )
153
 
@@ -155,21 +131,21 @@ def setup_events(components, state, buttons):
155
  def setup_point_events():
156
  buttons['undo'].click(
157
  undo_point,
158
- [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
159
  [components['input_img']]
160
  ).then(
161
  preview_out_image,
162
- [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
163
  [components['output_img'], state['inpaint_mask']]
164
  )
165
 
166
  buttons['clear'].click(
167
  clear_point,
168
- [components['canvas'], state['points_list'], components['sam_ks'], components['if_sam_box'], components['output_path']],
169
  [components['input_img']]
170
  ).then(
171
  preview_out_image,
172
- [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
173
  [components['output_img'], state['inpaint_mask']]
174
  )
175
 
@@ -177,7 +153,7 @@ def setup_events(components, state, buttons):
177
  def setup_processing_events():
178
  buttons['run'].click(
179
  preview_out_image,
180
- [components['canvas'], state['points_list'], components['sam_ks'], components['inpaint_ks'], components['if_sam_box'], components['output_path']],
181
  [components['output_img'], state['inpaint_mask']]
182
  ).then(
183
  inpaint,
@@ -195,7 +171,8 @@ def setup_events(components, state, buttons):
195
 
196
  def main():
197
  app = create_interface()
198
- app.queue().launch(share=True, debug=True)
 
199
 
200
  if __name__ == '__main__':
201
  main()
 
1
  import gradio as gr
2
+ import tempfile
3
+ import os
4
  from utils.ui_utils import *
5
 
6
  CANVAS_SIZE = 400
 
24
  canvas = gr.Image(type="numpy", tool="sketch", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE)
25
  with gr.Row():
26
  fit_btn = gr.Button("Resize Image")
 
27
 
28
  # Control Points Column
29
  with gr.Column():
 
41
  run_btn = gr.Button("Inpaint")
42
  reset_btn = gr.Button("Reset All")
43
 
44
+ # Generation Parameters
45
+ with gr.Row():
 
46
  inpaint_ks = gr.Slider(minimum=0, maximum=25, value=5, step=1, label='How much to expand inpainting mask', interactive=True)
 
47
 
48
  setup_events(
49
  components={
50
  'canvas': canvas,
51
  'input_img': input_img,
52
  'output_img': output_img,
 
 
 
53
  'inpaint_ks': inpaint_ks,
54
  },
55
  state=state,
 
71
  clear_all,
72
  [state['canvas_size']],
73
  [components['canvas'], components['input_img'], components['output_img'],
74
+ state['points_list'], components['inpaint_ks'], state['inpaint_mask']]
75
  )
76
 
77
  components['canvas'].clear(
78
  clear_all,
79
  [state['canvas_size']],
80
  [components['canvas'], components['input_img'], components['output_img'],
81
+ state['points_list'], components['inpaint_ks'], state['inpaint_mask']]
82
  )
83
 
84
  # Image manipulation events
85
  def setup_image_events():
86
  buttons['fit'].click(
87
  clear_point,
88
+ [components['canvas'], state['points_list'], components['inpaint_ks']],
89
  [components['input_img']]
90
  ).then(
91
  resize,
 
97
  def setup_canvas_events():
98
  components['canvas'].edit(
99
  visualize_user_drag,
100
+ [components['canvas'], state['points_list']],
101
  [components['input_img']]
102
  ).then(
103
  preview_out_image,
104
+ [components['canvas'], state['points_list'], components['inpaint_ks']],
105
  [components['output_img'], state['inpaint_mask']]
106
  )
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  components['inpaint_ks'].change(
109
  visualize_user_drag,
110
+ [components['canvas'], state['points_list']],
111
  [components['input_img']]
112
  ).then(
113
  preview_out_image,
114
+ [components['canvas'], state['points_list'], components['inpaint_ks']],
115
  [components['output_img'], state['inpaint_mask']]
116
  )
117
 
 
119
  def setup_input_events():
120
  components['input_img'].select(
121
  add_point,
122
+ [components['canvas'], state['points_list'], components['inpaint_ks']],
123
  [components['input_img']]
124
  ).then(
125
  preview_out_image,
126
+ [components['canvas'], state['points_list'], components['inpaint_ks']],
127
  [components['output_img'], state['inpaint_mask']]
128
  )
129
 
 
131
  def setup_point_events():
132
  buttons['undo'].click(
133
  undo_point,
134
+ [components['canvas'], state['points_list'], components['inpaint_ks']],
135
  [components['input_img']]
136
  ).then(
137
  preview_out_image,
138
+ [components['canvas'], state['points_list'], components['inpaint_ks']],
139
  [components['output_img'], state['inpaint_mask']]
140
  )
141
 
142
  buttons['clear'].click(
143
  clear_point,
144
+ [components['canvas'], state['points_list'], components['inpaint_ks']],
145
  [components['input_img']]
146
  ).then(
147
  preview_out_image,
148
+ [components['canvas'], state['points_list'], components['inpaint_ks']],
149
  [components['output_img'], state['inpaint_mask']]
150
  )
151
 
 
153
  def setup_processing_events():
154
  buttons['run'].click(
155
  preview_out_image,
156
+ [components['canvas'], state['points_list'], components['inpaint_ks']],
157
  [components['output_img'], state['inpaint_mask']]
158
  ).then(
159
  inpaint,
 
171
 
172
  def main():
173
  app = create_interface()
174
+ # HF Space compatible launch
175
+ app.queue().launch()
176
 
177
  if __name__ == '__main__':
178
  main()
utils/refine_mask.py DELETED
@@ -1,168 +0,0 @@
1
- import os
2
- import urllib.request
3
- from typing import Optional
4
-
5
- import cv2
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
-
10
-
11
- def download_model(checkpoint_path: str, model_name: str = "efficientvit_sam_l0.pt") -> str:
12
- """
13
- Download the model checkpoint if not found locally.
14
-
15
- Args:
16
- checkpoint_path: Local path where model should be saved
17
- model_name: Name of the model file to download
18
-
19
- Returns:
20
- str: Path to the downloaded checkpoint
21
- """
22
- os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
23
-
24
- base_url = "https://huggingface.co/mit-han-lab/efficientvit-sam/resolve/main"
25
- model_url = f"{base_url}/{model_name}"
26
-
27
- try:
28
- print(f"Downloading model from {model_url}...")
29
- urllib.request.urlretrieve(model_url, checkpoint_path)
30
- print(f"Model successfully downloaded to {checkpoint_path}")
31
- return checkpoint_path
32
- except Exception as e:
33
- raise RuntimeError(f"Failed to download model: {str(e)}")
34
-
35
-
36
- class SamMaskRefiner(nn.Module):
37
- CHECKPOINT_DIR = 'checkpoints'
38
- MODEL_CONFIGS = {
39
- 'l0': 'efficientvit_sam_l0.pt',
40
- 'l1': 'efficientvit_sam_l1.pt',
41
- 'l2': 'efficientvit_sam_l2.pt'
42
- }
43
-
44
- def __init__(self, model_name: str = 'l0') -> None:
45
- """
46
- Initialize SAM predictor with specified model version.
47
-
48
- Args:
49
- model_name: Model version to use ('l0', 'l1', or 'l2'). Defaults to 'l0'.
50
-
51
- Raises:
52
- ValueError: If invalid model_name is provided
53
- RuntimeError: If model loading fails after download attempt
54
- """
55
- super().__init__()
56
-
57
- if model_name not in self.MODEL_CONFIGS:
58
- raise ValueError(f"Invalid model_name. Choose from: {list(self.MODEL_CONFIGS.keys())}")
59
-
60
- model_filename = self.MODEL_CONFIGS[model_name]
61
- checkpoint_path = os.path.join(self.CHECKPOINT_DIR, model_filename)
62
-
63
- try:
64
- from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
65
- from efficientvit.sam_model_zoo import create_efficientvit_sam_model
66
- except ImportError:
67
- raise ImportError(
68
- "Failed to import EfficientViT modules. Please ensure the package is installed:\n"
69
- "pip install git+https://github.com/mit-han-lab/efficientvit.git"
70
- )
71
-
72
- if not os.path.exists(checkpoint_path):
73
- print(f"Checkpoint not found at {checkpoint_path}. Attempting to download...")
74
- checkpoint_path = download_model(checkpoint_path, model_filename)
75
-
76
- try:
77
- model_type = f'efficientvit-sam-{model_name}'
78
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
79
- self.model = create_efficientvit_sam_model(model_type, True, checkpoint_path).eval()
80
- self.model = self.model.requires_grad_(False).to(device)
81
- self.predictor = EfficientViTSamPredictor(self.model)
82
- print(f"\033[92mEfficientViT-SAM model loaded from: {checkpoint_path}\033[0m")
83
- except Exception as e:
84
- raise RuntimeError(f"Failed to load model: {str(e)}")
85
-
86
- def sample_points_from_mask(self, mask: np.ndarray, max_points: int = 128) -> np.ndarray:
87
- """
88
- Sample points uniformly from masked regions.
89
-
90
- Args:
91
- mask: Binary mask array of shape (H, W) with 0-1 values.
92
- max_points: Maximum number of points to sample.
93
-
94
- Returns:
95
- np.ndarray: Array of shape (N, 2) containing [x,y] coordinates.
96
- """
97
- y_indices, x_indices = np.where(mask > 0.5)
98
- total_points = len(y_indices)
99
-
100
- if total_points <= max_points:
101
- return np.stack([x_indices, y_indices], axis=1)
102
-
103
- y_min, y_max = y_indices.min(), y_indices.max()
104
- x_min, x_max = x_indices.min(), x_indices.max()
105
-
106
- aspect_ratio = (x_max - x_min) / max(y_max - y_min, 1)
107
- ny = int(np.sqrt(max_points / aspect_ratio))
108
- nx = int(ny * aspect_ratio)
109
-
110
- x_bins = np.linspace(x_min, x_max + 1, nx + 1, dtype=np.int32)
111
- y_bins = np.linspace(y_min, y_max + 1, ny + 1, dtype=np.int32)
112
-
113
- x_dig = np.digitize(x_indices, x_bins) - 1
114
- y_dig = np.digitize(y_indices, y_bins) - 1
115
- bin_indices = y_dig * nx + x_dig
116
- unique_bins = np.unique(bin_indices)
117
-
118
- points = []
119
- for idx in unique_bins:
120
- bin_y = idx // nx
121
- bin_x = idx % nx
122
- mask = (y_dig == bin_y) & (x_dig == bin_x)
123
-
124
- if np.any(mask):
125
- px = int(np.mean(x_indices[mask]))
126
- py = int(np.mean(y_indices[mask]))
127
- points.append([px, py])
128
-
129
- points = np.array(points)
130
-
131
- if len(points) > max_points:
132
- indices = np.linspace(0, len(points) - 1, max_points, dtype=int)
133
- points = points[indices]
134
-
135
- return points
136
-
137
- def refine_mask(self, image: np.ndarray, input_mask: np.ndarray, kernel_size: int = 21) -> np.ndarray:
138
- """
139
- Refine an input mask using the SAM (Segment Anything Model) model.
140
-
141
- Args:
142
- image: RGB image, shape (H, W, 3), values in [0, 255]
143
- input_mask: Binary mask, shape (H, W), values in {0, 1}
144
- kernel_size: Size of morphological kernel (default: 21)
145
-
146
- Returns:
147
- Refined binary mask, shape (H, W), values in {0, 1}
148
- """
149
- points = self.sample_points_from_mask(input_mask, max_points=128)
150
- if len(points) == 0:
151
- return input_mask
152
-
153
- self.predictor.set_image(image)
154
- masks_pred, _, _ = self.predictor.predict(
155
- point_coords=points,
156
- point_labels=np.ones(len(points)),
157
- multimask_output=False
158
- )
159
- sam_mask = masks_pred[0]
160
-
161
- kernel = np.ones((kernel_size, kernel_size), np.uint8)
162
- expanded_input = cv2.dilate(input_mask.astype(np.uint8), kernel)
163
- preserved_input = cv2.erode(input_mask.astype(np.uint8), kernel)
164
-
165
- sam_mask = np.logical_and(expanded_input, sam_mask).astype(np.uint8)
166
- sam_mask = np.logical_or(preserved_input, sam_mask).astype(np.uint8)
167
-
168
- return sam_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/ui_utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import pickle
3
  from time import perf_counter
 
4
 
5
  import cv2
6
  import gradio as gr
@@ -10,7 +11,6 @@ from PIL import Image
10
  from diffusers import AutoPipelineForInpainting, AutoencoderTiny, LCMScheduler
11
 
12
  from utils.drag import bi_warp
13
- from utils.refine_mask import SamMaskRefiner
14
 
15
 
16
  __all__ = [
@@ -19,10 +19,13 @@ __all__ = [
19
  'add_point', 'undo_point', 'clear_point',
20
  ]
21
 
 
 
 
22
  # UI functions
23
  def clear_all(length):
24
  """Reset UI by clearing all input images and parameters."""
25
- return (gr.Image(value=None, height=length, width=length),) * 3 + ([], 21, 2, "output/app", None)
26
 
27
  def resize(canvas, gen_length, canvas_length):
28
  """Resize canvas while maintaining aspect ratio."""
@@ -51,57 +54,35 @@ def process_canvas(canvas):
51
  return image, mask
52
 
53
  # Point manipulation functions
54
- def add_point(canvas, points, sam_ks, if_sam, output_path, evt: gr.SelectData):
55
  """Add selected point to points list and update image."""
56
  if canvas is None:
57
  return None
58
  points.append(evt.index)
59
- return visualize_user_drag(canvas, points, sam_ks, if_sam, output_path)
60
 
61
- def undo_point(canvas, points, sam_ks, if_sam, output_path):
62
  """Remove last point and update image."""
63
  if canvas is None:
64
  return None
65
  if len(points) > 0:
66
  points.pop()
67
- return visualize_user_drag(canvas, points, sam_ks, if_sam, output_path)
68
 
69
- def clear_point(canvas, points, sam_ks, if_sam, output_path):
70
  """Clear all points and update image."""
71
  if canvas is None:
72
  return None
73
  points.clear()
74
- return visualize_user_drag(canvas, points, sam_ks, if_sam, output_path)
75
 
76
  # Visualization tools
77
- def refine_mask(image, mask, kernel_size):
78
- """Refine mask using SAM model if available."""
79
- global sam_refiner
80
- try:
81
- if 'sam_refiner' not in globals():
82
- sam_refiner = SamMaskRefiner()
83
- return sam_refiner.refine_mask(image, mask, kernel_size)
84
- except ImportError:
85
- gr.Warning("EfficientVit not installed. Please install with: pip install git+https://github.com/mit-han-lab/efficientvit.git")
86
- return mask
87
- except Exception as e:
88
- gr.Warning(f"Error refining mask: {str(e)}")
89
- return mask
90
-
91
- def visualize_user_drag(canvas, points, sam_ks, if_sam=False, output_path=None):
92
- """Visualize control points and motion vectors on the input image.
93
-
94
- Args:
95
- canvas (dict): Gradio canvas containing image and mask
96
- points (list): List of (x,y) coordinate pairs for control points
97
- sam_ks (int): Kernel size for SAM mask refinement
98
- if_sam (bool): Whether to use SAM refinement on mask
99
- """
100
  if canvas is None:
101
  return None
102
 
103
  image, mask = process_canvas(canvas)
104
- mask = refine_mask(image, mask, sam_ks) if if_sam and mask.sum() > 0 else mask
105
 
106
  # Apply colored mask overlay
107
  result = image.copy()
@@ -120,29 +101,11 @@ def visualize_user_drag(canvas, points, sam_ks, if_sam=False, output_path=None):
120
  else:
121
  cv2.circle(image, tuple(point), 10, (255, 0, 0), -1) # Start point
122
  prev_point = point
123
-
124
- if output_path:
125
- os.makedirs(output_path, exist_ok=True)
126
- Image.fromarray(image).save(os.path.join(output_path, 'user_drag_i4p.png'))
127
  return image
128
 
129
- def preview_out_image(canvas, points, sam_ks, inpaint_ks, if_sam=False, output_path=None):
130
- """Preview warped image result and generate inpainting mask.
131
-
132
- Args:
133
- canvas (dict): Gradio canvas containing the input image and mask
134
- points (list): List of (x,y) coordinate pairs defining source and target positions for warping
135
- sam_ks (int): Kernel size parameter for SAM mask refinement
136
- inpaint_ks (int): Kernel size parameter for inpainting mask generation
137
- if_sam (bool): Whether to use SAM model for mask refinement
138
- output_path (str, optional): Directory path to save original image and metadata
139
-
140
- Returns:
141
- tuple:
142
- - ndarray: Warped image with grid pattern overlay on regions needing inpainting
143
- - ndarray: Binary mask (255 for inpainting regions, 0 elsewhere)
144
- - (None, None): If canvas is empty or fewer than 2 control points provided
145
- """
146
  if canvas is None:
147
  return None, None
148
 
@@ -155,15 +118,7 @@ def preview_out_image(canvas, points, sam_ks, inpaint_ks, if_sam=False, output_p
155
  size_valid = all(max(x.shape[:2] if len(x.shape) > 2 else x.shape) == 512 for x in (image, mask))
156
  if not (shapes_valid and size_valid):
157
  gr.Warning('Click Resize Image Button first.')
158
-
159
- mask = refine_mask(image, mask, sam_ks) if if_sam and mask.sum() > 0 else mask
160
-
161
- if output_path:
162
- os.makedirs(output_path, exist_ok=True)
163
- Image.fromarray(image).save(os.path.join(output_path, 'original_image.png'))
164
- metadata = {'mask': mask, 'points': points}
165
- with open(os.path.join(output_path, 'meta_data_i4p.pkl'), 'wb') as f:
166
- pickle.dump(metadata, f)
167
 
168
  handle_pts, target_pts, inpaint_mask = bi_warp(mask, points, inpaint_ks)
169
  image[target_pts[:, 1], target_pts[:, 0]] = image[handle_pts[:, 1], handle_pts[:, 0]]
@@ -172,9 +127,6 @@ def preview_out_image(canvas, points, sam_ks, inpaint_ks, if_sam=False, output_p
172
  background = np.ones_like(mask) * 255
173
  background[::10] = background[:, ::10] = 0
174
  image = np.where(inpaint_mask[..., np.newaxis]==1, background[..., np.newaxis], image)
175
-
176
- if output_path:
177
- Image.fromarray(image).save(os.path.join(output_path, 'preview_image.png'))
178
 
179
  return image, (inpaint_mask * 255).astype(np.uint8)
180
 
@@ -187,11 +139,26 @@ def setup_pipeline(device='cuda', model_version='v1-5'):
187
  }
188
  model_id, lora_id, vae_id = MODEL_CONFIGS[model_version]
189
 
190
- pipe = AutoPipelineForInpainting.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", safety_checker=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
192
  pipe.load_lora_weights(lora_id)
193
  pipe.fuse_lora()
194
- pipe.vae = AutoencoderTiny.from_pretrained(vae_id, torch_dtype=torch.float16)
195
  pipe = pipe.to(device)
196
 
197
  # Pre-compute prompt embeddings during setup
@@ -206,19 +173,20 @@ def setup_pipeline(device='cuda', model_version='v1-5'):
206
 
207
  return pipe
208
 
209
- pipe = setup_pipeline(model_version='v1-5')
210
- pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cuda', 1, False)[0]
 
 
 
 
 
 
 
 
 
211
 
212
  def inpaint(image, inpaint_mask):
213
- """Perform efficient inpainting on masked regions using Stable Diffusion.
214
-
215
- Args:
216
- image (ndarray): Input RGB image array (warped preview image)
217
- inpaint_mask (ndarray): Binary mask array where 255 indicates regions to inpaint
218
-
219
- Returns:
220
- ndarray: Inpainted image with masked regions filled in
221
- """
222
  if image is None:
223
  return None
224
 
@@ -226,6 +194,10 @@ def inpaint(image, inpaint_mask):
226
  return image
227
 
228
  start = perf_counter()
 
 
 
 
229
  pipe_id = 'xl' if 'xl' in pipe.config._name_or_path else 'v1-5'
230
  inpaint_strength = 0.99 if pipe_id == 'xl' else 1.0
231
 
@@ -254,18 +226,22 @@ def inpaint(image, inpaint_mask):
254
  }
255
 
256
  # Run pipeline
257
- if pipe_id == 'v1-5':
258
- inpainted = pipe(
259
- prompt_embeds=pipe.cached_prompt_embeds,
260
- **common_params
261
- ).images[0]
262
- else:
263
- inpainted = pipe(
264
- prompt_embeds=pipe.cached_prompt_embeds,
265
- pooled_prompt_embeds=pipe.cached_pooled_prompt_embeds,
266
- **common_params
267
- ).images[0]
 
 
 
 
268
 
269
  # Post-process results
270
  inpaint_mask = (inpaint_mask[..., np.newaxis] / 255).astype(np.uint8)
271
- return (inpainted * 255).astype(np.uint8) * inpaint_mask + image * (1 - inpaint_mask)
 
1
  import os
2
  import pickle
3
  from time import perf_counter
4
+ import tempfile
5
 
6
  import cv2
7
  import gradio as gr
 
11
  from diffusers import AutoPipelineForInpainting, AutoencoderTiny, LCMScheduler
12
 
13
  from utils.drag import bi_warp
 
14
 
15
 
16
  __all__ = [
 
19
  'add_point', 'undo_point', 'clear_point',
20
  ]
21
 
22
+ # Global variables for lazy loading
23
+ pipe = None
24
+
25
  # UI functions
26
  def clear_all(length):
27
  """Reset UI by clearing all input images and parameters."""
28
+ return (gr.Image(value=None, height=length, width=length),) * 3 + ([], 2, None)
29
 
30
  def resize(canvas, gen_length, canvas_length):
31
  """Resize canvas while maintaining aspect ratio."""
 
54
  return image, mask
55
 
56
  # Point manipulation functions
57
+ def add_point(canvas, points, inpaint_ks, evt: gr.SelectData):
58
  """Add selected point to points list and update image."""
59
  if canvas is None:
60
  return None
61
  points.append(evt.index)
62
+ return visualize_user_drag(canvas, points)
63
 
64
+ def undo_point(canvas, points, inpaint_ks):
65
  """Remove last point and update image."""
66
  if canvas is None:
67
  return None
68
  if len(points) > 0:
69
  points.pop()
70
+ return visualize_user_drag(canvas, points)
71
 
72
+ def clear_point(canvas, points, inpaint_ks):
73
  """Clear all points and update image."""
74
  if canvas is None:
75
  return None
76
  points.clear()
77
+ return visualize_user_drag(canvas, points)
78
 
79
  # Visualization tools
80
+ def visualize_user_drag(canvas, points):
81
+ """Visualize control points and motion vectors on the input image."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  if canvas is None:
83
  return None
84
 
85
  image, mask = process_canvas(canvas)
 
86
 
87
  # Apply colored mask overlay
88
  result = image.copy()
 
101
  else:
102
  cv2.circle(image, tuple(point), 10, (255, 0, 0), -1) # Start point
103
  prev_point = point
104
+
 
 
 
105
  return image
106
 
107
+ def preview_out_image(canvas, points, inpaint_ks):
108
+ """Preview warped image result and generate inpainting mask."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  if canvas is None:
110
  return None, None
111
 
 
118
  size_valid = all(max(x.shape[:2] if len(x.shape) > 2 else x.shape) == 512 for x in (image, mask))
119
  if not (shapes_valid and size_valid):
120
  gr.Warning('Click Resize Image Button first.')
121
+ return image, None
 
 
 
 
 
 
 
 
122
 
123
  handle_pts, target_pts, inpaint_mask = bi_warp(mask, points, inpaint_ks)
124
  image[target_pts[:, 1], target_pts[:, 0]] = image[handle_pts[:, 1], handle_pts[:, 0]]
 
127
  background = np.ones_like(mask) * 255
128
  background[::10] = background[:, ::10] = 0
129
  image = np.where(inpaint_mask[..., np.newaxis]==1, background[..., np.newaxis], image)
 
 
 
130
 
131
  return image, (inpaint_mask * 255).astype(np.uint8)
132
 
 
139
  }
140
  model_id, lora_id, vae_id = MODEL_CONFIGS[model_version]
141
 
142
+ # Check if CUDA is available, fallback to CPU
143
+ if not torch.cuda.is_available():
144
+ device = 'cpu'
145
+ torch_dtype = torch.float32
146
+ variant = None
147
+ else:
148
+ torch_dtype = torch.float16
149
+ variant = "fp16"
150
+
151
+ gr.Info('Loading inpainting pipeline...')
152
+ pipe = AutoPipelineForInpainting.from_pretrained(
153
+ model_id,
154
+ torch_dtype=torch_dtype,
155
+ variant=variant,
156
+ safety_checker=None
157
+ )
158
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
159
  pipe.load_lora_weights(lora_id)
160
  pipe.fuse_lora()
161
+ pipe.vae = AutoencoderTiny.from_pretrained(vae_id, torch_dtype=torch_dtype)
162
  pipe = pipe.to(device)
163
 
164
  # Pre-compute prompt embeddings during setup
 
173
 
174
  return pipe
175
 
176
+ def get_pipeline():
177
+ """Lazy load pipeline only when needed."""
178
+ global pipe
179
+ if pipe is None:
180
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
181
+ pipe = setup_pipeline(device=device, model_version='v1-5')
182
+ if device == 'cuda':
183
+ pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cuda', 1, False)[0]
184
+ else:
185
+ pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cpu', 1, False)[0]
186
+ return pipe
187
 
188
  def inpaint(image, inpaint_mask):
189
+ """Perform efficient inpainting on masked regions using Stable Diffusion."""
 
 
 
 
 
 
 
 
190
  if image is None:
191
  return None
192
 
 
194
  return image
195
 
196
  start = perf_counter()
197
+
198
+ # Get pipeline (lazy loading)
199
+ pipe = get_pipeline()
200
+
201
  pipe_id = 'xl' if 'xl' in pipe.config._name_or_path else 'v1-5'
202
  inpaint_strength = 0.99 if pipe_id == 'xl' else 1.0
203
 
 
226
  }
227
 
228
  # Run pipeline
229
+ try:
230
+ if pipe_id == 'v1-5':
231
+ inpainted = pipe(
232
+ prompt_embeds=pipe.cached_prompt_embeds,
233
+ **common_params
234
+ ).images[0]
235
+ else:
236
+ inpainted = pipe(
237
+ prompt_embeds=pipe.cached_prompt_embeds,
238
+ pooled_prompt_embeds=pipe.cached_pooled_prompt_embeds,
239
+ **common_params
240
+ ).images[0]
241
+ except Exception as e:
242
+ gr.Warning(f"Inpainting failed: {str(e)}")
243
+ return image
244
 
245
  # Post-process results
246
  inpaint_mask = (inpaint_mask[..., np.newaxis] / 255).astype(np.uint8)
247
+ return (inpainted * 255).astype(np.uint8) * inpaint_mask + image * (1 - inpaint_mask)