telcom commited on
Commit
fd2d5e7
·
verified ·
1 Parent(s): f762924

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +384 -103
app.py CHANGED
@@ -1,172 +1,453 @@
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import gc
3
  import random
 
 
 
4
  import numpy as np
5
- import torch
6
  import gradio as gr
7
- import cv2
8
-
9
  from PIL import Image
 
 
 
 
 
10
  from diffusers import StableDiffusionXLInpaintPipeline
11
- from huggingface_hub import login
12
 
13
- # --- GroundingDINO ---
14
  from groundingdino.util.inference import load_model, predict
15
 
16
- # --- SAM ---
17
  from segment_anything import sam_model_registry, SamPredictor
18
 
19
 
20
  # ============================================================
21
- # CONFIG
 
 
 
 
 
 
 
 
 
 
22
  # ============================================================
 
 
 
23
 
24
- HF_TOKEN = os.getenv("HF_TOKEN", "")
25
  if HF_TOKEN:
26
- login(HF_TOKEN)
 
 
27
 
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
- DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
 
 
30
 
31
- INPAINT_MODEL = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
 
32
 
33
- # Clothing keywords (you can tune this)
34
- CLOTHING_PROMPT = "shirt, jacket, coat, dress, hoodie, sweater, t-shirt"
 
 
 
 
 
 
 
35
 
36
 
37
  # ============================================================
38
- # LOAD MODELS
39
  # ============================================================
40
 
41
- # --- GroundingDINO ---
42
- dino = load_model(
43
- "GroundingDINO/groundingdino_swint_ogc.pth",
44
- "GroundingDINO/groundingdino_swint_ogc.cfg.py",
45
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # --- SAM ---
48
- sam = sam_model_registry["vit_h"](
49
- checkpoint="sam_vit_h_4b8939.pth"
50
- )
51
- sam.to(DEVICE)
52
- sam_predictor = SamPredictor(sam)
53
 
54
- # --- SDXL Inpaint ---
55
- pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
56
- INPAINT_MODEL,
57
- torch_dtype=DTYPE,
58
- use_safetensors=True,
59
- ).to(DEVICE)
60
 
61
- pipe.set_progress_bar_config(disable=True)
 
 
 
 
 
 
62
 
63
 
64
  # ============================================================
65
- # UTILS
66
  # ============================================================
67
 
68
- def pil_to_cv(img):
69
- return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
 
 
 
 
 
 
 
 
 
70
 
71
- def cv_to_pil(img):
72
- return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
73
 
 
 
 
 
 
 
74
 
75
- def detect_clothing_mask(image: Image.Image):
76
- """Automatically detect clothing and return a binary mask"""
77
- img_cv = pil_to_cv(image)
78
- h, w, _ = img_cv.shape
 
 
 
 
 
 
79
 
80
- boxes, _, _ = predict(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  model=dino,
82
- image=img_cv,
83
- caption=CLOTHING_PROMPT,
84
- box_threshold=0.35,
85
- text_threshold=0.25,
86
  )
87
 
88
- if len(boxes) == 0:
89
  return None
90
 
91
- # Convert normalized boxes to pixels
 
92
  boxes_px = []
93
- for box in boxes:
94
- x1 = int((box[0] - box[2] / 2) * w)
95
- y1 = int((box[1] - box[3] / 2) * h)
96
- x2 = int((box[0] + box[2] / 2) * w)
97
- y2 = int((box[1] + box[3] / 2) * h)
98
- boxes_px.append([x1, y1, x2, y2])
 
 
 
 
 
 
 
 
 
99
 
100
- # SAM segmentation
101
- sam_predictor.set_image(img_cv)
102
- masks = []
 
103
 
104
  for box in boxes_px:
105
- mask, _, _ = sam_predictor.predict(
106
- box=np.array(box),
 
 
 
107
  multimask_output=False,
108
  )
109
- masks.append(mask[0])
 
110
 
111
- # Merge all masks
112
- full_mask = np.zeros((h, w), dtype=np.uint8)
113
- for m in masks:
114
- full_mask[m] = 255
 
 
115
 
116
- return Image.fromarray(full_mask)
117
 
118
 
119
  # ============================================================
120
- # INFERENCE
121
  # ============================================================
122
 
123
- def replace_clothing(image, prompt, seed):
124
- if image is None or not prompt:
125
- return None, "Upload an image and provide a prompt."
126
-
127
- mask = detect_clothing_mask(image)
128
- if mask is None:
129
- return image, "No clothing detected."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
132
 
