telcom commited on
Commit
f762924
·
verified ·
1 Parent(s): e8d221e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -231
app.py CHANGED
@@ -1,280 +1,172 @@
1
- # app.py
2
- # ============================================================
3
- # SDXL Inpainting with ONE "paint-on-image" input (Gradio ImageEditor)
4
- # ============================================================
5
-
6
  import os
7
  import gc
8
  import random
9
- import warnings
10
- import logging
11
-
12
- # ---- Spaces GPU decorator (must be imported early) ----------
13
- try:
14
- import spaces # noqa: F401
15
- SPACES_AVAILABLE = True
16
- except Exception:
17
- SPACES_AVAILABLE = False
18
-
19
- import gradio as gr
20
  import numpy as np
21
- from PIL import Image, ImageChops
22
-
23
  import torch
24
- from huggingface_hub import login
25
- from diffusers import StableDiffusionXLInpaintPipeline
26
-
27
- # ============================================================
28
- # Config
29
- # ============================================================
30
 
31
- INPAINT_MODEL = os.environ.get(
32
- "INPAINT_MODEL",
33
- "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
34
- ).strip()
35
 
36
- HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
37
- if HF_TOKEN:
38
- login(token=HF_TOKEN)
39
 
40
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
41
- warnings.filterwarnings("ignore")
42
- logging.getLogger("transformers").setLevel(logging.ERROR)
43
 
44
- MAX_SEED = np.iinfo(np.int32).max
45
 
46
  # ============================================================
47
- # Device & dtype
48
  # ============================================================
49
 
50
- cuda_available = torch.cuda.is_available()
51
- device = torch.device("cuda" if cuda_available else "cpu")
 
52
 
53
- if cuda_available and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
54
- dtype = torch.bfloat16
55
- elif cuda_available:
56
- dtype = torch.float16
57
- else:
58
- dtype = torch.float32
 
59
 
60
- MAX_IMAGE_SIZE = 1536 if cuda_available else 768
61
- fallback_msg = "" if cuda_available else "GPU unavailable. Running in CPU fallback mode (slow)."
62
 
63
  # ============================================================
64
- # Load pipeline
65
  # ============================================================
66
 
67
- pipe = None
68
- model_loaded = False
69
- load_error = None
 
 
70
 
71
- try:
72
- fp_kwargs = {"torch_dtype": dtype, "use_safetensors": True}
73
- if HF_TOKEN:
74
- fp_kwargs["token"] = HF_TOKEN
 
 
75
 
76
- pipe = StableDiffusionXLInpaintPipeline.from_pretrained(INPAINT_MODEL, **fp_kwargs).to(device)
77
- try:
78
- pipe.set_progress_bar_config(disable=True)
79
- except Exception:
80
- pass
 
 
 
81
 
82
- model_loaded = True
83
- except Exception as e:
84
- load_error = repr(e)
85
- model_loaded = False
86
 
87
  # ============================================================
88
- # Helpers
89
  # ============================================================
90
 
91
- def make_error_image(w: int, h: int) -> Image.Image:
92
- return Image.new("RGB", (int(w), int(h)), (18, 18, 22))
 
 
 
93
 
94
- def _resize(img: Image.Image, w: int, h: int, is_mask: bool = False) -> Image.Image:
95
- if img.size == (w, h):
96
- return img
97
- return img.resize((w, h), Image.NEAREST if is_mask else Image.LANCZOS)
98
 
99
- def extract_image_and_mask(editor_value):
100
- """
101
- Gradio ImageEditor returns a dict-like object (varies a bit by version),
102
- usually containing:
103
- - "background": PIL image (original)
104
- - "layers": list of PIL images (paint strokes etc.)
105
- - "composite": PIL image (background + paint)
106
- We build a binary-ish mask from the difference between composite and background.
107
- White = edit, Black = keep.
108
- """
109
- if editor_value is None:
110
- return None, None
111
 
112
- background = editor_value.get("background", None)
113
- composite = editor_value.get("composite", None)
 
 
 
 
 
114
 
115
- if not isinstance(background, Image.Image) or not isinstance(composite, Image.Image):
116
- return None, None
117
 
118
- background = background.convert("RGB")
119
- composite = composite.convert("RGB")
 
 
 
 
 
 
120
 
121
- # Mask = difference between composite and background (where user painted)
122
- diff = ImageChops.difference(composite, background).convert("L")
 
123
 
124
- # Make it more binary (stronger mask)
125
- # Pixels > threshold become white (edit region)
126
- threshold = 10
127
- mask = diff.point(lambda p: 255 if p > threshold else 0).convert("L")
 
 
128
 
129
- return background, mask
 
 
 
130
 
131
- # ============================================================
132
- # Inference
133
- # ============================================================
134
 
