CI commited on
Commit
6933b0e
·
0 Parent(s):
.github/workflows/deploy.yml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub
2
+ on:
3
+ push:
4
+ branches: [deploy]
5
+ # to run this workflow manually from the Actions tab
6
+ workflow_dispatch:
7
+
8
+ jobs:
9
+ sync-to-hub:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - uses: actions/checkout@v3
13
+ with:
14
+ fetch-depth: 1
15
+ - name: Push to hub
16
+ env:
17
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
18
+ run: |
19
+ git config user.email "ci@github.com"
20
+ git config user.name "CI"
21
+ git checkout --orphan hf-clean
22
+ git add -A
23
+ git commit -m "deploy"
24
+ git push https://mvp-lab:$HF_TOKEN@huggingface.co/spaces/mvp-lab/VTON_TEST hf-clean:main -f
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ data/examples/portraits/
2
+ data/examples/garments/
3
+ data/examples/results/
4
+ data/user_uploads/
5
+ __pycache__/
6
+ *.pyc
7
+ .venv/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "fashn-vton-1.5"]
2
+ path = fashn-vton-1.5
3
+ url = https://github.com/fashn-AI/fashn-vton-1.5.git
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MultiSubjectVTON
3
+ emoji: 📊
4
+ colorFrom: red
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 6.7.0
8
+ python_version: '3.12'
9
+ app_file: app.py
10
+ pinned: false
11
+ short_description: Multi-subject VTON model
12
+ fullWidth: false
13
+ logs: true
14
+ ---
app.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from fashn_vton import TryOnPipeline
7
+ from ultralytics import YOLO
8
+ import gradio as gr
9
+ from pathlib import Path
10
+ import subprocess
11
+ import sys
12
+ from scipy.spatial import cKDTree
13
+ from ui import build_demo
14
+
15
+
16
+ class MultiPersonVTON:
17
+ def __init__(self, weights_dir="./weights"):
18
+ print("Initializing Multi-Person VTON pipeline...")
19
+ self.pipeline = TryOnPipeline(weights_dir=weights_dir)
20
+ self.model = YOLO("yolo26n-seg.pt")
21
+ print("Pipeline initialized")
22
+
23
+ def get_mask(self, result, H, W):
24
+ cls_ids = result.boxes.cls.cpu().numpy().astype(int)
25
+ person_idxs = cls_ids == 0
26
+ person_polygons = [poly for poly, keep in zip(result.masks.xy, person_idxs) if keep]
27
+ masks = []
28
+ for poly in person_polygons:
29
+ mask = np.zeros((H, W), dtype=np.uint8)
30
+ poly_int = np.round(poly).astype(np.int32)
31
+ cv2.fillPoly(mask, [poly_int], 1)
32
+ masks.append(mask.astype(bool))
33
+ return masks
34
+
35
+ def extract_people(self, img, masks):
36
+ img_np = np.array(img) if isinstance(img, Image.Image) else img.copy()
37
+ people = []
38
+ for mask in masks:
39
+ cutout = img_np.copy()
40
+ cutout[~mask] = 255
41
+ people.append(Image.fromarray(cutout))
42
+ return people
43
+
44
+ def apply_vton_to_people(self, people, assignments):
45
+ """Apply VTON per person based on individual assignments.
46
+
47
+ assignments: list of {"garment": PIL.Image|None, "category": str} per person.
48
+ If garment is None, person is kept as-is (skipped).
49
+ """
50
+ vton_people = []
51
+ for i, person in enumerate(people):
52
+ garment = assignments[i]["garment"]
53
+ if garment is not None:
54
+ result = self.pipeline(
55
+ person_image=person,
56
+ garment_image=garment,
57
+ category=assignments[i]["category"]
58
+ )
59
+ vton_people.append(result.images[0])
60
+ else:
61
+ vton_people.append(person)
62
+ return vton_people
63
+
64
+ def get_vton_masks(self, vton_people):
65
+ vton_masks = []
66
+ for people in vton_people:
67
+ people_arr = np.array(people)
68
+ gray = cv2.cvtColor(people_arr, cv2.COLOR_RGB2GRAY)
69
+ _, mask = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)
70
+ mask = mask.astype(bool)
71
+ kernel = np.ones((5, 5), np.uint8)
72
+ mask_clean = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel, iterations=1)
73
+ mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel, iterations=2)
74
+ mask_u8 = (mask_clean.astype(np.uint8) * 255)
75
+ mask_blur = cv2.GaussianBlur(mask_u8, (3, 3), 1)
76
+ vton_masks.append(mask_blur)
77
+ return vton_masks
78
+
79
+ def contour_curvature(self, contour, k=5):
80
+ pts = contour[:, 0, :].astype(np.float32)
81
+ N = len(pts)
82
+ curv = np.zeros(N)
83
+ for i in range(N):
84
+ p_prev = pts[(i - k) % N]
85
+ p = pts[i]
86
+ p_next = pts[(i + k) % N]
87
+ v1 = p - p_prev
88
+ v2 = p_next - p
89
+ v1 /= (np.linalg.norm(v1) + 1e-6)
90
+ v2 /= (np.linalg.norm(v2) + 1e-6)
91
+ angle = np.arccos(np.clip(np.dot(v1, v2), -1, 1))
92
+ curv[i] = angle
93
+ return curv
94
+
95
+ def frontness_score(self, mask_a, mask_b):
96
+ inter = mask_a & mask_b
97
+ if inter.sum() < 50:
98
+ return 0.0
99
+ cnts_a, _ = cv2.findContours(mask_a.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
100
+ cnts_b, _ = cv2.findContours(mask_b.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
101
+ if not cnts_a or not cnts_b:
102
+ return 0.0
103
+ ca = max(cnts_a, key=len)
104
+ cb = max(cnts_b, key=len)
105
+ curv_a = self.contour_curvature(ca)
106
+ curv_b = self.contour_curvature(cb)
107
+ inter_pts = np.column_stack(np.where(inter))[:, ::-1]
108
+ tree_a = cKDTree(ca[:, 0, :])
109
+ tree_b = cKDTree(cb[:, 0, :])
110
+ _, idx_a = tree_a.query(inter_pts, k=1)
111
+ _, idx_b = tree_b.query(inter_pts, k=1)
112
+ score_a = curv_a[idx_a].mean()
113
+ score_b = curv_b[idx_b].mean()
114
+ return score_a - score_b
115
+
116
+ def estimate_front_to_back_order(self, masks):
117
+ n = len(masks)
118
+ scores = np.zeros(n)
119
+ for i in range(n):
120
+ for j in range(n):
121
+ if i == j:
122
+ continue
123
+ scores[i] += self.frontness_score(masks[i], masks[j])
124
+ order = np.argsort(-scores)
125
+ return order, scores
126
+
127
+ def remove_original_people(self, image, person_masks):
128
+ image_np = np.array(image)
129
+ combined_mask = np.zeros(image_np.shape[:2], dtype=np.uint8)
130
+ for mask in person_masks:
131
+ combined_mask[mask] = 255
132
+ kernel = np.ones((5, 5), np.uint8)
133
+ combined_mask = cv2.dilate(combined_mask, kernel, iterations=2)
134
+ inpainted = cv2.inpaint(image_np, combined_mask, 3, cv2.INPAINT_TELEA)
135
+ return Image.fromarray(inpainted), combined_mask
136
+
137
+ def clean_vton_edges_on_overlap(self, img_pil, mask_uint8, other_masks_uint8,
138
+ erode_iters=1, edge_dilate=2, inner_erode=2):
139
+ src = np.array(img_pil).copy()
140
+ others_union = np.zeros_like(mask_uint8, dtype=np.uint8)
141
+ for m in other_masks_uint8:
142
+ others_union = np.maximum(others_union, m)
143
+ overlap = (mask_uint8 > 0) & (others_union > 0)
144
+ overlap = overlap.astype(np.uint8) * 255
145
+ if overlap.sum() == 0:
146
+ return img_pil, mask_uint8
147
+ kernel = np.ones((3, 3), np.uint8)
148
+ tight_mask = cv2.erode(mask_uint8, kernel, iterations=erode_iters)
149
+ edge = cv2.Canny(tight_mask, 50, 150)
150
+ edge = cv2.dilate(edge, np.ones((3, 3), np.uint8), iterations=edge_dilate)
151
+ overlap_band = cv2.dilate(overlap, np.ones((5, 5), np.uint8), iterations=1)
152
+ edge = cv2.bitwise_and(edge, overlap_band)
153
+ if edge.sum() == 0:
154
+ return img_pil, tight_mask
155
+ inner = cv2.erode(tight_mask, np.ones((5, 5), np.uint8), iterations=inner_erode)
156
+ inner_rgb = cv2.inpaint(src, 255 - inner, 3, cv2.INPAINT_TELEA)
157
+ src[edge > 0] = inner_rgb[edge > 0]
158
+ return Image.fromarray(src), tight_mask
159
+
160
+ def clean_masks(self, vton_people, vton_masks):
161
+ cleaned_vton_people = []
162
+ cleaned_vton_masks = []
163
+ for i in range(len(vton_people)):
164
+ other_masks = [m for j, m in enumerate(vton_masks) if j != i]
165
+ cleaned_img, cleaned_mask = self.clean_vton_edges_on_overlap(
166
+ vton_people[i], vton_masks[i], other_masks,
167
+ erode_iters=1, edge_dilate=2, inner_erode=2
168
+ )
169
+ cleaned_vton_people.append(cleaned_img)
170
+ cleaned_vton_masks.append(cleaned_mask)
171
+ return cleaned_vton_people, cleaned_vton_masks
172
+
173
+ def process_group_image(self, group_image, assignments):
174
+ """Process a group image with per-person garment assignments.
175
+
176
+ assignments: list of {"garment": PIL.Image|None, "category": str} per person.
177
+ """
178
+ print("Step 1: Loading images...")
179
+ if isinstance(group_image, np.ndarray):
180
+ group_image = Image.fromarray(group_image)
181
+ if isinstance(group_image, Image.Image):
182
+ group_image.save("people.png")
183
+
184
+ img_bgr = cv2.imread("people.png")
185
+ img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
186
+ H, W = img.shape[:2]
187
+
188
+ print("Step 2: Getting segmentation masks with YOLO...")
189
+ results = self.model("people.png")
190
+ result = results[0]
191
+ masks = self.get_mask(result, H, W)
192
+ print(f"Found {len(masks)} people")
193
+
194
+ print("Step 3: Extracting individual people...")
195
+ people = self.extract_people(img, masks)
196
+
197
+ # Pad assignments to match detected people count
198
+ while len(assignments) < len(people):
199
+ assignments.append({"garment": None, "category": "tops"})
200
+
201
+ print("Step 4: Applying VTON to people...")
202
+ vton_people = self.apply_vton_to_people(people, assignments)
203
+
204
+ print("Step 5: Getting masks for VTON results...")
205
+ vton_masks = self.get_vton_masks(vton_people)
206
+ for i in range(len(vton_masks)):
207
+ if assignments[i]["garment"] is None:
208
+ yolo_mask = (masks[i].astype(np.uint8) * 255)
209
+ yolo_mask = cv2.GaussianBlur(yolo_mask, (3, 3), 1)
210
+ vton_masks[i] = yolo_mask
211
+ order, scores = self.estimate_front_to_back_order(vton_masks)
212
+ cleaned_vton_people, cleaned_vton_masks = self.clean_masks(vton_people, vton_masks)
213
+
214
+ print("Step 6: Resizing to match dimensions...")
215
+ img = cv2.resize(img, vton_people[0].size)
216
+
217
+ print("Step 7: Creating clean background by removing original people...")
218
+ clean_background, person_mask = self.remove_original_people(img, masks)
219
+ clean_background_np = np.array(clean_background)
220
+
221
+ print("Step 8: Recomposing final image...")
222
+ recomposed = clean_background_np.copy()
223
+ for i in order:
224
+ vton_mask = cleaned_vton_masks[i]
225
+ img_pil = cleaned_vton_people[i]
226
+ out = recomposed.astype(np.float32)
227
+ src = np.array(img_pil).astype(np.float32)
228
+ alpha = (vton_mask.astype(np.float32) / 255.0)[..., None]
229
+ src = src * alpha
230
+ out = src + (1 - alpha) * out
231
+ recomposed = out.astype(np.uint8)
232
+
233
+ final_image = Image.fromarray(recomposed)
234
+ return final_image, {
235
+ "original": Image.fromarray(img),
236
+ "clean_background": clean_background,
237
+ "person_mask": Image.fromarray(person_mask),
238
+ "num_people": len(people),
239
+ "individual_people": people,
240
+ "vton_results": cleaned_vton_people,
241
+ "masks": masks,
242
+ "vton_masks": cleaned_vton_masks
243
+ }
244
+
245
+
246
+ WEIGHTS_DIR = Path("./weights")
247
+
248
+ def ensure_weights():
249
+ if WEIGHTS_DIR.exists() and any(WEIGHTS_DIR.iterdir()):
250
+ print("Weights already present, skipping download.")
251
+ return
252
+ print("Downloading weights...")
253
+ subprocess.check_call([
254
+ sys.executable,
255
+ "fashn-vton-1.5/scripts/download_weights.py",
256
+ "--weights-dir",
257
+ str(WEIGHTS_DIR),
258
+ ])
259
+
260
+ ensure_weights()
261
+
262
+ _pipeline = None
263
+
264
+ def get_pipeline():
265
+ global _pipeline
266
+ if _pipeline is None:
267
+ _pipeline = MultiPersonVTON()
268
+ return _pipeline
269
+
270
+ @spaces.GPU
271
+ def detect_people(portrait_path):
272
+ if portrait_path is None:
273
+ raise gr.Error("Please select a portrait first.")
274
+ portrait = Image.open(portrait_path) if isinstance(portrait_path, str) else portrait_path
275
+ new_width = 576
276
+ w, h = portrait.size
277
+ new_height = int(h * new_width / w)
278
+ resized = portrait.resize((new_width, new_height), Image.LANCZOS)
279
+ resized.save("people.png")
280
+ pipeline = get_pipeline()
281
+ results = pipeline.model("people.png")
282
+ result = results[0]
283
+ img = np.array(resized)
284
+ H, W = img.shape[:2]
285
+ masks = pipeline.get_mask(result, H, W)
286
+ people = pipeline.extract_people(img, masks)
287
+ return people
288
+
289
+ @spaces.GPU
290
+ def process_images(selected_portrait, garment_pool, num_detected, *assignment_args):
291
+ if selected_portrait is None:
292
+ raise gr.Error("Please select a portrait.")
293
+ if not garment_pool:
294
+ raise gr.Error("Please add at least one garment to the pool.")
295
+ portrait = Image.open(selected_portrait) if isinstance(selected_portrait, str) else selected_portrait
296
+ pipeline = get_pipeline()
297
+ new_width = 576
298
+ w, h = portrait.size
299
+ new_height = int(h * new_width / w)
300
+ resized = portrait.resize((new_width, new_height), Image.LANCZOS)
301
+
302
+ # Build per-person assignments from dropdown/radio values
303
+ # assignment_args: dd_0, dd_1, ..., dd_7, cat_0, cat_1, ..., cat_7
304
+ n = num_detected if num_detected else 0
305
+ max_p = len(assignment_args) // 2
306
+ pool_by_label = {g["label"]: g for g in garment_pool}
307
+ assignments = []
308
+ for i in range(n):
309
+ dd_val = assignment_args[i]
310
+ cat_val = assignment_args[max_p + i]
311
+ if dd_val == "Skip" or dd_val not in pool_by_label:
312
+ assignments.append({"garment": None, "category": cat_val or "tops"})
313
+ else:
314
+ g = pool_by_label[dd_val]
315
+ garment_img = Image.open(g["path"]) if isinstance(g["path"], str) else g["path"]
316
+ assignments.append({"garment": garment_img, "category": cat_val or "tops"})
317
+
318
+ result, _ = pipeline.process_group_image(resized, assignments)
319
+ return result
320
+
321
+ demo = build_demo(process_images, detect_fn=detect_people, max_people=8)
322
+ from huggingface_hub import constants as hf_constants
323
+ demo.launch(allowed_paths=[hf_constants.HF_HUB_CACHE])
fashn-vton-1.5 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 7c0f10af3f91ad4048fe9729c470a13ef905d25a
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ numpY
3
+ Pillow
4
+ Requests
5
+ torch
6
+ ultralytics
7
+ fashn-vton @ git+https://github.com/fashn-AI/fashn-vton-1.5.git
storage.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import shutil
4
+ from pathlib import Path
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ DATA_REPO = "aj406/vton-data"
9
+ REPO_TYPE = "dataset"
10
+ DATASET_HF_TOKEN = os.environ.get("DATASET_HF_TOKEN")
11
+ LOCAL_DATA = Path("data")
12
+
13
+ IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp"}
14
+
15
+
16
+ def is_remote():
17
+ return DATASET_HF_TOKEN is not None
18
+
19
+
20
+ def _api():
21
+ from huggingface_hub import HfApi
22
+ return HfApi()
23
+
24
+
25
+ def _ensure_repo():
26
+ if not is_remote():
27
+ return
28
+ _api().create_repo(repo_id=DATA_REPO, repo_type=REPO_TYPE, exist_ok=True, token=DATASET_HF_TOKEN)
29
+
30
+
31
+ def save_image(img, local_path):
32
+ local_path = Path(local_path)
33
+ local_path.parent.mkdir(parents=True, exist_ok=True)
34
+ if isinstance(img, np.ndarray):
35
+ img = Image.fromarray(img)
36
+ img.save(local_path, "JPEG", quality=85)
37
+
38
+
39
+ def upload_image(local_path, remote_path):
40
+ if not is_remote():
41
+ return
42
+ _ensure_repo()
43
+ _api().upload_file(
44
+ path_or_fileobj=str(local_path),
45
+ path_in_repo=remote_path,
46
+ repo_id=DATA_REPO,
47
+ repo_type=REPO_TYPE,
48
+ token=DATASET_HF_TOKEN,
49
+ )
50
+
51
+
52
+ def delete_remote_file(remote_path):
53
+ if not is_remote():
54
+ return
55
+ _api().delete_file(
56
+ path_in_repo=remote_path,
57
+ repo_id=DATA_REPO,
58
+ repo_type=REPO_TYPE,
59
+ token=DATASET_HF_TOKEN,
60
+ )
61
+
62
+
63
+ def download_dir(remote_prefix):
64
+ if not is_remote():
65
+ return
66
+ from huggingface_hub import snapshot_download
67
+ snapshot_download(
68
+ repo_id=DATA_REPO,
69
+ repo_type=REPO_TYPE,
70
+ allow_patterns=f"{remote_prefix}/**",
71
+ local_dir=str(LOCAL_DATA),
72
+ token=DATASET_HF_TOKEN,
73
+ )
74
+
75
+
76
+ def generate_id():
77
+ return uuid.uuid4().hex[:8]
78
+
79
+
80
+ def make_filename(item_id, item_type):
81
+ return f"{item_id}_{item_type}.jpg"
82
+
83
+
84
+ def parse_filename(filename):
85
+ stem = Path(filename).stem
86
+ parts = stem.rsplit("_", 1)
87
+ if len(parts) != 2:
88
+ return None
89
+ return {"id": parts[0], "type": parts[1]}
90
+
91
+
92
+ def make_result_filename(portrait_id, garment_id):
93
+ return f"{portrait_id}_{garment_id}_result.jpg"
94
+
95
+
96
+ def parse_result_filename(filename):
97
+ stem = Path(filename).stem
98
+ parts = stem.rsplit("_", 2)
99
+ if len(parts) != 3 or parts[2] != "result":
100
+ return None
101
+ if "-" in parts[1]:
102
+ return None
103
+ return {"portrait_id": parts[0], "garment_id": parts[1]}
104
+
105
+
106
+ def make_multi_result_filename(portrait_id, garment_ids):
107
+ """Build result filename encoding per-person garment assignments.
108
+
109
+ garment_ids: list of garment_id (str) or None per person.
110
+ Example: portrait_id=abc123, garment_ids=["ef12ab34", None, "gh56cd78"]
111
+ → "abc123_ef12ab34-x-gh56cd78_result.jpg"
112
+ """
113
+ slots = [gid if gid else "x" for gid in garment_ids]
114
+ code = "-".join(slots)
115
+ return f"{portrait_id}_{code}_result.jpg"
116
+
117
+
118
+ def parse_multi_result_filename(filename):
119
+ stem = Path(filename).stem
120
+ if not stem.endswith("_result"):
121
+ return None
122
+ stem = stem[:-len("_result")]
123
+ parts = stem.split("_", 1)
124
+ if len(parts) != 2:
125
+ return None
126
+ portrait_id = parts[0]
127
+ code = parts[1]
128
+ slots = code.split("-")
129
+ garment_ids = [None if slot == "x" else slot for slot in slots]
130
+ return {"portrait_id": portrait_id, "garment_ids": garment_ids}
131
+
132
+
133
+ def list_local_images(directory):
134
+ d = Path(directory)
135
+ if not d.exists():
136
+ return []
137
+ return sorted([str(p) for p in d.iterdir() if p.suffix.lower() in IMG_EXTS])
138
+
139
+
140
+ def file_url(remote_path):
141
+ """Return a direct HF URL for a file in the dataset repo (public repo)."""
142
+ return f"https://huggingface.co/datasets/{DATA_REPO}/resolve/main/{remote_path}"
143
+
144
+
145
+ def list_gallery_urls(prefix, subdir):
146
+ """List files in dataset repo and return direct URLs for gallery display."""
147
+ if not is_remote():
148
+ return list_local_images(LOCAL_DATA / prefix / subdir)
149
+ try:
150
+ items = _api().list_repo_tree(
151
+ DATA_REPO, repo_type=REPO_TYPE, path_in_repo=f"{prefix}/{subdir}"
152
+ )
153
+ urls = []
154
+ for item in items:
155
+ if hasattr(item, "rfilename"):
156
+ name = item.rfilename
157
+ elif hasattr(item, "path"):
158
+ name = item.path
159
+ else:
160
+ continue
161
+ if Path(name).suffix.lower() in IMG_EXTS:
162
+ urls.append(file_url(name))
163
+ return sorted(urls)
164
+ except Exception:
165
+ return list_local_images(LOCAL_DATA / prefix / subdir)
166
+
167
+
168
+ HF_URL_PREFIX = f"https://huggingface.co/datasets/{DATA_REPO}/resolve/main/"
169
+
170
+
171
+ def is_dataset_url(url):
172
+ """Check if a URL points to our HF dataset repo."""
173
+ return isinstance(url, str) and url.startswith(HF_URL_PREFIX)
174
+
175
+
176
+ def download_to_local(path_or_url):
177
+ """Download a URL to local path. HF dataset URLs use hf_hub, other URLs use requests."""
178
+ if not isinstance(path_or_url, str):
179
+ return path_or_url
180
+ if is_dataset_url(path_or_url):
181
+ remote_path = path_or_url[len(HF_URL_PREFIX):]
182
+ from huggingface_hub import hf_hub_download
183
+ local = hf_hub_download(
184
+ repo_id=DATA_REPO,
185
+ repo_type=REPO_TYPE,
186
+ filename=remote_path,
187
+ token=DATASET_HF_TOKEN,
188
+ )
189
+ return local
190
+ if path_or_url.startswith(("http://", "https://")):
191
+ import requests
192
+ from io import BytesIO
193
+ resp = requests.get(path_or_url, timeout=30)
194
+ resp.raise_for_status()
195
+ img = Image.open(BytesIO(resp.content))
196
+ tmp_path = LOCAL_DATA / "tmp" / f"{generate_id()}.jpg"
197
+ tmp_path.parent.mkdir(parents=True, exist_ok=True)
198
+ save_image(img, tmp_path)
199
+ return str(tmp_path)
200
+ return path_or_url
201
+
202
+
203
+ def load_image_sets(prefix):
204
+ """Scan {prefix}/portraits/ dir, parse filenames, return list of dicts with matched files."""
205
+ local_prefix = LOCAL_DATA / prefix
206
+ if is_remote():
207
+ download_dir(prefix)
208
+ portraits_dir = local_prefix / "portraits"
209
+ if not portraits_dir.exists():
210
+ return []
211
+ sets = {}
212
+ for p in portraits_dir.iterdir():
213
+ if p.suffix.lower() not in IMG_EXTS:
214
+ continue
215
+ parsed = parse_filename(p.name)
216
+ if not parsed:
217
+ continue
218
+ item_id = parsed["id"]
219
+ sets[item_id] = {
220
+ "id": item_id,
221
+ "portrait": str(p),
222
+ }
223
+ garments_dir = local_prefix / "garments"
224
+ results_dir = local_prefix / "results"
225
+ for item_id, entry in sets.items():
226
+ garment = garments_dir / f"{item_id}_garment.jpg"
227
+ result = results_dir / f"{item_id}_result.jpg"
228
+ entry["garment"] = str(garment) if garment.exists() else None
229
+ entry["result"] = str(result) if result.exists() else None
230
+ return [v for v in sets.values() if v["garment"] is not None]
231
+
232
+
233
+ def save_image_set(prefix, img_portrait, img_garment, img_result=None):
234
+ """Save a set of images (portrait + garment + optional result) with consistent naming."""
235
+ item_id = generate_id()
236
+ local_prefix = LOCAL_DATA / prefix
237
+
238
+ portrait_name = make_filename(item_id, "portrait")
239
+ garment_name = make_filename(item_id, "garment")
240
+
241
+ portrait_path = local_prefix / "portraits" / portrait_name
242
+ garment_path = local_prefix / "garments" / garment_name
243
+
244
+ save_image(img_portrait, portrait_path)
245
+ save_image(img_garment, garment_path)
246
+ upload_image(portrait_path, f"{prefix}/portraits/{portrait_name}")
247
+ upload_image(garment_path, f"{prefix}/garments/{garment_name}")
248
+
249
+ result_path = None
250
+ if img_result is not None:
251
+ result_name = make_result_filename(item_id, item_id)
252
+ result_path = local_prefix / "results" / result_name
253
+ save_image(img_result, result_path)
254
+ upload_image(result_path, f"{prefix}/results/{result_name}")
255
+
256
+ return item_id, str(portrait_path), str(garment_path), str(result_path) if result_path else None
257
+
258
+
259
+ def save_result(prefix, portrait_id, garment_id, img_result):
260
+ """Save a result image encoding both portrait and garment IDs."""
261
+ local_prefix = LOCAL_DATA / prefix
262
+ result_name = make_result_filename(portrait_id, garment_id)
263
+ result_path = local_prefix / "results" / result_name
264
+ save_image(img_result, result_path)
265
+ upload_image(result_path, f"{prefix}/results/{result_name}")
266
+ return str(result_path)
267
+
268
+
269
+ def save_multi_result(prefix, portrait_id, assignments, img_result):
270
+ """Save a multi-garment result image with assignment-encoded filename."""
271
+ local_prefix = LOCAL_DATA / prefix
272
+ result_name = make_multi_result_filename(portrait_id, assignments)
273
+ result_path = local_prefix / "results" / result_name
274
+ save_image(img_result, result_path)
275
+ upload_image(result_path, f"{prefix}/results/{result_name}")
276
+ return str(result_path)
277
+
278
+
279
+ def delete_image_set(prefix, item_id):
280
+ """Delete all files for an image set (scans for ID prefix to catch multi-garment files)."""
281
+ local_prefix = LOCAL_DATA / prefix
282
+ for subdir in ("portraits", "garments", "results"):
283
+ d = local_prefix / subdir
284
+ if not d.exists():
285
+ continue
286
+ for f in d.iterdir():
287
+ if f.stem.startswith(item_id):
288
+ f.unlink()
289
+ if is_remote():
290
+ try:
291
+ delete_remote_file(f"{prefix}/{subdir}/{f.name}")
292
+ except Exception:
293
+ pass
294
+
295
+
296
+ def promote_to_example(result_path):
297
+ """Copy a result file to examples, preserving its filename for resolution."""
298
+ src = Path(result_path)
299
+ dest = LOCAL_DATA / "examples" / "results" / src.name
300
+ dest.parent.mkdir(parents=True, exist_ok=True)
301
+ shutil.copy2(str(src), str(dest))
302
+ upload_image(dest, f"examples/results/{src.name}")
303
+ return src.stem
ui.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from pathlib import Path
4
+ from storage import (
5
+ save_multi_result,
6
+ delete_image_set,
7
+ promote_to_example,
8
+ list_gallery_urls,
9
+ download_to_local,
10
+ is_dataset_url,
11
+ save_image,
12
+ upload_image,
13
+ make_filename,
14
+ generate_id,
15
+ parse_filename,
16
+ parse_result_filename,
17
+ parse_multi_result_filename,
18
+ file_url,
19
+ LOCAL_DATA,
20
+ )
21
+
22
+ EXAMPLES_PREFIX = "examples"
23
+ UPLOADS_PREFIX = "user_uploads"
24
+
25
+ for subdir in ["portraits", "garments", "results"]:
26
+ (LOCAL_DATA / EXAMPLES_PREFIX / subdir).mkdir(parents=True, exist_ok=True)
27
+ (LOCAL_DATA / UPLOADS_PREFIX / subdir).mkdir(parents=True, exist_ok=True)
28
+
29
+
30
+ def _gallery_images(prefix, subdir):
31
+ return list_gallery_urls(prefix, subdir)
32
+
33
+
34
+ MAX_PEOPLE = 8
35
+
36
+
37
+ def build_demo(process_fn, detect_fn=None, max_people=MAX_PEOPLE):
38
+
39
+ def process_and_save(portrait_path, garment_pool, num_detected, *assignment_args):
40
+ if portrait_path is None:
41
+ raise gr.Error("Please select a portrait.")
42
+ if not garment_pool:
43
+ raise gr.Error("Please add at least one garment to the pool.")
44
+ result = process_fn(portrait_path, garment_pool, num_detected, *assignment_args)
45
+ if result and portrait_path:
46
+ p_parsed = parse_filename(Path(portrait_path).name)
47
+ if p_parsed:
48
+ # Upload portrait if local
49
+ local = Path(portrait_path)
50
+ if local.exists() and str(local).startswith(str(LOCAL_DATA)):
51
+ upload_image(local, str(local.relative_to(LOCAL_DATA)))
52
+ # Upload all garments in pool
53
+ for g in garment_pool:
54
+ garment_local = Path(g["path"])
55
+ if garment_local.exists() and str(garment_local).startswith(str(LOCAL_DATA)):
56
+ upload_image(garment_local, str(garment_local.relative_to(LOCAL_DATA)))
57
+ # Build garment ID list for filename
58
+ n = num_detected if num_detected else 0
59
+ max_p = len(assignment_args) // 2
60
+ pool_by_label = {g["label"]: g for g in garment_pool}
61
+ garment_ids = []
62
+ for i in range(n):
63
+ dd_val = assignment_args[i]
64
+ if dd_val == "Skip" or dd_val not in pool_by_label:
65
+ garment_ids.append(None)
66
+ else:
67
+ g = pool_by_label[dd_val]
68
+ g_parsed = parse_filename(Path(g["path"]).name)
69
+ garment_ids.append(g_parsed["id"] if g_parsed else generate_id())
70
+ save_multi_result(UPLOADS_PREFIX, p_parsed["id"], garment_ids, result)
71
+ return result
72
+
73
+ with gr.Blocks(title="Multi-Person Virtual Try-On") as demo:
74
+ with gr.Tabs():
75
+ # ---- Main VTON Tab ----
76
+ with gr.Tab("Virtual Try-On"):
77
+ gr.Markdown("# Multi-Person Virtual Try-On")
78
+ gr.Markdown("Select a portrait, add garments to the pool, detect people, and assign garments.")
79
+
80
+ selected_portrait = gr.State(value=None)
81
+ garment_pool = gr.State(value=[])
82
+ num_detected = gr.State(value=0)
83
+ garment_counter = gr.State(value=0)
84
+
85
+ gr.Markdown("**Step 1:** Select or upload a portrait, and select or upload garments to the pool.")
86
+ with gr.Row():
87
+ with gr.Column():
88
+ gr.Markdown("### Portrait")
89
+ portrait_gallery = gr.Gallery(
90
+ value=_gallery_images(UPLOADS_PREFIX, "portraits"),
91
+ label="Uploaded Portraits",
92
+ columns=4,
93
+ height=200,
94
+ allow_preview=False,
95
+ )
96
+ with gr.Accordion("Upload new portrait", open=False):
97
+ portrait_upload = gr.Image(type="pil", label="Upload Portrait", sources=["upload", "webcam"])
98
+ with gr.Row():
99
+ portrait_url_input = gr.Textbox(label="Or paste image URL", scale=4)
100
+ portrait_url_btn = gr.Button("Load", size="sm", scale=1)
101
+ preview_portrait = gr.Image(label="Selected Portrait", interactive=False, height=250)
102
+
103
+ # Garment pool section
104
+ with gr.Column():
105
+ gr.Markdown("### Garment Pool")
106
+ garment_gallery = gr.Gallery(
107
+ value=_gallery_images(UPLOADS_PREFIX, "garments"),
108
+ label="Available Garments (click to add to pool)",
109
+ columns=4,
110
+ height=200,
111
+ allow_preview=False,
112
+ )
113
+ with gr.Accordion("Upload new garment", open=False):
114
+ garment_upload = gr.Image(type="pil", label="Upload Garment", sources=["upload", "webcam"])
115
+ with gr.Row():
116
+ garment_url_input = gr.Textbox(label="Or paste image URL", scale=4)
117
+ garment_url_btn = gr.Button("Load", size="sm", scale=1)
118
+ garment_pool_gallery = gr.Gallery(
119
+ label="Current Pool",
120
+ columns=6,
121
+ height=250,
122
+ allow_preview=False,
123
+ )
124
+ clear_pool_btn = gr.Button("Clear Pool", size="sm", variant="stop")
125
+
126
+ gr.Markdown("**Step 2:** Detect people in the portrait. This lets you choose which garment goes on each person.")
127
+ detect_btn = gr.Button("Detect People", variant="secondary")
128
+ detect_status = gr.Textbox(interactive=False, show_label=False, value="")
129
+ people_gallery = gr.Gallery(
130
+ label="Detected People",
131
+ columns=6,
132
+ height=300,
133
+ allow_preview=False,
134
+ )
135
+
136
+ gr.Markdown("**Step 3:** Assign garments to each person, then click Try On.")
137
+ @gr.render(inputs=[num_detected, garment_pool])
138
+ def render_assignments(n_detected, pool):
139
+ n = n_detected or 0
140
+ choices = ["Skip"] + [g["label"] for g in (pool or [])]
141
+ default_garment = choices[1] if len(choices) > 1 else "Skip"
142
+
143
+ if n > 0:
144
+ gr.Markdown(f"### Assign Garments to {n} {'Person' if n == 1 else 'People'}")
145
+
146
+ dds = []
147
+ cats = []
148
+ for i in range(n):
149
+ with gr.Row():
150
+ dd = gr.Dropdown(
151
+ choices=choices,
152
+ value=default_garment,
153
+ label=f"Person {i + 1} — Garment",
154
+ scale=3,
155
+ interactive=True,
156
+ )
157
+ cat = gr.Radio(
158
+ choices=["tops", "bottoms", "one-pieces"],
159
+ value="tops",
160
+ label="Category",
161
+ scale=2,
162
+ interactive=True,
163
+ )
164
+ dds.append(dd)
165
+ cats.append(cat)
166
+
167
+ submit_btn = gr.Button("Try On", variant="primary")
168
+ result_image = gr.Image(type="pil", label="Result", interactive=False)
169
+
170
+ submit_btn.click(
171
+ process_and_save,
172
+ inputs=[selected_portrait, garment_pool, num_detected] + dds + cats,
173
+ outputs=result_image,
174
+ )
175
+
176
+ # Examples section
177
+ gr.Markdown("---")
178
+ gr.Markdown("### Examples")
179
+ example_sets = gr.State(value=[])
180
+ refresh_examples_btn = gr.Button("Refresh Examples", size="sm")
181
+
182
+ @gr.render(inputs=[example_sets])
183
+ def render_examples(sets):
184
+ for i, ex in enumerate(sets or []):
185
+ with gr.Row():
186
+ gr.Image(value=ex["portrait"], label="Portrait", height=200, interactive=False, scale=1)
187
+ for j, g in enumerate(ex["garments"]):
188
+ gr.Image(value=g, label=f"Garment {j+1}", height=200, interactive=False, scale=1)
189
+ gr.Image(value=ex["result"], label="Result", height=200, interactive=False, scale=1)
190
+ use_btn = gr.Button("Use", size="sm", scale=0, min_width=60)
191
+ use_btn.click(
192
+ lambda p=ex["portrait"], gs=ex["garments"]: _load_example(p, gs),
193
+ outputs=[selected_portrait, preview_portrait, garment_pool, garment_counter, garment_pool_gallery],
194
+ )
195
+
196
+ # -- Event handlers --
197
+
198
+ def on_portrait_gallery_select(evt: gr.SelectData):
199
+ path = evt.value["image"]["path"]
200
+ local_path = download_to_local(path)
201
+ return local_path, local_path
202
+
203
+ def on_garment_gallery_select(evt: gr.SelectData, pool, counter):
204
+ """Add selected garment to pool."""
205
+ path = evt.value["image"]["path"]
206
+ local_path = download_to_local(path)
207
+ new_counter = counter + 1
208
+ label = f"Garment {new_counter}"
209
+ new_pool = pool + [{"path": local_path, "label": label}]
210
+ pool_images = [g["path"] for g in new_pool]
211
+ return new_pool, new_counter, pool_images
212
+
213
+ def _load_example(portrait_path, garment_paths):
214
+ pool = [{"path": g, "label": f"Garment {i+1}"} for i, g in enumerate(garment_paths)]
215
+ pool_images = [g["path"] for g in pool]
216
+ return portrait_path, portrait_path, pool, len(pool), pool_images
217
+
218
+ def clear_pool():
219
+ return [], 0, [], _gallery_images(UPLOADS_PREFIX, "garments")
220
+
221
+ def reset_detection():
222
+ return "", [], 0
223
+
224
+ detection_reset_outputs = [detect_status, people_gallery, num_detected]
225
+
226
+ portrait_gallery.select(
227
+ on_portrait_gallery_select, outputs=[selected_portrait, preview_portrait]
228
+ ).then(reset_detection, outputs=detection_reset_outputs)
229
+
230
+ garment_gallery.select(
231
+ on_garment_gallery_select,
232
+ inputs=[garment_pool, garment_counter],
233
+ outputs=[garment_pool, garment_counter, garment_pool_gallery],
234
+ )
235
+
236
+ clear_pool_btn.click(
237
+ clear_pool,
238
+ outputs=[garment_pool, garment_counter, garment_pool_gallery, garment_gallery],
239
+ )
240
+
241
+ def on_portrait_upload(img, current_pool):
242
+ if img is None:
243
+ return _gallery_images(UPLOADS_PREFIX, "portraits"), None, None, None
244
+ item_id = generate_id()
245
+ fname = make_filename(item_id, "portrait")
246
+ local_path = LOCAL_DATA / UPLOADS_PREFIX / "portraits" / fname
247
+ save_image(img, local_path)
248
+ path = str(local_path)
249
+ return _gallery_images(UPLOADS_PREFIX, "portraits"), path, path, None
250
+
251
+ def on_garment_upload(img, pool, counter):
252
+ if img is None:
253
+ return _gallery_images(UPLOADS_PREFIX, "garments"), pool, counter, [g["path"] for g in pool], None
254
+ item_id = generate_id()
255
+ fname = make_filename(item_id, "garment")
256
+ local_path = LOCAL_DATA / UPLOADS_PREFIX / "garments" / fname
257
+ save_image(img, local_path)
258
+ path = str(local_path)
259
+ new_counter = counter + 1
260
+ label = f"Garment {new_counter}"
261
+ new_pool = pool + [{"path": path, "label": label}]
262
+ pool_images = [g["path"] for g in new_pool]
263
+ return _gallery_images(UPLOADS_PREFIX, "garments"), new_pool, new_counter, pool_images, None
264
+
265
+ def on_portrait_url(url, pool):
266
+ if not url or not url.strip():
267
+ return _gallery_images(UPLOADS_PREFIX, "portraits"), None, None, ""
268
+ local_path = download_to_local(url.strip())
269
+ if not is_dataset_url(url.strip()):
270
+ from PIL import Image as PILImage
271
+ item_id = generate_id()
272
+ fname = make_filename(item_id, "portrait")
273
+ dest = LOCAL_DATA / UPLOADS_PREFIX / "portraits" / fname
274
+ save_image(PILImage.open(local_path), dest)
275
+ local_path = str(dest)
276
+ return _gallery_images(UPLOADS_PREFIX, "portraits"), local_path, local_path, ""
277
+
278
+ def on_garment_url(url, pool, counter):
279
+ if not url or not url.strip():
280
+ return _gallery_images(UPLOADS_PREFIX, "garments"), pool, counter, [g["path"] for g in pool], ""
281
+ local_path = download_to_local(url.strip())
282
+ if not is_dataset_url(url.strip()):
283
+ from PIL import Image as PILImage
284
+ item_id = generate_id()
285
+ fname = make_filename(item_id, "garment")
286
+ dest = LOCAL_DATA / UPLOADS_PREFIX / "garments" / fname
287
+ save_image(PILImage.open(local_path), dest)
288
+ local_path = str(dest)
289
+ new_counter = counter + 1
290
+ label = f"Garment {new_counter}"
291
+ new_pool = pool + [{"path": local_path, "label": label}]
292
+ pool_images = [g["path"] for g in new_pool]
293
+ return _gallery_images(UPLOADS_PREFIX, "garments"), new_pool, new_counter, pool_images, ""
294
+
295
+ portrait_upload.change(
296
+ on_portrait_upload,
297
+ inputs=[portrait_upload, garment_pool],
298
+ outputs=[portrait_gallery, selected_portrait, preview_portrait, portrait_upload],
299
+ ).then(reset_detection, outputs=detection_reset_outputs)
300
+
301
+ garment_upload.change(
302
+ on_garment_upload,
303
+ inputs=[garment_upload, garment_pool, garment_counter],
304
+ outputs=[garment_gallery, garment_pool, garment_counter, garment_pool_gallery, garment_upload],
305
+ )
306
+
307
+ portrait_url_btn.click(
308
+ on_portrait_url,
309
+ inputs=[portrait_url_input, garment_pool],
310
+ outputs=[portrait_gallery, selected_portrait, preview_portrait, portrait_url_input],
311
+ ).then(reset_detection, outputs=detection_reset_outputs)
312
+
313
+ garment_url_btn.click(
314
+ on_garment_url,
315
+ inputs=[garment_url_input, garment_pool, garment_counter],
316
+ outputs=[garment_gallery, garment_pool, garment_counter, garment_pool_gallery, garment_url_input],
317
+ )
318
+
319
+ def on_detect(portrait_path, pool):
320
+ if detect_fn is None or portrait_path is None:
321
+ raise gr.Error("Please select a portrait first.")
322
+ people = detect_fn(portrait_path)
323
+ n = len(people)
324
+ return f"Found {n} {'person' if n == 1 else 'people'}", people, n
325
+
326
+ detect_btn.click(
327
+ lambda: "Detecting people...",
328
+ outputs=[detect_status],
329
+ ).then(
330
+ on_detect,
331
+ inputs=[selected_portrait, garment_pool],
332
+ outputs=[detect_status, people_gallery, num_detected],
333
+ )
334
+
335
+ def refresh_examples():
336
+ result_urls = _gallery_images(EXAMPLES_PREFIX, "results")
337
+ sets = []
338
+ for r in result_urls:
339
+ portrait, garments, result = _resolve_result_images(UPLOADS_PREFIX, r)
340
+ if portrait and garments:
341
+ sets.append({"portrait": portrait, "garments": garments, "result": result})
342
+ return sets
343
+
344
+ refresh_examples_btn.click(
345
+ refresh_examples,
346
+ outputs=[example_sets],
347
+ )
348
+ demo.load(
349
+ refresh_examples,
350
+ outputs=[example_sets],
351
+ )
352
+
353
+ # ---- Admin Tab ----
354
+ with gr.Tab("Admin - Manage Examples"):
355
+ admin_status = gr.Textbox(label="Status", interactive=False)
356
+
357
+ gr.Markdown("### Current Examples")
358
+ admin_examples_table = gr.Dataframe(
359
+ headers=["ID", "Result Filename"],
360
+ label="Examples",
361
+ interactive=False,
362
+ )
363
+
364
+ with gr.Row():
365
+ delete_id = gr.Textbox(label="Example ID to delete", scale=3)
366
+ delete_btn = gr.Button("Delete", variant="stop", scale=1)
367
+
368
+ def get_examples_table():
369
+ results = list_gallery_urls(EXAMPLES_PREFIX, "results")
370
+ rows = []
371
+ for r in results:
372
+ fname = Path(r).name
373
+ parsed = parse_result_filename(fname) or parse_multi_result_filename(fname)
374
+ rid = parsed["portrait_id"] if parsed else Path(fname).stem
375
+ rows.append([rid, fname])
376
+ return rows
377
+
378
+ def on_admin_delete(ex_id):
379
+ if not ex_id:
380
+ return "Please provide an ID.", get_examples_table()
381
+ delete_image_set(EXAMPLES_PREFIX, ex_id.strip())
382
+ return "Deleted.", get_examples_table()
383
+
384
+ delete_btn.click(
385
+ on_admin_delete,
386
+ inputs=[delete_id],
387
+ outputs=[admin_status, admin_examples_table],
388
+ )
389
+
390
+ # ---- Promote from Uploads ----
391
+ gr.Markdown("---")
392
+ gr.Markdown("## Promote from Uploads")
393
+ gr.Markdown("Select a result to promote. The matching portrait and garment are found automatically.")
394
+
395
+ promote_portrait = gr.State(value=None)
396
+ promote_garments = gr.State(value=[])
397
+ promote_result = gr.State(value=None)
398
+
399
+ promo_result_gallery = gr.Gallery(
400
+ value=_gallery_images(UPLOADS_PREFIX, "results"),
401
+ label="Results",
402
+ columns=4,
403
+ height=200,
404
+ allow_preview=False,
405
+ )
406
+
407
+ with gr.Row():
408
+ promo_preview_portrait = gr.Image(label="Portrait", interactive=False, height=150)
409
+ promo_preview_garments = gr.Gallery(label="Garments", columns=4, height=150, allow_preview=False)
410
+ promo_preview_result = gr.Image(label="Result", interactive=False, height=150)
411
+
412
+ promote_btn = gr.Button("Promote to Example", variant="primary")
413
+ promote_status = gr.Textbox(label="Status", interactive=False)
414
+
415
+ def _resolve_result_images(prefix, path):
416
+ """Parse a result filename and resolve portrait + all garments."""
417
+ result_local = download_to_local(path)
418
+ fname = Path(result_local).name
419
+ # Try single-garment format
420
+ parsed = parse_result_filename(fname)
421
+ if not parsed:
422
+ parsed = parse_result_filename(Path(path).name)
423
+ if parsed:
424
+ try:
425
+ p_url = file_url(f"{prefix}/portraits/{make_filename(parsed['portrait_id'], 'portrait')}")
426
+ g_url = file_url(f"{prefix}/garments/{make_filename(parsed['garment_id'], 'garment')}")
427
+ return download_to_local(p_url), [download_to_local(g_url)], result_local
428
+ except Exception as e:
429
+ gr.Warning(f"Single-garment resolve failed for {fname}: {e}")
430
+ # Try multi-garment format
431
+ multi = parse_multi_result_filename(fname)
432
+ if not multi:
433
+ multi = parse_multi_result_filename(Path(path).name)
434
+ if multi:
435
+ gids = [gid for gid in multi["garment_ids"] if gid is not None]
436
+ if gids:
437
+ try:
438
+ p_url = file_url(f"{prefix}/portraits/{make_filename(multi['portrait_id'], 'portrait')}")
439
+ portrait_local = download_to_local(p_url)
440
+ garment_locals = []
441
+ for gid in gids:
442
+ g_url = file_url(f"{prefix}/garments/{make_filename(gid, 'garment')}")
443
+ garment_locals.append(download_to_local(g_url))
444
+ return portrait_local, garment_locals, result_local
445
+ except Exception as e:
446
+ gr.Warning(f"Multi-garment resolve failed for {fname} (portrait={multi['portrait_id']}, garments={gids}): {e}")
447
+ else:
448
+ gr.Warning(f"Could not parse result filename: {fname}")
449
+ return None, [], result_local
450
+
451
+ def on_result_select(evt: gr.SelectData):
452
+ path = evt.value["image"]["path"]
453
+ portrait_local, garment_locals, result_local = _resolve_result_images(UPLOADS_PREFIX, path)
454
+ if not portrait_local or not garment_locals:
455
+ gr.Warning(f"Could not find matching portrait/garment for: {Path(path).name}")
456
+ return portrait_local, garment_locals, result_local, portrait_local, garment_locals, result_local
457
+
458
+ promo_result_gallery.select(
459
+ on_result_select,
460
+ outputs=[promote_portrait, promote_garments, promote_result,
461
+ promo_preview_portrait, promo_preview_garments, promo_preview_result],
462
+ )
463
+
464
+ def on_promote(portrait_path, garment_paths, result_path):
465
+ if not result_path:
466
+ return "No result to promote.", get_examples_table()
467
+ name = promote_to_example(result_path)
468
+ return f"Promoted: {name}", get_examples_table()
469
+
470
+ promote_btn.click(
471
+ on_promote,
472
+ inputs=[promote_portrait, promote_garments, promote_result],
473
+ outputs=[promote_status, admin_examples_table],
474
+ )
475
+
476
+ refresh_promo_btn = gr.Button("Refresh Results", size="sm")
477
+ refresh_promo_btn.click(
478
+ lambda: _gallery_images(UPLOADS_PREFIX, "results"),
479
+ outputs=[promo_result_gallery],
480
+ )
481
+
482
+ demo.load(get_examples_table, outputs=[admin_examples_table])
483
+
484
+ return demo
485
+
486
+
487
+ if __name__ == "__main__":
488
+ def dummy_process(portrait, pool, num_detected, *assignment_args):
489
+ return Image.new("RGB", (512, 512), (200, 200, 200))
490
+ def dummy_detect(portrait_path):
491
+ return [Image.new("RGB", (100, 200), (255, 0, 0)), Image.new("RGB", (100, 200), (0, 255, 0))]
492
+ demo = build_demo(dummy_process, detect_fn=dummy_detect)
493
+ demo.launch()
vton_gradio_demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
vton_gradio_demo_remove_background.ipynb ADDED
The diff for this file is too large to render. See raw diff