telcom commited on
Commit
59bda41
·
verified ·
1 Parent(s): fd2d5e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -124
app.py CHANGED
@@ -1,11 +1,8 @@
1
  # app.py
2
  # ============================================================
3
- # Automatic clothing replacement:
4
- # 1) Detect clothing boxes with GroundingDINO
5
- # 2) Turn boxes into pixel mask with SAM
6
- # 3) Inpaint mask with SDXL Inpaint
7
- #
8
- # Input: ONE image, NO manual paint, NO manual mask
9
  # ============================================================
10
 
11
  import os
@@ -21,28 +18,24 @@ from PIL import Image
21
  import torch
22
  from huggingface_hub import login, hf_hub_download
23
 
24
- # Diffusers SDXL inpaint
25
  from diffusers import StableDiffusionXLInpaintPipeline
26
-
27
- # GroundingDINO
28
  from groundingdino.util.inference import load_model, predict
29
-
30
- # SAM
31
  from segment_anything import sam_model_registry, SamPredictor
32
 
33
 
34
  # ============================================================
35
- # Spaces GPU decorator (must be imported early)
36
  # ============================================================
37
  try:
38
- import spaces # noqa: F401
39
  SPACES_AVAILABLE = True
40
  except Exception:
 
41
  SPACES_AVAILABLE = False
42
 
43
 
44
  # ============================================================
45
- # Basic config
46
  # ============================================================
47
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
48
  warnings.filterwarnings("ignore")
@@ -54,29 +47,31 @@ if HF_TOKEN:
54
 
55
  MAX_SEED = np.iinfo(np.int32).max
56
 
57
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
58
- DTYPE = torch.bfloat16 if (DEVICE == "cuda" and torch.cuda.is_bf16_supported()) else (torch.float16 if DEVICE == "cuda" else torch.float32)
 
 
 
 
 
 
 
59
 
60
- MAX_IMAGE_SIZE = 1024 if DEVICE == "cuda" else 768
61
 
62
- # You can tune what the detector looks for
63
  DEFAULT_CLOTHING_QUERY = "shirt, t-shirt, jacket, coat, hoodie, sweater, dress, pants, jeans, skirt, clothing"
 
 
64
 
65
- # SDXL inpaint model
66
  INPAINT_MODEL = os.environ.get(
67
  "INPAINT_MODEL",
68
  "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
69
  ).strip()
70
 
71
- # Detection thresholds (tune for your data)
72
- DEFAULT_BOX_THRESHOLD = 0.35
73
- DEFAULT_TEXT_THRESHOLD = 0.25
74
-
75
 
76
  # ============================================================
77
- # Model loading with hf_hub_download (no local file assumptions)
78
  # ============================================================
79
-
80
  model_loaded = False
81
  load_error = None
82
 
@@ -84,13 +79,10 @@ dino = None
84
  sam_predictor = None
85
  pipe = None
86
 
87
- def _download_and_load_models():
88
  global dino, sam_predictor, pipe
89
 
90
- # --------------------------
91
- # 1) GroundingDINO download
92
- # --------------------------
93
- # Official repo commonly used on HF Hub
94
  DINO_REPO = "IDEA-Research/GroundingDINO"
95
  dino_cfg_path = hf_hub_download(
96
  repo_id=DINO_REPO,
@@ -104,41 +96,30 @@ def _download_and_load_models():
104
  )
105
  dino = load_model(dino_cfg_path, dino_ckpt_path)
106
 
107
- # --------------------------
108
- # 2) SAM download
109
- # --------------------------
110
- # Many installs use this HF repo mirror
111
  SAM_REPO = "facebook/sam-vit-huge"
112
  sam_ckpt_path = hf_hub_download(
113
  repo_id=SAM_REPO,
114
  filename="sam_vit_h_4b8939.pth",
115
  token=HF_TOKEN if HF_TOKEN else None,
116
  )