133
- with torch.inference_mode():
134
- out = pipe(
135
- prompt=prompt,
136
- image=image,
137
- mask_image=mask,
138
- guidance_scale=7.0,
139
- num_inference_steps=30,
140
- generator=generator,
141
- )
142
-
143
- gc.collect()
144
- if DEVICE == "cuda":
145
- torch.cuda.empty_cache()
 
146
 
147
- return out.images[0], "Clothing replaced automatically."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
 
150
  # ============================================================
151
  # UI
152
  # ============================================================
153
 
154
- with gr.Blocks(title="Auto Clothing Replacement") as demo:
155
- gr.Markdown("## Automatic Clothing Replacement (no mask, no painting)")
156
- gr.Markdown("Upload a photo, describe the new outfit. Everything else is automatic.")
157
-
158
- image = gr.Image(type="pil", label="Input image")
159
- prompt = gr.Textbox(label="New clothing description")
160
- seed = gr.Slider(0, 999999, value=0, label="Seed")
161
-
162
- run = gr.Button("Replace Clothing")
163
- output = gr.Image(label="Result")
164
- status = gr.Markdown()
165
 
166
- run.click(
167
- replace_clothing,
168
- inputs=[image, prompt, seed],
169
- outputs=[output, status],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  )
171
 
172
- demo.launch()
 
 
1
+ # app.py
2
+ # ============================================================
3
+ # Automatic clothing replacement:
4
+ # 1) Detect clothing boxes with GroundingDINO
5
+ # 2) Turn boxes into pixel mask with SAM
6
+ # 3) Inpaint mask with SDXL Inpaint
7
+ #
8
+ # Input: ONE image, NO manual paint, NO manual mask
9
+ # ============================================================
10
+
11
  import os
12
  import gc
13
  import random
14
+ import warnings
15
+ import logging
16
+
17
  import numpy as np
 
18
  import gradio as gr
 
 
19
  from PIL import Image
20
+
21
+ import torch
22
+ from huggingface_hub import login, hf_hub_download
23
+
24
+ # Diffusers SDXL inpaint
25
  from diffusers import StableDiffusionXLInpaintPipeline
 
26
 
27
+ # GroundingDINO
28
  from groundingdino.util.inference import load_model, predict
29
 
30
+ # SAM
31
  from segment_anything import sam_model_registry, SamPredictor
32
 
33
 
34
  # ============================================================
35
+ # Spaces GPU decorator (must be imported early)
36
+ # ============================================================
37
+ try:
38
+ import spaces # noqa: F401
39
+ SPACES_AVAILABLE = True
40
+ except Exception:
41
+ SPACES_AVAILABLE = False
42
+
43
+
44
+ # ============================================================
45
+ # Basic config
46
  # ============================================================
47
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
48
+ warnings.filterwarnings("ignore")
49
+ logging.getLogger("transformers").setLevel(logging.ERROR)
50
 
51
+ HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
52
  if HF_TOKEN:
53
+ login(token=HF_TOKEN)
54
+
55
+ MAX_SEED = np.iinfo(np.int32).max
56
 
57
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
58
+ DTYPE = torch.bfloat16 if (DEVICE == "cuda" and torch.cuda.is_bf16_supported()) else (torch.float16 if DEVICE == "cuda" else torch.float32)
59
+
60
+ MAX_IMAGE_SIZE = 1024 if DEVICE == "cuda" else 768
61
 
62
+ # You can tune what the detector looks for
63
+ DEFAULT_CLOTHING_QUERY = "shirt, t-shirt, jacket, coat, hoodie, sweater, dress, pants, jeans, skirt, clothing"
64
 
65
+ # SDXL inpaint model
66
+ INPAINT_MODEL = os.environ.get(
67
+ "INPAINT_MODEL",
68
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
69
+ ).strip()
70
+
71
+ # Detection thresholds (tune for your data)
72
+ DEFAULT_BOX_THRESHOLD = 0.35
73
+ DEFAULT_TEXT_THRESHOLD = 0.25
74
 
75
 
76
  # ============================================================
77
+ # Model loading with hf_hub_download (no local file assumptions)
78
  # ============================================================
79
 