135
- def _infer_impl(
136
- prompt,
137
- negative_prompt,
138
- seed,
139
- randomize_seed,
140
- width,
141
- height,
142
- guidance_scale,
143
- num_inference_steps,
144
- painted, # ImageEditor value
145
- ):
146
- width = int(width)
147
- height = int(height)
148
- seed = int(seed)
149
-
150
- if not model_loaded:
151
- return make_error_image(width, height), f"Model load failed: {load_error}"
152
-
153
- prompt = (prompt or "").strip()
154
- if not prompt:
155
- return make_error_image(width, height), "Error: prompt is empty."
156
-
157
- init_image, mask_image = extract_image_and_mask(painted)
158
- if init_image is None or mask_image is None:
159
- return make_error_image(width, height), "Error: upload an image and paint over the area you want to change."
160
-
161
- if randomize_seed:
162
- seed = random.randint(0, MAX_SEED)
163
-
164
- generator = torch.Generator(device=device).manual_seed(seed)
165
-
166
- neg = (negative_prompt or "").strip()
167
- if not neg:
168
- neg = None
169
-
170
- init_image = _resize(init_image, width, height, is_mask=False)
171
- mask_image = _resize(mask_image, width, height, is_mask=True)
172
-
173
- status = f"Seed: {seed}"
174
- if fallback_msg:
175
- status += f" | {fallback_msg}"
176
-
177
- try:
178
- with torch.inference_mode():
179
- if device.type == "cuda":
180
- with torch.autocast("cuda", dtype=dtype):
181
- out = pipe(
182
- prompt=prompt,
183
- negative_prompt=neg,
184
- image=init_image,
185
- mask_image=mask_image,
186
- width=width,
187
- height=height,
188
- guidance_scale=float(guidance_scale),
189
- num_inference_steps=int(num_inference_steps),
190
- generator=generator,
191
- )
192
- else:
193
- out = pipe(
194
- prompt=prompt,
195
- negative_prompt=neg,
196
- image=init_image,
197
- mask_image=mask_image,
198
- width=width,
199
- height=height,
200
- guidance_scale=float(guidance_scale),
201
- num_inference_steps=int(num_inference_steps),
202
- generator=generator,
203
- )
204
-
205
- return out.images[0], status
206
-
207
- except Exception as e:
208
- return make_error_image(width, height), f"Error: {type(e).__name__}: {e}"
209
-
210
- finally:
211
- gc.collect()
212
- if device.type == "cuda":
213
- torch.cuda.empty_cache()
214
-
215
- if SPACES_AVAILABLE:
216
- @spaces.GPU
217
- def infer(*args, **kwargs):
218
- return _infer_impl(*args, **kwargs)
219
- else:
220
- def infer(*args, **kwargs):
221
- return _infer_impl(*args, **kwargs)
222
 
223
  # ============================================================
224
- # UI
225
  # ============================================================
226
 
227
- CSS = "body { background: #000; color: #fff; }"
 
 
228
 
229
- with gr.Blocks(title="Inpainting (paint-to-edit)") as demo:
230
- gr.HTML(f"<style>{CSS}</style>")
 
231
 
232
- if fallback_msg:
233
- gr.Markdown(f"**{fallback_msg}**")
234
 
235
- if not model_loaded:
236
- gr.Markdown(f"⚠️ Model failed to load:\n\n{load_error}")
 
 
 
 
 
 
 
237
 
238
- gr.Markdown("## Inpainting (paint the area you want to change)")
239
- gr.Markdown("Upload an image, then paint over the clothing area, then describe the new clothing in the prompt.")
 
240
 
241
- prompt = gr.Textbox(label="Prompt (describe new clothing)", lines=2)
242
- negative_prompt = gr.Textbox(label="Negative prompt (optional)", lines=2)
243
 
244
- painted = gr.ImageEditor(
245
- label="Image editor (paint where you want to edit)",
246
- type="pil",
247
- )
248
 
249
- run_button = gr.Button("Inpaint")
250
- result = gr.Image(label="Result")
251
- status = gr.Markdown("")
252
-
253
- with gr.Accordion("Advanced Settings", open=False):
254
- seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
255
- randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
256
-
257
- width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Width")
258
- height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Height")
259
-
260
- guidance_scale = gr.Slider(0.0, 15.0, step=0.1, value=7.0, label="Guidance scale")
261
- num_inference_steps = gr.Slider(1, 80, step=1, value=30, label="Steps")
262
-
263
- run_button.click(
264
- fn=infer,
265
- inputs=[
266
- prompt,
267
- negative_prompt,
268
- seed,
269
- randomize_seed,
270
- width,
271
- height,
272
- guidance_scale,
273
- num_inference_steps,
274
- painted,
275
- ],
276
- outputs=[result, status],
277
  )
278
 
279
- if __name__ == "__main__":
280
- demo.queue().launch(ssr_mode=False)
 
 
 
 
 
 
1
  import os
2
  import gc
3
  import random
 
 
 
 
 
 
 
 
 
 
 
4
  import numpy as np
 
 
5
  import torch
6
+ import gradio as gr
7
+ import cv2
 
 
 
 
8
 
9
+ from PIL import Image
10
+ from diffusers import StableDiffusionXLInpaintPipeline
11
+ from huggingface_hub import login
 
12
 
13
+ # --- GroundingDINO ---
14
+ from groundingdino.util.inference import load_model, predict
 
15
 
16
+ # --- SAM ---
17
+ from segment_anything import sam_model_registry, SamPredictor
 
18
 
 
19
 