117
-
118
  sam = sam_model_registry["vit_h"](checkpoint=sam_ckpt_path)
119
  sam.to(DEVICE)
120
  sam_predictor = SamPredictor(sam)
121
 
122
- # --------------------------
123
- # 3) SDXL Inpaint pipeline
124
- # --------------------------
125
- fp_kwargs = {
126
- "torch_dtype": DTYPE,
127
- "use_safetensors": True,
128
- }
129
  if HF_TOKEN:
130
  fp_kwargs["token"] = HF_TOKEN
131
 
132
  pipe = StableDiffusionXLInpaintPipeline.from_pretrained(INPAINT_MODEL, **fp_kwargs).to(DEVICE)
133
-
134
  try:
135
  pipe.set_progress_bar_config(disable=True)
136
  except Exception:
137
  pass
138
 
139
-
140
  try:
141
- _download_and_load_models()
142
  model_loaded = True
143
  except Exception as e:
144
  model_loaded = False
@@ -146,33 +127,30 @@ except Exception as e:
146
 
147
 
148
  # ============================================================
149
- # Image helpers
150
  # ============================================================
151
-
152
  def make_error_image(w: int, h: int) -> Image.Image:
153
  return Image.new("RGB", (int(w), int(h)), (18, 18, 22))
154
 