80
+ model_loaded = False
81
+ load_error = None
82
+
83
+ dino = None
84
+ sam_predictor = None
85
+ pipe = None
86
+
87
+ def _download_and_load_models():
88
+ global dino, sam_predictor, pipe
89
+
90
+ # --------------------------
91
+ # 1) GroundingDINO download
92
+ # --------------------------
93
+ # Official repo commonly used on HF Hub
94
+ DINO_REPO = "IDEA-Research/GroundingDINO"
95
+ dino_cfg_path = hf_hub_download(
96
+ repo_id=DINO_REPO,
97
+ filename="groundingdino/config/GroundingDINO_SwinT_OGC.py",
98
+ token=HF_TOKEN if HF_TOKEN else None,
99
+ )
100
+ dino_ckpt_path = hf_hub_download(
101
+ repo_id=DINO_REPO,
102
+ filename="groundingdino_swint_ogc.pth",
103
+ token=HF_TOKEN if HF_TOKEN else None,
104
+ )
105
+ dino = load_model(dino_cfg_path, dino_ckpt_path)
106
+
107
+ # --------------------------
108
+ # 2) SAM download
109
+ # --------------------------
110
+ # Many installs use this HF repo mirror
111
+ SAM_REPO = "facebook/sam-vit-huge"
112
+ sam_ckpt_path = hf_hub_download(
113
+ repo_id=SAM_REPO,
114
+ filename="sam_vit_h_4b8939.pth",
115
+ token=HF_TOKEN if HF_TOKEN else None,
116
+ )
117
+
118
+ sam = sam_model_registry["vit_h"](checkpoint=sam_ckpt_path)
119
+ sam.to(DEVICE)
120
+ sam_predictor = SamPredictor(sam)
121
+
122
+ # --------------------------
123
+ # 3) SDXL Inpaint pipeline
124
+ # --------------------------
125
+ fp_kwargs = {
126
+ "torch_dtype": DTYPE,
127
+ "use_safetensors": True,
128
+ }
129
+ if HF_TOKEN:
130
+ fp_kwargs["token"] = HF_TOKEN
131
 
132
+ pipe = StableDiffusionXLInpaintPipeline.from_pretrained(INPAINT_MODEL, **fp_kwargs).to(DEVICE)
 
 
 
 
 
133
 
134
+ try:
135
+ pipe.set_progress_bar_config(disable=True)
136
+ except Exception:
137
+ pass
 
 
138
 
139
+
140
+ try:
141
+ _download_and_load_models()
142
+ model_loaded = True
143
+ except Exception as e:
144
+ model_loaded = False
145
+ load_error = repr(e)
146
 
147
 
148
  # ============================================================
149
+ # Image helpers
150
  # ============================================================
151
 