20
  # ============================================================
21
+ # CONFIG
22
  # ============================================================
23
 
24
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
25
+ if HF_TOKEN:
26
+ login(HF_TOKEN)
27
 
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
30
+
31
+ INPAINT_MODEL = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
32
+
33
+ # Clothing keywords (you can tune this)
34
+ CLOTHING_PROMPT = "shirt, jacket, coat, dress, hoodie, sweater, t-shirt"
35
 
 
 
36
 
37
  # ============================================================
38
+ # LOAD MODELS
39
  # ============================================================
40
 
41
+ # --- GroundingDINO ---
42
+ dino = load_model(
43
+ "GroundingDINO/groundingdino_swint_ogc.pth",
44
+ "GroundingDINO/groundingdino_swint_ogc.cfg.py",
45
+ )
46
 
47
+ # --- SAM ---
48
+ sam = sam_model_registry["vit_h"](
49
+ checkpoint="sam_vit_h_4b8939.pth"
50
+ )
51
+ sam.to(DEVICE)
52
+ sam_predictor = SamPredictor(sam)
53
 
54
+ # --- SDXL Inpaint ---
55
+ pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
56
+ INPAINT_MODEL,
57
+ torch_dtype=DTYPE,
58
+ use_safetensors=True,
59
+ ).to(DEVICE)
60
+
61
+ pipe.set_progress_bar_config(disable=True)
62
 
 
 
 
 
63
 
64
  # ============================================================
65
+ # UTILS
66
  # ============================================================
67
 
68
+ def pil_to_cv(img):
69
+ return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
70
+
71
+ def cv_to_pil(img):
72
+ return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
73
 
 
 
 
 
74
 
75
+ def detect_clothing_mask(image: Image.Image):
76
+ """Automatically detect clothing and return a binary mask"""
77
+ img_cv = pil_to_cv(image)
78
+ h, w, _ = img_cv.shape
 
 
 
 
 
 
 
 
79
 
80
+ boxes, _, _ = predict(
81
+ model=dino,
82
+ image=img_cv,
83
+ caption=CLOTHING_PROMPT,
84
+ box_threshold=0.35,
85
+ text_threshold=0.25,
86
+ )
87
 
88
+ if len(boxes) == 0:
89
+ return None
90
 
91
+ # Convert normalized boxes to pixels
92
+ boxes_px = []
93
+ for box in boxes:
94
+ x1 = int((box[0] - box[2] / 2) * w)
95
+ y1 = int((box[1] - box[3] / 2) * h)
96
+ x2 = int((box[0] + box[2] / 2) * w)
97
+ y2 = int((box[1] + box[3] / 2) * h)
98
+ boxes_px.append([x1, y1, x2, y2])
99
 
100
+ # SAM segmentation
101
+ sam_predictor.set_image(img_cv)
102
+ masks = []
103
 
104
+ for box in boxes_px:
105
+ mask, _, _ = sam_predictor.predict(
106
+ box=np.array(box),
107
+ multimask_output=False,
108
+ )
109
+ masks.append(mask[0])
110
 
111
+ # Merge all masks
112
+ full_mask = np.zeros((h, w), dtype=np.uint8)
113
+ for m in masks:
114
+ full_mask[m] = 255
115
 
116
+ return Image.fromarray(full_mask)
 
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  # ============================================================
120
+ # INFERENCE
121
  # ============================================================
122
 
123
+ def replace_clothing(image, prompt, seed):
124
+ if image is None or not prompt:
125
+ return None, "Upload an image and provide a prompt."
126
 
127
+ mask = detect_clothing_mask(image)
128
+ if mask is None:
129
+ return image, "No clothing detected."
130
 
131
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
 
132
 
133
+ with torch.inference_mode():
134
+ out = pipe(
135
+ prompt=prompt,
136
+ image=image,
137
+ mask_image=mask,
138
+ guidance_scale=7.0,
139
+ num_inference_steps=30,
140
+ generator=generator,
141
+ )
142
 
143
+ gc.collect()
144
+ if DEVICE == "cuda":
145
+ torch.cuda.empty_cache()
146
 
147
+ return out.images[0], "Clothing replaced automatically."
 
148
 
 
 
 
 
149
 
150
+ # ============================================================
151
+ # UI
152
+ # ============================================================
153
+
154
+ with gr.Blocks(title="Auto Clothing Replacement") as demo:
155
+ gr.Markdown("## Automatic Clothing Replacement (no mask, no painting)")
156
+ gr.Markdown("Upload a photo, describe the new outfit. Everything else is automatic.")
157
+
158
+ image = gr.Image(type="pil", label="Input image")
159
+ prompt = gr.Textbox(label="New clothing description")
160
+ seed = gr.Slider(0, 999999, value=0, label="Seed")
161
+
162
+ run = gr.Button("Replace Clothing")
163
+ output = gr.Image(label="Result")
164
+ status = gr.Markdown()
165
+
166
+ run.click(
167
+ replace_clothing,
168
+ inputs=[image, prompt, seed],
169
+ outputs=[output, status],
 
 
 
 
 
 
 
 
170
  )
171
 
172
+ demo.launch()