155
- def _fit_to_multiple_of_64(w: int, h: int):
156
- # SDXL likes multiples of 64
157
  w = max(256, (w // 64) * 64)
158
  h = max(256, (h // 64) * 64)
159
  return w, h
160
 
161
- def _resize_rgb(img: Image.Image, w: int, h: int) -> Image.Image:
162
  return img.convert("RGB").resize((w, h), Image.LANCZOS)
163
 
164
- def _resize_mask(mask: Image.Image, w: int, h: int) -> Image.Image:
165
  return mask.convert("L").resize((w, h), Image.NEAREST)
166
 
167
- def _dilate_mask(mask_np: np.ndarray, radius: int) -> np.ndarray:
168
  if radius <= 0:
169
  return mask_np
170
  import cv2
171
  kernel = np.ones((radius * 2 + 1, radius * 2 + 1), np.uint8)
172
  return cv2.dilate(mask_np, kernel, iterations=1)
173
 
174
- def _largest_component(mask_np: np.ndarray) -> np.ndarray:
175
- # Optional cleanup: keep only largest connected region
176
  import cv2
177
  num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_np, connectivity=8)
178
  if num_labels <= 1:
@@ -183,10 +161,6 @@ def _largest_component(mask_np: np.ndarray) -> np.ndarray:
183
  return out
184
 
185
 
186
- # ============================================================
187
- # Detect clothing and create a mask
188
- # ============================================================
189
-
190
  def detect_clothing_mask(
191
  image: Image.Image,
192
  clothing_query: str,
@@ -195,19 +169,14 @@ def detect_clothing_mask(
195
  dilate_radius: int,
196
  keep_largest: bool,
197
  ):
198
- """
199
- Returns a PIL L mask: white = edit, black = keep
200
- """
201
  if image is None:
202
  return None
203
 
204
  img_rgb = image.convert("RGB")
205
  w, h = img_rgb.size
206
-
207
- # GroundingDINO expects numpy image (H,W,3) in RGB usually
208
  img_np = np.array(img_rgb)
209
 
210
- boxes, logits, phrases = predict(
211
  model=dino,
212
  image=img_np,
213
  caption=clothing_query,
@@ -218,11 +187,9 @@ def detect_clothing_mask(
218
  if boxes is None or len(boxes) == 0:
219
  return None
220
 
221
- # Convert boxes to pixel coords
222
- # GroundingDINO returns boxes as [cx, cy, w, h] normalized (0..1)
223
  boxes_px = []
224
  for b in boxes:
225
- cx, cy, bw, bh = float(b[0]), float(b[1]), float(b[2]), float(b[3])
226
  x1 = int((cx - bw / 2.0) * w)
227
  y1 = int((cy - bh / 2.0) * h)
228
  x2 = int((cx + bw / 2.0) * w)
@@ -237,37 +204,27 @@ def detect_clothing_mask(
237
  if not boxes_px:
238
  return None
239
 
240
- # SAM segmentation on original resolution
241
  sam_predictor.set_image(img_np)
242
 
243
  full_mask = np.zeros((h, w), dtype=np.uint8)
244
-
245
  for box in boxes_px:
246
- # SAM expects box in XYXY pixel coords
247
  box_arr = np.array(box, dtype=np.float32)
248
-
249
- masks, scores, _ = sam_predictor.predict(
250
- box=box_arr,
251
- multimask_output=False,
252
- )
253
  m = masks[0].astype(np.uint8) * 255
254
  full_mask = np.maximum(full_mask, m)
255
 
256
- # Optional cleanup
257
  if keep_largest:
258
- full_mask = _largest_component(full_mask)
259
 
260
- # Optional dilation to cover seams and edges
261
- full_mask = _dilate_mask(full_mask, int(dilate_radius))
262
 
263
  return Image.fromarray(full_mask, mode="L")
264
 
265
 
266
  # ============================================================
267
- # Inference
268
  # ============================================================
269
-
270
- def _infer_impl(
271
  image,
272
  prompt,
273
  negative_prompt,
@@ -290,7 +247,7 @@ def _infer_impl(
290
  return make_error_image(width, height), f"Model load failed: {load_error}"
291
 
292
  if image is None:
293
- return make_error_image(width, height), "Error: please upload an image."
294
 
295
  prompt = (prompt or "").strip()
296
  if not prompt:
@@ -300,11 +257,8 @@ def _infer_impl(
300
  if not neg:
301
  neg = None
302
 
303
- clothing_query = (clothing_query or "").strip()
304
- if not clothing_query:
305
- clothing_query = DEFAULT_CLOTHING_QUERY
306
 
307
- # Seed
308
  if randomize_seed:
309
  seed = random.randint(0, MAX_SEED)
310
  else:
@@ -312,12 +266,10 @@ def _infer_impl(
312
 
313
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
314
 
315
- # Fit resolution
316
- width, height = _fit_to_multiple_of_64(width, height)
317
  width = min(width, MAX_IMAGE_SIZE)
318
  height = min(height, MAX_IMAGE_SIZE)
319
 
320
- # Detect clothing mask on original image
321
  mask = detect_clothing_mask(
322
  image=image,
323
  clothing_query=clothing_query,
@@ -328,19 +280,18 @@ def _infer_impl(
328
  )
329
 
330
  if mask is None:
331
- return image, f"Seed: {seed}. No clothing region detected, try adjusting thresholds or query."
332
 
333
- # Resize image and mask to target size
334
- img_resized = _resize_rgb(image, width, height)
335
- mask_resized = _resize_mask(mask, width, height)
336
 
337
  status = f"Seed: {seed}"
338
- if DEVICE != "cuda":
339
- status += " | Running on CPU, this will be slow."
340
 
341
  try:
342
  with torch.inference_mode():
343
- if DEVICE == "cuda":
344
  with torch.autocast("cuda", dtype=DTYPE):
345
  out = pipe(
346
  prompt=prompt,
@@ -362,73 +313,79 @@ def _infer_impl(
362
  generator=generator,
363
  )
364
 
365
- result = out.images[0]
366
- return result, status
367
 
368
  except Exception as e:
369
  return make_error_image(width, height), f"Error: {type(e).__name__}: {e}"
370
 
371
  finally:
372
  gc.collect()
373
- if DEVICE == "cuda":
374
  torch.cuda.empty_cache()
375
 
376
 
 
 
 
 
377
  if SPACES_AVAILABLE:
378
  @spaces.GPU
379
  def infer(*args, **kwargs):
380
- return _infer_impl(*args, **kwargs)
381
  else:
382
  def infer(*args, **kwargs):
383
- return _infer_impl(*args, **kwargs)
384
 
385
 
386
  # ============================================================
387
  # UI
388
  # ============================================================
389
-
390
- CSS = """
391
- body { background: #000; color: #fff; }
392
- """
393
 
394
  with gr.Blocks(title="Auto Clothing Replacement") as demo:
395
  gr.HTML(f"<style>{CSS}</style>")
396
-
397
  gr.Markdown("## Automatic Clothing Replacement (no paint, no manual mask)")
398
- gr.Markdown("Upload a photo, describe the new clothing. The system detects clothing and inpaints automatically.")
399
 
400
  if not model_loaded:
401
  gr.Markdown(f"⚠️ Model failed to load:\n\n{load_error}")
402
 
403
- with gr.Row():
404
- image = gr.Image(type="pil", label="Input image")
405
 
406
- prompt = gr.Textbox(label="Prompt (describe new clothing)", lines=2, placeholder="e.g., a red leather jacket with zipper, realistic fabric folds")
407
- negative_prompt = gr.Textbox(label="Negative prompt (optional)", lines=2, placeholder="e.g., blurry, deformed, low quality")
 
 
 
 
 
 
 
 
408
 
409
- run_button = gr.Button("Replace Clothing")
410
- result = gr.Image(label="Result")
411
  status = gr.Markdown("")
412
 
413
  with gr.Accordion("Advanced settings", open=False):
414
- clothing_query = gr.Textbox(label="Detection query (what counts as clothing)", value=DEFAULT_CLOTHING_QUERY)
415
 
416
  seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
417
  randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
418
 
419
- width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=768 if DEVICE != "cuda" else 1024, label="Width")
420
- height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=768 if DEVICE != "cuda" else 1024, label="Height")
421
 
422
  guidance_scale = gr.Slider(0.0, 15.0, step=0.1, value=7.0, label="Guidance scale")
423
  num_inference_steps = gr.Slider(1, 80, step=1, value=30, label="Steps")
424
 
425
- box_threshold = gr.Slider(0.05, 0.90, step=0.01, value=DEFAULT_BOX_THRESHOLD, label="Box threshold (GroundingDINO)")
426
- text_threshold = gr.Slider(0.05, 0.90, step=0.01, value=DEFAULT_TEXT_THRESHOLD, label="Text threshold (GroundingDINO)")
427
 
428
- dilate_radius = gr.Slider(0, 30, step=1, value=8, label="Mask dilation radius (cover edges)")
429
- keep_largest = gr.Checkbox(value=True, label="Keep only largest clothing region")
430
 
431
- run_button.click(
432
  fn=infer,
433
  inputs=[
434
  image,
@@ -446,8 +403,8 @@ with gr.Blocks(title="Auto Clothing Replacement") as demo:
446
  dilate_radius,
447
  keep_largest,
448
  ],
449
- outputs=[result, status],
450
  )
451
 
452
  if __name__ == "__main__":
453
- demo.queue().launch(ssr_mode=False)
 
1
  # app.py
2
  # ============================================================
3
+ # Automatic clothing replacement (no paint, no manual mask)
4
+ # GroundingDINO -> SAM -> SDXL Inpaint
5
+ # Fixes: Spaces requires @spaces.GPU function at startup
 
 
 
6
  # ============================================================
7
 
8
  import os
 
18
  import torch
19
  from huggingface_hub import login, hf_hub_download
20
 
 
21
  from diffusers import StableDiffusionXLInpaintPipeline
 
 
22
  from groundingdino.util.inference import load_model, predict
 
 
23
  from segment_anything import sam_model_registry, SamPredictor
24
 
25
 
26
  # ============================================================
27
+ # Spaces import (do not hide the decorated function)
28
  # ============================================================
29
  try:
30
+ import spaces
31
  SPACES_AVAILABLE = True
32
  except Exception:
33
+ spaces = None
34
  SPACES_AVAILABLE = False
35
 
36
 
37
  # ============================================================
38
+ # Config
39
  # ============================================================
40
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
41
  warnings.filterwarnings("ignore")
 
47
 
48
  MAX_SEED = np.iinfo(np.int32).max
49
 
50
+ CUDA_OK = torch.cuda.is_available()
51
+ DEVICE = "cuda" if CUDA_OK else "cpu"
52
+
53
+ if CUDA_OK and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
54
+ DTYPE = torch.bfloat16
55
+ elif CUDA_OK:
56
+ DTYPE = torch.float16
57
+ else:
58
+ DTYPE = torch.float32
59
 
60
+ MAX_IMAGE_SIZE = 1024 if CUDA_OK else 768
61
 
 
62
  DEFAULT_CLOTHING_QUERY = "shirt, t-shirt, jacket, coat, hoodie, sweater, dress, pants, jeans, skirt, clothing"
63
+ DEFAULT_BOX_THRESHOLD = 0.35
64
+ DEFAULT_TEXT_THRESHOLD = 0.25
65
 
 
66
  INPAINT_MODEL = os.environ.get(
67
  "INPAINT_MODEL",
68
  "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
69
  ).strip()
70
 
 
 
 
 
71
 
72
  # ============================================================
73
+ # Load models (download from HF Hub)
74
  # ============================================================
 
75
  model_loaded = False
76
  load_error = None
77
 
 
79
  sam_predictor = None
80
  pipe = None
81
 
82
+ def download_and_load_models():
83
  global dino, sam_predictor, pipe
84
 
85
+ # ---- GroundingDINO ----
 
 
 
86
  DINO_REPO = "IDEA-Research/GroundingDINO"
87
  dino_cfg_path = hf_hub_download(
88
  repo_id=DINO_REPO,
 
96
  )
97
  dino = load_model(dino_cfg_path, dino_ckpt_path)
98
 
99
+ # ---- SAM ----
 
 
 
100
  SAM_REPO = "facebook/sam-vit-huge"
101
  sam_ckpt_path = hf_hub_download(
102
  repo_id=SAM_REPO,
103
  filename="sam_vit_h_4b8939.pth",
104
  token=HF_TOKEN if HF_TOKEN else None,
105
  )
 
106
  sam = sam_model_registry["vit_h"](checkpoint=sam_ckpt_path)
107
  sam.to(DEVICE)
108
  sam_predictor = SamPredictor(sam)
109
 
110
+ # ---- SDXL Inpaint ----
111
+ fp_kwargs = {"torch_dtype": DTYPE, "use_safetensors": True}
 
 
 
 
 
112
  if HF_TOKEN:
113
  fp_kwargs["token"] = HF_TOKEN
114
 
115
  pipe = StableDiffusionXLInpaintPipeline.from_pretrained(INPAINT_MODEL, **fp_kwargs).to(DEVICE)
 
116
  try:
117
  pipe.set_progress_bar_config(disable=True)
118
  except Exception:
119
  pass
120
 
 
121
  try:
122
+ download_and_load_models()
123
  model_loaded = True
124
  except Exception as e:
125
  model_loaded = False
 
127
 
128
 
129
  # ============================================================
130
+ # Helpers
131
  # ============================================================
 
132
  def make_error_image(w: int, h: int) -> Image.Image:
133
  return Image.new("RGB", (int(w), int(h)), (18, 18, 22))
134
 
135
+ def fit64(w: int, h: int):
 
136
  w = max(256, (w // 64) * 64)
137
  h = max(256, (h // 64) * 64)
138
  return w, h
139
 
140
+ def resize_rgb(img: Image.Image, w: int, h: int) -> Image.Image:
141
  return img.convert("RGB").resize((w, h), Image.LANCZOS)
142
 
143
+ def resize_mask(mask: Image.Image, w: int, h: int) -> Image.Image:
144
  return mask.convert("L").resize((w, h), Image.NEAREST)
145
 
146
+ def dilate_mask(mask_np: np.ndarray, radius: int) -> np.ndarray:
147
  if radius <= 0:
148
  return mask_np
149
  import cv2
150
  kernel = np.ones((radius * 2 + 1, radius * 2 + 1), np.uint8)
151
  return cv2.dilate(mask_np, kernel, iterations=1)
152
 
153
+ def largest_component(mask_np: np.ndarray) -> np.ndarray:
 
154
  import cv2
155
  num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_np, connectivity=8)
156
  if num_labels <= 1:
 
161
  return out
162
 
163
 
 
 
 
 
164
  def detect_clothing_mask(
165
  image: Image.Image,
166
  clothing_query: str,
 
169
  dilate_radius: int,
170
  keep_largest: bool,
171
  ):
 
 
 
172
  if image is None:
173
  return None
174
 
175
  img_rgb = image.convert("RGB")
176
  w, h = img_rgb.size
 
 
177
  img_np = np.array(img_rgb)
178
 
179
+ boxes, _, _ = predict(
180
  model=dino,
181
  image=img_np,
182
  caption=clothing_query,
 
187
  if boxes is None or len(boxes) == 0:
188
  return None
189
 
 
 
190
  boxes_px = []
191
  for b in boxes:
192
+ cx, cy, bw, bh = map(float, b)
193
  x1 = int((cx - bw / 2.0) * w)
194
  y1 = int((cy - bh / 2.0) * h)
195
  x2 = int((cx + bw / 2.0) * w)
 
204
  if not boxes_px:
205
  return None
206
 
 
207
  sam_predictor.set_image(img_np)
208
 
209
  full_mask = np.zeros((h, w), dtype=np.uint8)
 
210
  for box in boxes_px:
 
211
  box_arr = np.array(box, dtype=np.float32)
212
+ masks, _, _ = sam_predictor.predict(box=box_arr, multimask_output=False)
 
 
 
 
213
  m = masks[0].astype(np.uint8) * 255
214
  full_mask = np.maximum(full_mask, m)
215
 
 
216
  if keep_largest:
217
+ full_mask = largest_component(full_mask)
218
 
219
+ full_mask = dilate_mask(full_mask, int(dilate_radius))
 
220
 
221
  return Image.fromarray(full_mask, mode="L")
222
 
223
 
224
  # ============================================================
225
+ # Core inference (no decorator here)
226
  # ============================================================
227
+ def infer_core(
 
228
  image,
229
  prompt,
230
  negative_prompt,
 
247
  return make_error_image(width, height), f"Model load failed: {load_error}"
248
 
249
  if image is None:
250
+ return make_error_image(width, height), "Error: upload an image."
251
 
252
  prompt = (prompt or "").strip()
253
  if not prompt:
 
257
  if not neg:
258
  neg = None
259
 
260
+ clothing_query = (clothing_query or "").strip() or DEFAULT_CLOTHING_QUERY
 
 
261
 
 
262
  if randomize_seed:
263
  seed = random.randint(0, MAX_SEED)
264
  else:
 
266
 
267
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
268
 
269
+ width, height = fit64(width, height)
 
270
  width = min(width, MAX_IMAGE_SIZE)
271
  height = min(height, MAX_IMAGE_SIZE)
272
 
 
273
  mask = detect_clothing_mask(
274
  image=image,
275
  clothing_query=clothing_query,
 
280
  )
281
 
282
  if mask is None:
283
+ return image, f"Seed: {seed}. No clothing detected, try lowering thresholds or changing query."
284
 
285
+ img_resized = resize_rgb(image, width, height)
286
+ mask_resized = resize_mask(mask, width, height)
 
287
 
288
  status = f"Seed: {seed}"
289
+ if not CUDA_OK:
290
+ status += " | CPU only (slow)."
291
 
292
  try:
293
  with torch.inference_mode():
294
+ if CUDA_OK:
295
  with torch.autocast("cuda", dtype=DTYPE):
296
  out = pipe(
297
  prompt=prompt,
 
313
  generator=generator,
314
  )
315
 
316
+ return out.images[0], status
 
317
 
318
  except Exception as e:
319
  return make_error_image(width, height), f"Error: {type(e).__name__}: {e}"
320
 
321
  finally:
322
  gc.collect()
323
+ if CUDA_OK:
324
  torch.cuda.empty_cache()
325
 
326
 
327
+ # ============================================================
328
+ # IMPORTANT: Always define a @spaces.GPU function if spaces imports
329
+ # (Spaces startup checker requires it)
330
+ # ============================================================
331
  if SPACES_AVAILABLE:
332
  @spaces.GPU
333
  def infer(*args, **kwargs):
334
+ return infer_core(*args, **kwargs)
335
  else:
336
  def infer(*args, **kwargs):
337
+ return infer_core(*args, **kwargs)
338
 
339
 
340
  # ============================================================
341
  # UI
342
  # ============================================================
343
+ CSS = "body { background: #000; color: #fff; }"
 
 
 
344
 
345
  with gr.Blocks(title="Auto Clothing Replacement") as demo:
346
  gr.HTML(f"<style>{CSS}</style>")
 
347
  gr.Markdown("## Automatic Clothing Replacement (no paint, no manual mask)")
348
+ gr.Markdown("Upload a photo, describe the new clothing. Detection and masking is automatic.")
349
 
350
  if not model_loaded:
351
  gr.Markdown(f"⚠️ Model failed to load:\n\n{load_error}")
352
 
353
+ image = gr.Image(type="pil", label="Input image")
 
354
 
355
+ prompt = gr.Textbox(
356
+ label="Prompt (describe new clothing)",
357
+ lines=2,
358
+ placeholder="e.g., a navy business suit jacket, realistic fabric folds, studio lighting",
359
+ )
360
+ negative_prompt = gr.Textbox(
361
+ label="Negative prompt (optional)",
362
+ lines=2,
363
+ placeholder="e.g., blurry, deformed, low quality",
364
+ )
365
 
366
+ run = gr.Button("Replace Clothing")
367
+ out_img = gr.Image(label="Result")
368
  status = gr.Markdown("")
369
 
370
  with gr.Accordion("Advanced settings", open=False):
371
+ clothing_query = gr.Textbox(label="Detection query", value=DEFAULT_CLOTHING_QUERY)
372
 
373
  seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
374
  randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
375
 
376
+ width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=768 if not CUDA_OK else 1024, label="Width")
377
+ height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=768 if not CUDA_OK else 1024, label="Height")
378
 
379
  guidance_scale = gr.Slider(0.0, 15.0, step=0.1, value=7.0, label="Guidance scale")
380
  num_inference_steps = gr.Slider(1, 80, step=1, value=30, label="Steps")
381
 
382
+ box_threshold = gr.Slider(0.05, 0.90, step=0.01, value=DEFAULT_BOX_THRESHOLD, label="Box threshold (DINO)")
383
+ text_threshold = gr.Slider(0.05, 0.90, step=0.01, value=DEFAULT_TEXT_THRESHOLD, label="Text threshold (DINO)")
384
 
385
+ dilate_radius = gr.Slider(0, 30, step=1, value=8, label="Mask dilation radius")
386
+ keep_largest = gr.Checkbox(value=True, label="Keep only largest region")
387
 
388
+ run.click(
389
  fn=infer,
390
  inputs=[
391
  image,
 
403
  dilate_radius,
404
  keep_largest,
405
  ],
406
+ outputs=[out_img, status],
407
  )
408
 
409
  if __name__ == "__main__":
410
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)