152
+ def make_error_image(w: int, h: int) -> Image.Image:
153
+ return Image.new("RGB", (int(w), int(h)), (18, 18, 22))
154
+
155
+ def _fit_to_multiple_of_64(w: int, h: int):
156
+ # SDXL likes multiples of 64
157
+ w = max(256, (w // 64) * 64)
158
+ h = max(256, (h // 64) * 64)
159
+ return w, h
160
+
161
+ def _resize_rgb(img: Image.Image, w: int, h: int) -> Image.Image:
162
+ return img.convert("RGB").resize((w, h), Image.LANCZOS)
163
 
164
+ def _resize_mask(mask: Image.Image, w: int, h: int) -> Image.Image:
165
+ return mask.convert("L").resize((w, h), Image.NEAREST)
166
 
167
+ def _dilate_mask(mask_np: np.ndarray, radius: int) -> np.ndarray:
168
+ if radius <= 0:
169
+ return mask_np
170
+ import cv2
171
+ kernel = np.ones((radius * 2 + 1, radius * 2 + 1), np.uint8)
172
+ return cv2.dilate(mask_np, kernel, iterations=1)
173
 
174
+ def _largest_component(mask_np: np.ndarray) -> np.ndarray:
175
+ # Optional cleanup: keep only largest connected region
176
+ import cv2
177
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_np, connectivity=8)
178
+ if num_labels <= 1:
179
+ return mask_np
180
+ largest = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
181
+ out = np.zeros_like(mask_np)
182
+ out[labels == largest] = 255
183
+ return out
184
 
185
+
186
+ # ============================================================
187
+ # Detect clothing and create a mask
188
+ # ============================================================
189
+
190
+ def detect_clothing_mask(
191
+ image: Image.Image,
192
+ clothing_query: str,
193
+ box_threshold: float,
194
+ text_threshold: float,
195
+ dilate_radius: int,
196
+ keep_largest: bool,
197
+ ):
198
+ """
199
+ Returns a PIL L mask: white = edit, black = keep
200
+ """
201
+ if image is None:
202
+ return None
203
+
204
+ img_rgb = image.convert("RGB")
205
+ w, h = img_rgb.size
206
+
207
+ # GroundingDINO expects numpy image (H,W,3) in RGB usually
208
+ img_np = np.array(img_rgb)
209
+
210
+ boxes, logits, phrases = predict(
211
  model=dino,
212
+ image=img_np,
213
+ caption=clothing_query,
214
+ box_threshold=float(box_threshold),
215
+ text_threshold=float(text_threshold),
216
  )
217
 
218
+ if boxes is None or len(boxes) == 0:
219
  return None
220
 
221
+ # Convert boxes to pixel coords
222
+ # GroundingDINO returns boxes as [cx, cy, w, h] normalized (0..1)
223
  boxes_px = []
224
+ for b in boxes:
225
+ cx, cy, bw, bh = float(b[0]), float(b[1]), float(b[2]), float(b[3])
226
+ x1 = int((cx - bw / 2.0) * w)
227
+ y1 = int((cy - bh / 2.0) * h)
228
+ x2 = int((cx + bw / 2.0) * w)
229
+ y2 = int((cy + bh / 2.0) * h)
230
+ x1 = max(0, min(w - 1, x1))
231
+ y1 = max(0, min(h - 1, y1))
232
+ x2 = max(0, min(w - 1, x2))
233
+ y2 = max(0, min(h - 1, y2))
234
+ if x2 > x1 and y2 > y1:
235
+ boxes_px.append([x1, y1, x2, y2])
236
+
237
+ if not boxes_px:
238
+ return None
239
 
240
+ # SAM segmentation on original resolution
241
+ sam_predictor.set_image(img_np)
242
+
243
+ full_mask = np.zeros((h, w), dtype=np.uint8)
244
 
245
  for box in boxes_px:
246
+ # SAM expects box in XYXY pixel coords
247
+ box_arr = np.array(box, dtype=np.float32)
248
+
249
+ masks, scores, _ = sam_predictor.predict(
250
+ box=box_arr,
251
  multimask_output=False,
252
  )
253
+ m = masks[0].astype(np.uint8) * 255
254
+ full_mask = np.maximum(full_mask, m)
255
 
256
+ # Optional cleanup
257
+ if keep_largest:
258
+ full_mask = _largest_component(full_mask)
259
+
260
+ # Optional dilation to cover seams and edges
261
+ full_mask = _dilate_mask(full_mask, int(dilate_radius))
262
 
263
+ return Image.fromarray(full_mask, mode="L")
264
 
265
 
266
  # ============================================================
267
+ # Inference
268
  # ============================================================
269
 
270
+ def _infer_impl(
271
+ image,
272
+ prompt,
273
+ negative_prompt,
274
+ clothing_query,
275
+ seed,
276
+ randomize_seed,
277
+ width,
278
+ height,
279
+ guidance_scale,
280
+ num_inference_steps,
281
+ box_threshold,
282
+ text_threshold,
283
+ dilate_radius,
284
+ keep_largest,
285
+ ):
286
+ width = int(width)
287
+ height = int(height)
288
+
289
+ if not model_loaded:
290
+ return make_error_image(width, height), f"Model load failed: {load_error}"
291
+
292
+ if image is None:
293
+ return make_error_image(width, height), "Error: please upload an image."
294
+
295
+ prompt = (prompt or "").strip()
296
+ if not prompt:
297
+ return make_error_image(width, height), "Error: prompt is empty."
298
+
299
+ neg = (negative_prompt or "").strip()
300
+ if not neg:
301
+ neg = None
302
+
303
+ clothing_query = (clothing_query or "").strip()
304
+ if not clothing_query:
305
+ clothing_query = DEFAULT_CLOTHING_QUERY
306
+
307
+ # Seed
308
+ if randomize_seed:
309
+ seed = random.randint(0, MAX_SEED)
310
+ else:
311
+ seed = int(seed)
312
 
313
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
314
 
315
+ # Fit resolution
316
+ width, height = _fit_to_multiple_of_64(width, height)
317
+ width = min(width, MAX_IMAGE_SIZE)
318
+ height = min(height, MAX_IMAGE_SIZE)
319
+
320
+ # Detect clothing mask on original image
321
+ mask = detect_clothing_mask(
322
+ image=image,
323
+ clothing_query=clothing_query,
324
+ box_threshold=float(box_threshold),
325
+ text_threshold=float(text_threshold),
326
+ dilate_radius=int(dilate_radius),
327
+ keep_largest=bool(keep_largest),
328
+ )
329
 
330
+ if mask is None:
331
+ return image, f"Seed: {seed}. No clothing region detected, try adjusting thresholds or query."
332
+
333
+ # Resize image and mask to target size
334
+ img_resized = _resize_rgb(image, width, height)
335
+ mask_resized = _resize_mask(mask, width, height)
336
+
337
+ status = f"Seed: {seed}"
338
+ if DEVICE != "cuda":
339
+ status += " | Running on CPU, this will be slow."
340
+
341
+ try:
342
+ with torch.inference_mode():
343
+ if DEVICE == "cuda":
344
+ with torch.autocast("cuda", dtype=DTYPE):
345
+ out = pipe(
346
+ prompt=prompt,
347
+ negative_prompt=neg,
348
+ image=img_resized,
349
+ mask_image=mask_resized,
350
+ guidance_scale=float(guidance_scale),
351
+ num_inference_steps=int(num_inference_steps),
352
+ generator=generator,
353
+ )
354
+ else:
355
+ out = pipe(
356
+ prompt=prompt,
357
+ negative_prompt=neg,
358
+ image=img_resized,
359
+ mask_image=mask_resized,
360
+ guidance_scale=float(guidance_scale),
361
+ num_inference_steps=int(num_inference_steps),
362
+ generator=generator,
363
+ )
364
+
365
+ result = out.images[0]
366
+ return result, status
367
+
368
+ except Exception as e:
369
+ return make_error_image(width, height), f"Error: {type(e).__name__}: {e}"
370
+
371
+ finally:
372
+ gc.collect()
373
+ if DEVICE == "cuda":
374
+ torch.cuda.empty_cache()
375
+
376
+
377
+ if SPACES_AVAILABLE:
378
+ @spaces.GPU
379
+ def infer(*args, **kwargs):
380
+ return _infer_impl(*args, **kwargs)
381
+ else:
382
+ def infer(*args, **kwargs):
383
+ return _infer_impl(*args, **kwargs)
384
 
385
 
386
  # ============================================================
387
  # UI
388
  # ============================================================
389
 
390
+ CSS = """
391
+ body { background: #000; color: #fff; }
392
+ """
 
 
 
 
 
 
 
 
393
 
394
+ with gr.Blocks(title="Auto Clothing Replacement") as demo:
395
+ gr.HTML(f"<style>{CSS}</style>")
396
+
397
+ gr.Markdown("## Automatic Clothing Replacement (no paint, no manual mask)")
398
+ gr.Markdown("Upload a photo, describe the new clothing. The system detects clothing and inpaints automatically.")
399
+
400
+ if not model_loaded:
401
+ gr.Markdown(f"⚠️ Model failed to load:\n\n{load_error}")
402
+
403
+ with gr.Row():
404
+ image = gr.Image(type="pil", label="Input image")
405
+
406
+ prompt = gr.Textbox(label="Prompt (describe new clothing)", lines=2, placeholder="e.g., a red leather jacket with zipper, realistic fabric folds")
407
+ negative_prompt = gr.Textbox(label="Negative prompt (optional)", lines=2, placeholder="e.g., blurry, deformed, low quality")
408
+
409
+ run_button = gr.Button("Replace Clothing")
410
+ result = gr.Image(label="Result")
411
+ status = gr.Markdown("")
412
+
413
+ with gr.Accordion("Advanced settings", open=False):
414
+ clothing_query = gr.Textbox(label="Detection query (what counts as clothing)", value=DEFAULT_CLOTHING_QUERY)
415
+
416
+ seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
417
+ randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
418
+
419
+ width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=768 if DEVICE != "cuda" else 1024, label="Width")
420
+ height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=768 if DEVICE != "cuda" else 1024, label="Height")
421
+
422
+ guidance_scale = gr.Slider(0.0, 15.0, step=0.1, value=7.0, label="Guidance scale")
423
+ num_inference_steps = gr.Slider(1, 80, step=1, value=30, label="Steps")
424
+
425
+ box_threshold = gr.Slider(0.05, 0.90, step=0.01, value=DEFAULT_BOX_THRESHOLD, label="Box threshold (GroundingDINO)")
426
+ text_threshold = gr.Slider(0.05, 0.90, step=0.01, value=DEFAULT_TEXT_THRESHOLD, label="Text threshold (GroundingDINO)")
427
+
428
+ dilate_radius = gr.Slider(0, 30, step=1, value=8, label="Mask dilation radius (cover edges)")
429
+ keep_largest = gr.Checkbox(value=True, label="Keep only largest clothing region")
430
+
431
+ run_button.click(
432
+ fn=infer,
433
+ inputs=[
434
+ image,
435
+ prompt,
436
+ negative_prompt,
437
+ clothing_query,
438
+ seed,
439
+ randomize_seed,
440
+ width,
441
+ height,
442
+ guidance_scale,
443
+ num_inference_steps,
444
+ box_threshold,
445
+ text_threshold,
446
+ dilate_radius,
447
+ keep_largest,
448
+ ],
449
+ outputs=[result, status],
450
  )
451
 
452
+ if __name__ == "__main__":
453
+ demo.queue().launch(ssr_mode=False)