CI commited on
Commit ·
6933b0e
0
Parent(s):
deploy
Browse files- .github/workflows/deploy.yml +24 -0
- .gitignore +7 -0
- .gitmodules +3 -0
- README.md +14 -0
- app.py +323 -0
- fashn-vton-1.5 +1 -0
- requirements.txt +7 -0
- storage.py +303 -0
- ui.py +493 -0
- vton_gradio_demo.ipynb +0 -0
- vton_gradio_demo_remove_background.ipynb +0 -0
.github/workflows/deploy.yml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Sync to Hugging Face hub
|
| 2 |
+
on:
|
| 3 |
+
push:
|
| 4 |
+
branches: [deploy]
|
| 5 |
+
# to run this workflow manually from the Actions tab
|
| 6 |
+
workflow_dispatch:
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
sync-to-hub:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
steps:
|
| 12 |
+
- uses: actions/checkout@v3
|
| 13 |
+
with:
|
| 14 |
+
fetch-depth: 1
|
| 15 |
+
- name: Push to hub
|
| 16 |
+
env:
|
| 17 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 18 |
+
run: |
|
| 19 |
+
git config user.email "ci@github.com"
|
| 20 |
+
git config user.name "CI"
|
| 21 |
+
git checkout --orphan hf-clean
|
| 22 |
+
git add -A
|
| 23 |
+
git commit -m "deploy"
|
| 24 |
+
git push https://mvp-lab:$HF_TOKEN@huggingface.co/spaces/mvp-lab/VTON_TEST hf-clean:main -f
|
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/examples/portraits/
|
| 2 |
+
data/examples/garments/
|
| 3 |
+
data/examples/results/
|
| 4 |
+
data/user_uploads/
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.pyc
|
| 7 |
+
.venv/
|
.gitmodules
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "fashn-vton-1.5"]
|
| 2 |
+
path = fashn-vton-1.5
|
| 3 |
+
url = https://github.com/fashn-AI/fashn-vton-1.5.git
|
README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: MultiSubjectVTON
|
| 3 |
+
emoji: 📊
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.7.0
|
| 8 |
+
python_version: '3.12'
|
| 9 |
+
app_file: app.py
|
| 10 |
+
pinned: false
|
| 11 |
+
short_description: Multi-subject VTON model
|
| 12 |
+
fullWidth: false
|
| 13 |
+
logs: true
|
| 14 |
+
---
|
app.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torch
|
| 6 |
+
from fashn_vton import TryOnPipeline
|
| 7 |
+
from ultralytics import YOLO
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import subprocess
|
| 11 |
+
import sys
|
| 12 |
+
from scipy.spatial import cKDTree
|
| 13 |
+
from ui import build_demo
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MultiPersonVTON:
|
| 17 |
+
def __init__(self, weights_dir="./weights"):
|
| 18 |
+
print("Initializing Multi-Person VTON pipeline...")
|
| 19 |
+
self.pipeline = TryOnPipeline(weights_dir=weights_dir)
|
| 20 |
+
self.model = YOLO("yolo26n-seg.pt")
|
| 21 |
+
print("Pipeline initialized")
|
| 22 |
+
|
| 23 |
+
def get_mask(self, result, H, W):
|
| 24 |
+
cls_ids = result.boxes.cls.cpu().numpy().astype(int)
|
| 25 |
+
person_idxs = cls_ids == 0
|
| 26 |
+
person_polygons = [poly for poly, keep in zip(result.masks.xy, person_idxs) if keep]
|
| 27 |
+
masks = []
|
| 28 |
+
for poly in person_polygons:
|
| 29 |
+
mask = np.zeros((H, W), dtype=np.uint8)
|
| 30 |
+
poly_int = np.round(poly).astype(np.int32)
|
| 31 |
+
cv2.fillPoly(mask, [poly_int], 1)
|
| 32 |
+
masks.append(mask.astype(bool))
|
| 33 |
+
return masks
|
| 34 |
+
|
| 35 |
+
def extract_people(self, img, masks):
|
| 36 |
+
img_np = np.array(img) if isinstance(img, Image.Image) else img.copy()
|
| 37 |
+
people = []
|
| 38 |
+
for mask in masks:
|
| 39 |
+
cutout = img_np.copy()
|
| 40 |
+
cutout[~mask] = 255
|
| 41 |
+
people.append(Image.fromarray(cutout))
|
| 42 |
+
return people
|
| 43 |
+
|
| 44 |
+
def apply_vton_to_people(self, people, assignments):
|
| 45 |
+
"""Apply VTON per person based on individual assignments.
|
| 46 |
+
|
| 47 |
+
assignments: list of {"garment": PIL.Image|None, "category": str} per person.
|
| 48 |
+
If garment is None, person is kept as-is (skipped).
|
| 49 |
+
"""
|
| 50 |
+
vton_people = []
|
| 51 |
+
for i, person in enumerate(people):
|
| 52 |
+
garment = assignments[i]["garment"]
|
| 53 |
+
if garment is not None:
|
| 54 |
+
result = self.pipeline(
|
| 55 |
+
person_image=person,
|
| 56 |
+
garment_image=garment,
|
| 57 |
+
category=assignments[i]["category"]
|
| 58 |
+
)
|
| 59 |
+
vton_people.append(result.images[0])
|
| 60 |
+
else:
|
| 61 |
+
vton_people.append(person)
|
| 62 |
+
return vton_people
|
| 63 |
+
|
| 64 |
+
def get_vton_masks(self, vton_people):
|
| 65 |
+
vton_masks = []
|
| 66 |
+
for people in vton_people:
|
| 67 |
+
people_arr = np.array(people)
|
| 68 |
+
gray = cv2.cvtColor(people_arr, cv2.COLOR_RGB2GRAY)
|
| 69 |
+
_, mask = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)
|
| 70 |
+
mask = mask.astype(bool)
|
| 71 |
+
kernel = np.ones((5, 5), np.uint8)
|
| 72 |
+
mask_clean = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel, iterations=1)
|
| 73 |
+
mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel, iterations=2)
|
| 74 |
+
mask_u8 = (mask_clean.astype(np.uint8) * 255)
|
| 75 |
+
mask_blur = cv2.GaussianBlur(mask_u8, (3, 3), 1)
|
| 76 |
+
vton_masks.append(mask_blur)
|
| 77 |
+
return vton_masks
|
| 78 |
+
|
| 79 |
+
def contour_curvature(self, contour, k=5):
|
| 80 |
+
pts = contour[:, 0, :].astype(np.float32)
|
| 81 |
+
N = len(pts)
|
| 82 |
+
curv = np.zeros(N)
|
| 83 |
+
for i in range(N):
|
| 84 |
+
p_prev = pts[(i - k) % N]
|
| 85 |
+
p = pts[i]
|
| 86 |
+
p_next = pts[(i + k) % N]
|
| 87 |
+
v1 = p - p_prev
|
| 88 |
+
v2 = p_next - p
|
| 89 |
+
v1 /= (np.linalg.norm(v1) + 1e-6)
|
| 90 |
+
v2 /= (np.linalg.norm(v2) + 1e-6)
|
| 91 |
+
angle = np.arccos(np.clip(np.dot(v1, v2), -1, 1))
|
| 92 |
+
curv[i] = angle
|
| 93 |
+
return curv
|
| 94 |
+
|
| 95 |
+
def frontness_score(self, mask_a, mask_b):
|
| 96 |
+
inter = mask_a & mask_b
|
| 97 |
+
if inter.sum() < 50:
|
| 98 |
+
return 0.0
|
| 99 |
+
cnts_a, _ = cv2.findContours(mask_a.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
| 100 |
+
cnts_b, _ = cv2.findContours(mask_b.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
| 101 |
+
if not cnts_a or not cnts_b:
|
| 102 |
+
return 0.0
|
| 103 |
+
ca = max(cnts_a, key=len)
|
| 104 |
+
cb = max(cnts_b, key=len)
|
| 105 |
+
curv_a = self.contour_curvature(ca)
|
| 106 |
+
curv_b = self.contour_curvature(cb)
|
| 107 |
+
inter_pts = np.column_stack(np.where(inter))[:, ::-1]
|
| 108 |
+
tree_a = cKDTree(ca[:, 0, :])
|
| 109 |
+
tree_b = cKDTree(cb[:, 0, :])
|
| 110 |
+
_, idx_a = tree_a.query(inter_pts, k=1)
|
| 111 |
+
_, idx_b = tree_b.query(inter_pts, k=1)
|
| 112 |
+
score_a = curv_a[idx_a].mean()
|
| 113 |
+
score_b = curv_b[idx_b].mean()
|
| 114 |
+
return score_a - score_b
|
| 115 |
+
|
| 116 |
+
def estimate_front_to_back_order(self, masks):
|
| 117 |
+
n = len(masks)
|
| 118 |
+
scores = np.zeros(n)
|
| 119 |
+
for i in range(n):
|
| 120 |
+
for j in range(n):
|
| 121 |
+
if i == j:
|
| 122 |
+
continue
|
| 123 |
+
scores[i] += self.frontness_score(masks[i], masks[j])
|
| 124 |
+
order = np.argsort(-scores)
|
| 125 |
+
return order, scores
|
| 126 |
+
|
| 127 |
+
def remove_original_people(self, image, person_masks):
|
| 128 |
+
image_np = np.array(image)
|
| 129 |
+
combined_mask = np.zeros(image_np.shape[:2], dtype=np.uint8)
|
| 130 |
+
for mask in person_masks:
|
| 131 |
+
combined_mask[mask] = 255
|
| 132 |
+
kernel = np.ones((5, 5), np.uint8)
|
| 133 |
+
combined_mask = cv2.dilate(combined_mask, kernel, iterations=2)
|
| 134 |
+
inpainted = cv2.inpaint(image_np, combined_mask, 3, cv2.INPAINT_TELEA)
|
| 135 |
+
return Image.fromarray(inpainted), combined_mask
|
| 136 |
+
|
| 137 |
+
def clean_vton_edges_on_overlap(self, img_pil, mask_uint8, other_masks_uint8,
|
| 138 |
+
erode_iters=1, edge_dilate=2, inner_erode=2):
|
| 139 |
+
src = np.array(img_pil).copy()
|
| 140 |
+
others_union = np.zeros_like(mask_uint8, dtype=np.uint8)
|
| 141 |
+
for m in other_masks_uint8:
|
| 142 |
+
others_union = np.maximum(others_union, m)
|
| 143 |
+
overlap = (mask_uint8 > 0) & (others_union > 0)
|
| 144 |
+
overlap = overlap.astype(np.uint8) * 255
|
| 145 |
+
if overlap.sum() == 0:
|
| 146 |
+
return img_pil, mask_uint8
|
| 147 |
+
kernel = np.ones((3, 3), np.uint8)
|
| 148 |
+
tight_mask = cv2.erode(mask_uint8, kernel, iterations=erode_iters)
|
| 149 |
+
edge = cv2.Canny(tight_mask, 50, 150)
|
| 150 |
+
edge = cv2.dilate(edge, np.ones((3, 3), np.uint8), iterations=edge_dilate)
|
| 151 |
+
overlap_band = cv2.dilate(overlap, np.ones((5, 5), np.uint8), iterations=1)
|
| 152 |
+
edge = cv2.bitwise_and(edge, overlap_band)
|
| 153 |
+
if edge.sum() == 0:
|
| 154 |
+
return img_pil, tight_mask
|
| 155 |
+
inner = cv2.erode(tight_mask, np.ones((5, 5), np.uint8), iterations=inner_erode)
|
| 156 |
+
inner_rgb = cv2.inpaint(src, 255 - inner, 3, cv2.INPAINT_TELEA)
|
| 157 |
+
src[edge > 0] = inner_rgb[edge > 0]
|
| 158 |
+
return Image.fromarray(src), tight_mask
|
| 159 |
+
|
| 160 |
+
def clean_masks(self, vton_people, vton_masks):
|
| 161 |
+
cleaned_vton_people = []
|
| 162 |
+
cleaned_vton_masks = []
|
| 163 |
+
for i in range(len(vton_people)):
|
| 164 |
+
other_masks = [m for j, m in enumerate(vton_masks) if j != i]
|
| 165 |
+
cleaned_img, cleaned_mask = self.clean_vton_edges_on_overlap(
|
| 166 |
+
vton_people[i], vton_masks[i], other_masks,
|
| 167 |
+
erode_iters=1, edge_dilate=2, inner_erode=2
|
| 168 |
+
)
|
| 169 |
+
cleaned_vton_people.append(cleaned_img)
|
| 170 |
+
cleaned_vton_masks.append(cleaned_mask)
|
| 171 |
+
return cleaned_vton_people, cleaned_vton_masks
|
| 172 |
+
|
| 173 |
+
def process_group_image(self, group_image, assignments):
|
| 174 |
+
"""Process a group image with per-person garment assignments.
|
| 175 |
+
|
| 176 |
+
assignments: list of {"garment": PIL.Image|None, "category": str} per person.
|
| 177 |
+
"""
|
| 178 |
+
print("Step 1: Loading images...")
|
| 179 |
+
if isinstance(group_image, np.ndarray):
|
| 180 |
+
group_image = Image.fromarray(group_image)
|
| 181 |
+
if isinstance(group_image, Image.Image):
|
| 182 |
+
group_image.save("people.png")
|
| 183 |
+
|
| 184 |
+
img_bgr = cv2.imread("people.png")
|
| 185 |
+
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
| 186 |
+
H, W = img.shape[:2]
|
| 187 |
+
|
| 188 |
+
print("Step 2: Getting segmentation masks with YOLO...")
|
| 189 |
+
results = self.model("people.png")
|
| 190 |
+
result = results[0]
|
| 191 |
+
masks = self.get_mask(result, H, W)
|
| 192 |
+
print(f"Found {len(masks)} people")
|
| 193 |
+
|
| 194 |
+
print("Step 3: Extracting individual people...")
|
| 195 |
+
people = self.extract_people(img, masks)
|
| 196 |
+
|
| 197 |
+
# Pad assignments to match detected people count
|
| 198 |
+
while len(assignments) < len(people):
|
| 199 |
+
assignments.append({"garment": None, "category": "tops"})
|
| 200 |
+
|
| 201 |
+
print("Step 4: Applying VTON to people...")
|
| 202 |
+
vton_people = self.apply_vton_to_people(people, assignments)
|
| 203 |
+
|
| 204 |
+
print("Step 5: Getting masks for VTON results...")
|
| 205 |
+
vton_masks = self.get_vton_masks(vton_people)
|
| 206 |
+
for i in range(len(vton_masks)):
|
| 207 |
+
if assignments[i]["garment"] is None:
|
| 208 |
+
yolo_mask = (masks[i].astype(np.uint8) * 255)
|
| 209 |
+
yolo_mask = cv2.GaussianBlur(yolo_mask, (3, 3), 1)
|
| 210 |
+
vton_masks[i] = yolo_mask
|
| 211 |
+
order, scores = self.estimate_front_to_back_order(vton_masks)
|
| 212 |
+
cleaned_vton_people, cleaned_vton_masks = self.clean_masks(vton_people, vton_masks)
|
| 213 |
+
|
| 214 |
+
print("Step 6: Resizing to match dimensions...")
|
| 215 |
+
img = cv2.resize(img, vton_people[0].size)
|
| 216 |
+
|
| 217 |
+
print("Step 7: Creating clean background by removing original people...")
|
| 218 |
+
clean_background, person_mask = self.remove_original_people(img, masks)
|
| 219 |
+
clean_background_np = np.array(clean_background)
|
| 220 |
+
|
| 221 |
+
print("Step 8: Recomposing final image...")
|
| 222 |
+
recomposed = clean_background_np.copy()
|
| 223 |
+
for i in order:
|
| 224 |
+
vton_mask = cleaned_vton_masks[i]
|
| 225 |
+
img_pil = cleaned_vton_people[i]
|
| 226 |
+
out = recomposed.astype(np.float32)
|
| 227 |
+
src = np.array(img_pil).astype(np.float32)
|
| 228 |
+
alpha = (vton_mask.astype(np.float32) / 255.0)[..., None]
|
| 229 |
+
src = src * alpha
|
| 230 |
+
out = src + (1 - alpha) * out
|
| 231 |
+
recomposed = out.astype(np.uint8)
|
| 232 |
+
|
| 233 |
+
final_image = Image.fromarray(recomposed)
|
| 234 |
+
return final_image, {
|
| 235 |
+
"original": Image.fromarray(img),
|
| 236 |
+
"clean_background": clean_background,
|
| 237 |
+
"person_mask": Image.fromarray(person_mask),
|
| 238 |
+
"num_people": len(people),
|
| 239 |
+
"individual_people": people,
|
| 240 |
+
"vton_results": cleaned_vton_people,
|
| 241 |
+
"masks": masks,
|
| 242 |
+
"vton_masks": cleaned_vton_masks
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
WEIGHTS_DIR = Path("./weights")
|
| 247 |
+
|
| 248 |
+
def ensure_weights():
|
| 249 |
+
if WEIGHTS_DIR.exists() and any(WEIGHTS_DIR.iterdir()):
|
| 250 |
+
print("Weights already present, skipping download.")
|
| 251 |
+
return
|
| 252 |
+
print("Downloading weights...")
|
| 253 |
+
subprocess.check_call([
|
| 254 |
+
sys.executable,
|
| 255 |
+
"fashn-vton-1.5/scripts/download_weights.py",
|
| 256 |
+
"--weights-dir",
|
| 257 |
+
str(WEIGHTS_DIR),
|
| 258 |
+
])
|
| 259 |
+
|
| 260 |
+
ensure_weights()
|
| 261 |
+
|
| 262 |
+
_pipeline = None
|
| 263 |
+
|
| 264 |
+
def get_pipeline():
|
| 265 |
+
global _pipeline
|
| 266 |
+
if _pipeline is None:
|
| 267 |
+
_pipeline = MultiPersonVTON()
|
| 268 |
+
return _pipeline
|
| 269 |
+
|
| 270 |
+
@spaces.GPU
|
| 271 |
+
def detect_people(portrait_path):
|
| 272 |
+
if portrait_path is None:
|
| 273 |
+
raise gr.Error("Please select a portrait first.")
|
| 274 |
+
portrait = Image.open(portrait_path) if isinstance(portrait_path, str) else portrait_path
|
| 275 |
+
new_width = 576
|
| 276 |
+
w, h = portrait.size
|
| 277 |
+
new_height = int(h * new_width / w)
|
| 278 |
+
resized = portrait.resize((new_width, new_height), Image.LANCZOS)
|
| 279 |
+
resized.save("people.png")
|
| 280 |
+
pipeline = get_pipeline()
|
| 281 |
+
results = pipeline.model("people.png")
|
| 282 |
+
result = results[0]
|
| 283 |
+
img = np.array(resized)
|
| 284 |
+
H, W = img.shape[:2]
|
| 285 |
+
masks = pipeline.get_mask(result, H, W)
|
| 286 |
+
people = pipeline.extract_people(img, masks)
|
| 287 |
+
return people
|
| 288 |
+
|
| 289 |
+
@spaces.GPU
|
| 290 |
+
def process_images(selected_portrait, garment_pool, num_detected, *assignment_args):
|
| 291 |
+
if selected_portrait is None:
|
| 292 |
+
raise gr.Error("Please select a portrait.")
|
| 293 |
+
if not garment_pool:
|
| 294 |
+
raise gr.Error("Please add at least one garment to the pool.")
|
| 295 |
+
portrait = Image.open(selected_portrait) if isinstance(selected_portrait, str) else selected_portrait
|
| 296 |
+
pipeline = get_pipeline()
|
| 297 |
+
new_width = 576
|
| 298 |
+
w, h = portrait.size
|
| 299 |
+
new_height = int(h * new_width / w)
|
| 300 |
+
resized = portrait.resize((new_width, new_height), Image.LANCZOS)
|
| 301 |
+
|
| 302 |
+
# Build per-person assignments from dropdown/radio values
|
| 303 |
+
# assignment_args: dd_0, dd_1, ..., dd_7, cat_0, cat_1, ..., cat_7
|
| 304 |
+
n = num_detected if num_detected else 0
|
| 305 |
+
max_p = len(assignment_args) // 2
|
| 306 |
+
pool_by_label = {g["label"]: g for g in garment_pool}
|
| 307 |
+
assignments = []
|
| 308 |
+
for i in range(n):
|
| 309 |
+
dd_val = assignment_args[i]
|
| 310 |
+
cat_val = assignment_args[max_p + i]
|
| 311 |
+
if dd_val == "Skip" or dd_val not in pool_by_label:
|
| 312 |
+
assignments.append({"garment": None, "category": cat_val or "tops"})
|
| 313 |
+
else:
|
| 314 |
+
g = pool_by_label[dd_val]
|
| 315 |
+
garment_img = Image.open(g["path"]) if isinstance(g["path"], str) else g["path"]
|
| 316 |
+
assignments.append({"garment": garment_img, "category": cat_val or "tops"})
|
| 317 |
+
|
| 318 |
+
result, _ = pipeline.process_group_image(resized, assignments)
|
| 319 |
+
return result
|
| 320 |
+
|
| 321 |
+
demo = build_demo(process_images, detect_fn=detect_people, max_people=8)
|
| 322 |
+
from huggingface_hub import constants as hf_constants
|
| 323 |
+
demo.launch(allowed_paths=[hf_constants.HF_HUB_CACHE])
|
fashn-vton-1.5
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 7c0f10af3f91ad4048fe9729c470a13ef905d25a
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
numpY
|
| 3 |
+
Pillow
|
| 4 |
+
Requests
|
| 5 |
+
torch
|
| 6 |
+
ultralytics
|
| 7 |
+
fashn-vton @ git+https://github.com/fashn-AI/fashn-vton-1.5.git
|
storage.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
import shutil
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
DATA_REPO = "aj406/vton-data"
|
| 9 |
+
REPO_TYPE = "dataset"
|
| 10 |
+
DATASET_HF_TOKEN = os.environ.get("DATASET_HF_TOKEN")
|
| 11 |
+
LOCAL_DATA = Path("data")
|
| 12 |
+
|
| 13 |
+
IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp"}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def is_remote():
|
| 17 |
+
return DATASET_HF_TOKEN is not None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _api():
|
| 21 |
+
from huggingface_hub import HfApi
|
| 22 |
+
return HfApi()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _ensure_repo():
|
| 26 |
+
if not is_remote():
|
| 27 |
+
return
|
| 28 |
+
_api().create_repo(repo_id=DATA_REPO, repo_type=REPO_TYPE, exist_ok=True, token=DATASET_HF_TOKEN)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def save_image(img, local_path):
|
| 32 |
+
local_path = Path(local_path)
|
| 33 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
if isinstance(img, np.ndarray):
|
| 35 |
+
img = Image.fromarray(img)
|
| 36 |
+
img.save(local_path, "JPEG", quality=85)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def upload_image(local_path, remote_path):
|
| 40 |
+
if not is_remote():
|
| 41 |
+
return
|
| 42 |
+
_ensure_repo()
|
| 43 |
+
_api().upload_file(
|
| 44 |
+
path_or_fileobj=str(local_path),
|
| 45 |
+
path_in_repo=remote_path,
|
| 46 |
+
repo_id=DATA_REPO,
|
| 47 |
+
repo_type=REPO_TYPE,
|
| 48 |
+
token=DATASET_HF_TOKEN,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def delete_remote_file(remote_path):
|
| 53 |
+
if not is_remote():
|
| 54 |
+
return
|
| 55 |
+
_api().delete_file(
|
| 56 |
+
path_in_repo=remote_path,
|
| 57 |
+
repo_id=DATA_REPO,
|
| 58 |
+
repo_type=REPO_TYPE,
|
| 59 |
+
token=DATASET_HF_TOKEN,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def download_dir(remote_prefix):
|
| 64 |
+
if not is_remote():
|
| 65 |
+
return
|
| 66 |
+
from huggingface_hub import snapshot_download
|
| 67 |
+
snapshot_download(
|
| 68 |
+
repo_id=DATA_REPO,
|
| 69 |
+
repo_type=REPO_TYPE,
|
| 70 |
+
allow_patterns=f"{remote_prefix}/**",
|
| 71 |
+
local_dir=str(LOCAL_DATA),
|
| 72 |
+
token=DATASET_HF_TOKEN,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def generate_id():
|
| 77 |
+
return uuid.uuid4().hex[:8]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def make_filename(item_id, item_type):
|
| 81 |
+
return f"{item_id}_{item_type}.jpg"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def parse_filename(filename):
|
| 85 |
+
stem = Path(filename).stem
|
| 86 |
+
parts = stem.rsplit("_", 1)
|
| 87 |
+
if len(parts) != 2:
|
| 88 |
+
return None
|
| 89 |
+
return {"id": parts[0], "type": parts[1]}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def make_result_filename(portrait_id, garment_id):
|
| 93 |
+
return f"{portrait_id}_{garment_id}_result.jpg"
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def parse_result_filename(filename):
|
| 97 |
+
stem = Path(filename).stem
|
| 98 |
+
parts = stem.rsplit("_", 2)
|
| 99 |
+
if len(parts) != 3 or parts[2] != "result":
|
| 100 |
+
return None
|
| 101 |
+
if "-" in parts[1]:
|
| 102 |
+
return None
|
| 103 |
+
return {"portrait_id": parts[0], "garment_id": parts[1]}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def make_multi_result_filename(portrait_id, garment_ids):
|
| 107 |
+
"""Build result filename encoding per-person garment assignments.
|
| 108 |
+
|
| 109 |
+
garment_ids: list of garment_id (str) or None per person.
|
| 110 |
+
Example: portrait_id=abc123, garment_ids=["ef12ab34", None, "gh56cd78"]
|
| 111 |
+
→ "abc123_ef12ab34-x-gh56cd78_result.jpg"
|
| 112 |
+
"""
|
| 113 |
+
slots = [gid if gid else "x" for gid in garment_ids]
|
| 114 |
+
code = "-".join(slots)
|
| 115 |
+
return f"{portrait_id}_{code}_result.jpg"
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def parse_multi_result_filename(filename):
|
| 119 |
+
stem = Path(filename).stem
|
| 120 |
+
if not stem.endswith("_result"):
|
| 121 |
+
return None
|
| 122 |
+
stem = stem[:-len("_result")]
|
| 123 |
+
parts = stem.split("_", 1)
|
| 124 |
+
if len(parts) != 2:
|
| 125 |
+
return None
|
| 126 |
+
portrait_id = parts[0]
|
| 127 |
+
code = parts[1]
|
| 128 |
+
slots = code.split("-")
|
| 129 |
+
garment_ids = [None if slot == "x" else slot for slot in slots]
|
| 130 |
+
return {"portrait_id": portrait_id, "garment_ids": garment_ids}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def list_local_images(directory):
|
| 134 |
+
d = Path(directory)
|
| 135 |
+
if not d.exists():
|
| 136 |
+
return []
|
| 137 |
+
return sorted([str(p) for p in d.iterdir() if p.suffix.lower() in IMG_EXTS])
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def file_url(remote_path):
|
| 141 |
+
"""Return a direct HF URL for a file in the dataset repo (public repo)."""
|
| 142 |
+
return f"https://huggingface.co/datasets/{DATA_REPO}/resolve/main/{remote_path}"
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def list_gallery_urls(prefix, subdir):
|
| 146 |
+
"""List files in dataset repo and return direct URLs for gallery display."""
|
| 147 |
+
if not is_remote():
|
| 148 |
+
return list_local_images(LOCAL_DATA / prefix / subdir)
|
| 149 |
+
try:
|
| 150 |
+
items = _api().list_repo_tree(
|
| 151 |
+
DATA_REPO, repo_type=REPO_TYPE, path_in_repo=f"{prefix}/{subdir}"
|
| 152 |
+
)
|
| 153 |
+
urls = []
|
| 154 |
+
for item in items:
|
| 155 |
+
if hasattr(item, "rfilename"):
|
| 156 |
+
name = item.rfilename
|
| 157 |
+
elif hasattr(item, "path"):
|
| 158 |
+
name = item.path
|
| 159 |
+
else:
|
| 160 |
+
continue
|
| 161 |
+
if Path(name).suffix.lower() in IMG_EXTS:
|
| 162 |
+
urls.append(file_url(name))
|
| 163 |
+
return sorted(urls)
|
| 164 |
+
except Exception:
|
| 165 |
+
return list_local_images(LOCAL_DATA / prefix / subdir)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
HF_URL_PREFIX = f"https://huggingface.co/datasets/{DATA_REPO}/resolve/main/"
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def is_dataset_url(url):
|
| 172 |
+
"""Check if a URL points to our HF dataset repo."""
|
| 173 |
+
return isinstance(url, str) and url.startswith(HF_URL_PREFIX)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def download_to_local(path_or_url):
|
| 177 |
+
"""Download a URL to local path. HF dataset URLs use hf_hub, other URLs use requests."""
|
| 178 |
+
if not isinstance(path_or_url, str):
|
| 179 |
+
return path_or_url
|
| 180 |
+
if is_dataset_url(path_or_url):
|
| 181 |
+
remote_path = path_or_url[len(HF_URL_PREFIX):]
|
| 182 |
+
from huggingface_hub import hf_hub_download
|
| 183 |
+
local = hf_hub_download(
|
| 184 |
+
repo_id=DATA_REPO,
|
| 185 |
+
repo_type=REPO_TYPE,
|
| 186 |
+
filename=remote_path,
|
| 187 |
+
token=DATASET_HF_TOKEN,
|
| 188 |
+
)
|
| 189 |
+
return local
|
| 190 |
+
if path_or_url.startswith(("http://", "https://")):
|
| 191 |
+
import requests
|
| 192 |
+
from io import BytesIO
|
| 193 |
+
resp = requests.get(path_or_url, timeout=30)
|
| 194 |
+
resp.raise_for_status()
|
| 195 |
+
img = Image.open(BytesIO(resp.content))
|
| 196 |
+
tmp_path = LOCAL_DATA / "tmp" / f"{generate_id()}.jpg"
|
| 197 |
+
tmp_path.parent.mkdir(parents=True, exist_ok=True)
|
| 198 |
+
save_image(img, tmp_path)
|
| 199 |
+
return str(tmp_path)
|
| 200 |
+
return path_or_url
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def load_image_sets(prefix):
|
| 204 |
+
"""Scan {prefix}/portraits/ dir, parse filenames, return list of dicts with matched files."""
|
| 205 |
+
local_prefix = LOCAL_DATA / prefix
|
| 206 |
+
if is_remote():
|
| 207 |
+
download_dir(prefix)
|
| 208 |
+
portraits_dir = local_prefix / "portraits"
|
| 209 |
+
if not portraits_dir.exists():
|
| 210 |
+
return []
|
| 211 |
+
sets = {}
|
| 212 |
+
for p in portraits_dir.iterdir():
|
| 213 |
+
if p.suffix.lower() not in IMG_EXTS:
|
| 214 |
+
continue
|
| 215 |
+
parsed = parse_filename(p.name)
|
| 216 |
+
if not parsed:
|
| 217 |
+
continue
|
| 218 |
+
item_id = parsed["id"]
|
| 219 |
+
sets[item_id] = {
|
| 220 |
+
"id": item_id,
|
| 221 |
+
"portrait": str(p),
|
| 222 |
+
}
|
| 223 |
+
garments_dir = local_prefix / "garments"
|
| 224 |
+
results_dir = local_prefix / "results"
|
| 225 |
+
for item_id, entry in sets.items():
|
| 226 |
+
garment = garments_dir / f"{item_id}_garment.jpg"
|
| 227 |
+
result = results_dir / f"{item_id}_result.jpg"
|
| 228 |
+
entry["garment"] = str(garment) if garment.exists() else None
|
| 229 |
+
entry["result"] = str(result) if result.exists() else None
|
| 230 |
+
return [v for v in sets.values() if v["garment"] is not None]
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def save_image_set(prefix, img_portrait, img_garment, img_result=None):
|
| 234 |
+
"""Save a set of images (portrait + garment + optional result) with consistent naming."""
|
| 235 |
+
item_id = generate_id()
|
| 236 |
+
local_prefix = LOCAL_DATA / prefix
|
| 237 |
+
|
| 238 |
+
portrait_name = make_filename(item_id, "portrait")
|
| 239 |
+
garment_name = make_filename(item_id, "garment")
|
| 240 |
+
|
| 241 |
+
portrait_path = local_prefix / "portraits" / portrait_name
|
| 242 |
+
garment_path = local_prefix / "garments" / garment_name
|
| 243 |
+
|
| 244 |
+
save_image(img_portrait, portrait_path)
|
| 245 |
+
save_image(img_garment, garment_path)
|
| 246 |
+
upload_image(portrait_path, f"{prefix}/portraits/{portrait_name}")
|
| 247 |
+
upload_image(garment_path, f"{prefix}/garments/{garment_name}")
|
| 248 |
+
|
| 249 |
+
result_path = None
|
| 250 |
+
if img_result is not None:
|
| 251 |
+
result_name = make_result_filename(item_id, item_id)
|
| 252 |
+
result_path = local_prefix / "results" / result_name
|
| 253 |
+
save_image(img_result, result_path)
|
| 254 |
+
upload_image(result_path, f"{prefix}/results/{result_name}")
|
| 255 |
+
|
| 256 |
+
return item_id, str(portrait_path), str(garment_path), str(result_path) if result_path else None
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def save_result(prefix, portrait_id, garment_id, img_result):
|
| 260 |
+
"""Save a result image encoding both portrait and garment IDs."""
|
| 261 |
+
local_prefix = LOCAL_DATA / prefix
|
| 262 |
+
result_name = make_result_filename(portrait_id, garment_id)
|
| 263 |
+
result_path = local_prefix / "results" / result_name
|
| 264 |
+
save_image(img_result, result_path)
|
| 265 |
+
upload_image(result_path, f"{prefix}/results/{result_name}")
|
| 266 |
+
return str(result_path)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def save_multi_result(prefix, portrait_id, assignments, img_result):
|
| 270 |
+
"""Save a multi-garment result image with assignment-encoded filename."""
|
| 271 |
+
local_prefix = LOCAL_DATA / prefix
|
| 272 |
+
result_name = make_multi_result_filename(portrait_id, assignments)
|
| 273 |
+
result_path = local_prefix / "results" / result_name
|
| 274 |
+
save_image(img_result, result_path)
|
| 275 |
+
upload_image(result_path, f"{prefix}/results/{result_name}")
|
| 276 |
+
return str(result_path)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def delete_image_set(prefix, item_id):
|
| 280 |
+
"""Delete all files for an image set (scans for ID prefix to catch multi-garment files)."""
|
| 281 |
+
local_prefix = LOCAL_DATA / prefix
|
| 282 |
+
for subdir in ("portraits", "garments", "results"):
|
| 283 |
+
d = local_prefix / subdir
|
| 284 |
+
if not d.exists():
|
| 285 |
+
continue
|
| 286 |
+
for f in d.iterdir():
|
| 287 |
+
if f.stem.startswith(item_id):
|
| 288 |
+
f.unlink()
|
| 289 |
+
if is_remote():
|
| 290 |
+
try:
|
| 291 |
+
delete_remote_file(f"{prefix}/{subdir}/{f.name}")
|
| 292 |
+
except Exception:
|
| 293 |
+
pass
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def promote_to_example(result_path):
|
| 297 |
+
"""Copy a result file to examples, preserving its filename for resolution."""
|
| 298 |
+
src = Path(result_path)
|
| 299 |
+
dest = LOCAL_DATA / "examples" / "results" / src.name
|
| 300 |
+
dest.parent.mkdir(parents=True, exist_ok=True)
|
| 301 |
+
shutil.copy2(str(src), str(dest))
|
| 302 |
+
upload_image(dest, f"examples/results/{src.name}")
|
| 303 |
+
return src.stem
|
ui.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from storage import (
|
| 5 |
+
save_multi_result,
|
| 6 |
+
delete_image_set,
|
| 7 |
+
promote_to_example,
|
| 8 |
+
list_gallery_urls,
|
| 9 |
+
download_to_local,
|
| 10 |
+
is_dataset_url,
|
| 11 |
+
save_image,
|
| 12 |
+
upload_image,
|
| 13 |
+
make_filename,
|
| 14 |
+
generate_id,
|
| 15 |
+
parse_filename,
|
| 16 |
+
parse_result_filename,
|
| 17 |
+
parse_multi_result_filename,
|
| 18 |
+
file_url,
|
| 19 |
+
LOCAL_DATA,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
EXAMPLES_PREFIX = "examples"
|
| 23 |
+
UPLOADS_PREFIX = "user_uploads"
|
| 24 |
+
|
| 25 |
+
for subdir in ["portraits", "garments", "results"]:
|
| 26 |
+
(LOCAL_DATA / EXAMPLES_PREFIX / subdir).mkdir(parents=True, exist_ok=True)
|
| 27 |
+
(LOCAL_DATA / UPLOADS_PREFIX / subdir).mkdir(parents=True, exist_ok=True)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _gallery_images(prefix, subdir):
|
| 31 |
+
return list_gallery_urls(prefix, subdir)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
MAX_PEOPLE = 8
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def build_demo(process_fn, detect_fn=None, max_people=MAX_PEOPLE):
|
| 38 |
+
|
| 39 |
+
def process_and_save(portrait_path, garment_pool, num_detected, *assignment_args):
|
| 40 |
+
if portrait_path is None:
|
| 41 |
+
raise gr.Error("Please select a portrait.")
|
| 42 |
+
if not garment_pool:
|
| 43 |
+
raise gr.Error("Please add at least one garment to the pool.")
|
| 44 |
+
result = process_fn(portrait_path, garment_pool, num_detected, *assignment_args)
|
| 45 |
+
if result and portrait_path:
|
| 46 |
+
p_parsed = parse_filename(Path(portrait_path).name)
|
| 47 |
+
if p_parsed:
|
| 48 |
+
# Upload portrait if local
|
| 49 |
+
local = Path(portrait_path)
|
| 50 |
+
if local.exists() and str(local).startswith(str(LOCAL_DATA)):
|
| 51 |
+
upload_image(local, str(local.relative_to(LOCAL_DATA)))
|
| 52 |
+
# Upload all garments in pool
|
| 53 |
+
for g in garment_pool:
|
| 54 |
+
garment_local = Path(g["path"])
|
| 55 |
+
if garment_local.exists() and str(garment_local).startswith(str(LOCAL_DATA)):
|
| 56 |
+
upload_image(garment_local, str(garment_local.relative_to(LOCAL_DATA)))
|
| 57 |
+
# Build garment ID list for filename
|
| 58 |
+
n = num_detected if num_detected else 0
|
| 59 |
+
max_p = len(assignment_args) // 2
|
| 60 |
+
pool_by_label = {g["label"]: g for g in garment_pool}
|
| 61 |
+
garment_ids = []
|
| 62 |
+
for i in range(n):
|
| 63 |
+
dd_val = assignment_args[i]
|
| 64 |
+
if dd_val == "Skip" or dd_val not in pool_by_label:
|
| 65 |
+
garment_ids.append(None)
|
| 66 |
+
else:
|
| 67 |
+
g = pool_by_label[dd_val]
|
| 68 |
+
g_parsed = parse_filename(Path(g["path"]).name)
|
| 69 |
+
garment_ids.append(g_parsed["id"] if g_parsed else generate_id())
|
| 70 |
+
save_multi_result(UPLOADS_PREFIX, p_parsed["id"], garment_ids, result)
|
| 71 |
+
return result
|
| 72 |
+
|
| 73 |
+
with gr.Blocks(title="Multi-Person Virtual Try-On") as demo:
|
| 74 |
+
with gr.Tabs():
|
| 75 |
+
# ---- Main VTON Tab ----
|
| 76 |
+
with gr.Tab("Virtual Try-On"):
|
| 77 |
+
gr.Markdown("# Multi-Person Virtual Try-On")
|
| 78 |
+
gr.Markdown("Select a portrait, add garments to the pool, detect people, and assign garments.")
|
| 79 |
+
|
| 80 |
+
selected_portrait = gr.State(value=None)
|
| 81 |
+
garment_pool = gr.State(value=[])
|
| 82 |
+
num_detected = gr.State(value=0)
|
| 83 |
+
garment_counter = gr.State(value=0)
|
| 84 |
+
|
| 85 |
+
gr.Markdown("**Step 1:** Select or upload a portrait, and select or upload garments to the pool.")
|
| 86 |
+
with gr.Row():
|
| 87 |
+
with gr.Column():
|
| 88 |
+
gr.Markdown("### Portrait")
|
| 89 |
+
portrait_gallery = gr.Gallery(
|
| 90 |
+
value=_gallery_images(UPLOADS_PREFIX, "portraits"),
|
| 91 |
+
label="Uploaded Portraits",
|
| 92 |
+
columns=4,
|
| 93 |
+
height=200,
|
| 94 |
+
allow_preview=False,
|
| 95 |
+
)
|
| 96 |
+
with gr.Accordion("Upload new portrait", open=False):
|
| 97 |
+
portrait_upload = gr.Image(type="pil", label="Upload Portrait", sources=["upload", "webcam"])
|
| 98 |
+
with gr.Row():
|
| 99 |
+
portrait_url_input = gr.Textbox(label="Or paste image URL", scale=4)
|
| 100 |
+
portrait_url_btn = gr.Button("Load", size="sm", scale=1)
|
| 101 |
+
preview_portrait = gr.Image(label="Selected Portrait", interactive=False, height=250)
|
| 102 |
+
|
| 103 |
+
# Garment pool section
|
| 104 |
+
with gr.Column():
|
| 105 |
+
gr.Markdown("### Garment Pool")
|
| 106 |
+
garment_gallery = gr.Gallery(
|
| 107 |
+
value=_gallery_images(UPLOADS_PREFIX, "garments"),
|
| 108 |
+
label="Available Garments (click to add to pool)",
|
| 109 |
+
columns=4,
|
| 110 |
+
height=200,
|
| 111 |
+
allow_preview=False,
|
| 112 |
+
)
|
| 113 |
+
with gr.Accordion("Upload new garment", open=False):
|
| 114 |
+
garment_upload = gr.Image(type="pil", label="Upload Garment", sources=["upload", "webcam"])
|
| 115 |
+
with gr.Row():
|
| 116 |
+
garment_url_input = gr.Textbox(label="Or paste image URL", scale=4)
|
| 117 |
+
garment_url_btn = gr.Button("Load", size="sm", scale=1)
|
| 118 |
+
garment_pool_gallery = gr.Gallery(
|
| 119 |
+
label="Current Pool",
|
| 120 |
+
columns=6,
|
| 121 |
+
height=250,
|
| 122 |
+
allow_preview=False,
|
| 123 |
+
)
|
| 124 |
+
clear_pool_btn = gr.Button("Clear Pool", size="sm", variant="stop")
|
| 125 |
+
|
| 126 |
+
gr.Markdown("**Step 2:** Detect people in the portrait. This lets you choose which garment goes on each person.")
|
| 127 |
+
detect_btn = gr.Button("Detect People", variant="secondary")
|
| 128 |
+
detect_status = gr.Textbox(interactive=False, show_label=False, value="")
|
| 129 |
+
people_gallery = gr.Gallery(
|
| 130 |
+
label="Detected People",
|
| 131 |
+
columns=6,
|
| 132 |
+
height=300,
|
| 133 |
+
allow_preview=False,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
gr.Markdown("**Step 3:** Assign garments to each person, then click Try On.")
|
| 137 |
+
@gr.render(inputs=[num_detected, garment_pool])
|
| 138 |
+
def render_assignments(n_detected, pool):
|
| 139 |
+
n = n_detected or 0
|
| 140 |
+
choices = ["Skip"] + [g["label"] for g in (pool or [])]
|
| 141 |
+
default_garment = choices[1] if len(choices) > 1 else "Skip"
|
| 142 |
+
|
| 143 |
+
if n > 0:
|
| 144 |
+
gr.Markdown(f"### Assign Garments to {n} {'Person' if n == 1 else 'People'}")
|
| 145 |
+
|
| 146 |
+
dds = []
|
| 147 |
+
cats = []
|
| 148 |
+
for i in range(n):
|
| 149 |
+
with gr.Row():
|
| 150 |
+
dd = gr.Dropdown(
|
| 151 |
+
choices=choices,
|
| 152 |
+
value=default_garment,
|
| 153 |
+
label=f"Person {i + 1} — Garment",
|
| 154 |
+
scale=3,
|
| 155 |
+
interactive=True,
|
| 156 |
+
)
|
| 157 |
+
cat = gr.Radio(
|
| 158 |
+
choices=["tops", "bottoms", "one-pieces"],
|
| 159 |
+
value="tops",
|
| 160 |
+
label="Category",
|
| 161 |
+
scale=2,
|
| 162 |
+
interactive=True,
|
| 163 |
+
)
|
| 164 |
+
dds.append(dd)
|
| 165 |
+
cats.append(cat)
|
| 166 |
+
|
| 167 |
+
submit_btn = gr.Button("Try On", variant="primary")
|
| 168 |
+
result_image = gr.Image(type="pil", label="Result", interactive=False)
|
| 169 |
+
|
| 170 |
+
submit_btn.click(
|
| 171 |
+
process_and_save,
|
| 172 |
+
inputs=[selected_portrait, garment_pool, num_detected] + dds + cats,
|
| 173 |
+
outputs=result_image,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Examples section
|
| 177 |
+
gr.Markdown("---")
|
| 178 |
+
gr.Markdown("### Examples")
|
| 179 |
+
example_sets = gr.State(value=[])
|
| 180 |
+
refresh_examples_btn = gr.Button("Refresh Examples", size="sm")
|
| 181 |
+
|
| 182 |
+
@gr.render(inputs=[example_sets])
|
| 183 |
+
def render_examples(sets):
|
| 184 |
+
for i, ex in enumerate(sets or []):
|
| 185 |
+
with gr.Row():
|
| 186 |
+
gr.Image(value=ex["portrait"], label="Portrait", height=200, interactive=False, scale=1)
|
| 187 |
+
for j, g in enumerate(ex["garments"]):
|
| 188 |
+
gr.Image(value=g, label=f"Garment {j+1}", height=200, interactive=False, scale=1)
|
| 189 |
+
gr.Image(value=ex["result"], label="Result", height=200, interactive=False, scale=1)
|
| 190 |
+
use_btn = gr.Button("Use", size="sm", scale=0, min_width=60)
|
| 191 |
+
use_btn.click(
|
| 192 |
+
lambda p=ex["portrait"], gs=ex["garments"]: _load_example(p, gs),
|
| 193 |
+
outputs=[selected_portrait, preview_portrait, garment_pool, garment_counter, garment_pool_gallery],
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# -- Event handlers --
|
| 197 |
+
|
| 198 |
+
def on_portrait_gallery_select(evt: gr.SelectData):
|
| 199 |
+
path = evt.value["image"]["path"]
|
| 200 |
+
local_path = download_to_local(path)
|
| 201 |
+
return local_path, local_path
|
| 202 |
+
|
| 203 |
+
def on_garment_gallery_select(evt: gr.SelectData, pool, counter):
|
| 204 |
+
"""Add selected garment to pool."""
|
| 205 |
+
path = evt.value["image"]["path"]
|
| 206 |
+
local_path = download_to_local(path)
|
| 207 |
+
new_counter = counter + 1
|
| 208 |
+
label = f"Garment {new_counter}"
|
| 209 |
+
new_pool = pool + [{"path": local_path, "label": label}]
|
| 210 |
+
pool_images = [g["path"] for g in new_pool]
|
| 211 |
+
return new_pool, new_counter, pool_images
|
| 212 |
+
|
| 213 |
+
def _load_example(portrait_path, garment_paths):
|
| 214 |
+
pool = [{"path": g, "label": f"Garment {i+1}"} for i, g in enumerate(garment_paths)]
|
| 215 |
+
pool_images = [g["path"] for g in pool]
|
| 216 |
+
return portrait_path, portrait_path, pool, len(pool), pool_images
|
| 217 |
+
|
| 218 |
+
def clear_pool():
|
| 219 |
+
return [], 0, [], _gallery_images(UPLOADS_PREFIX, "garments")
|
| 220 |
+
|
| 221 |
+
def reset_detection():
|
| 222 |
+
return "", [], 0
|
| 223 |
+
|
| 224 |
+
detection_reset_outputs = [detect_status, people_gallery, num_detected]
|
| 225 |
+
|
| 226 |
+
portrait_gallery.select(
|
| 227 |
+
on_portrait_gallery_select, outputs=[selected_portrait, preview_portrait]
|
| 228 |
+
).then(reset_detection, outputs=detection_reset_outputs)
|
| 229 |
+
|
| 230 |
+
garment_gallery.select(
|
| 231 |
+
on_garment_gallery_select,
|
| 232 |
+
inputs=[garment_pool, garment_counter],
|
| 233 |
+
outputs=[garment_pool, garment_counter, garment_pool_gallery],
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
clear_pool_btn.click(
|
| 237 |
+
clear_pool,
|
| 238 |
+
outputs=[garment_pool, garment_counter, garment_pool_gallery, garment_gallery],
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def on_portrait_upload(img, current_pool):
|
| 242 |
+
if img is None:
|
| 243 |
+
return _gallery_images(UPLOADS_PREFIX, "portraits"), None, None, None
|
| 244 |
+
item_id = generate_id()
|
| 245 |
+
fname = make_filename(item_id, "portrait")
|
| 246 |
+
local_path = LOCAL_DATA / UPLOADS_PREFIX / "portraits" / fname
|
| 247 |
+
save_image(img, local_path)
|
| 248 |
+
path = str(local_path)
|
| 249 |
+
return _gallery_images(UPLOADS_PREFIX, "portraits"), path, path, None
|
| 250 |
+
|
| 251 |
+
def on_garment_upload(img, pool, counter):
|
| 252 |
+
if img is None:
|
| 253 |
+
return _gallery_images(UPLOADS_PREFIX, "garments"), pool, counter, [g["path"] for g in pool], None
|
| 254 |
+
item_id = generate_id()
|
| 255 |
+
fname = make_filename(item_id, "garment")
|
| 256 |
+
local_path = LOCAL_DATA / UPLOADS_PREFIX / "garments" / fname
|
| 257 |
+
save_image(img, local_path)
|
| 258 |
+
path = str(local_path)
|
| 259 |
+
new_counter = counter + 1
|
| 260 |
+
label = f"Garment {new_counter}"
|
| 261 |
+
new_pool = pool + [{"path": path, "label": label}]
|
| 262 |
+
pool_images = [g["path"] for g in new_pool]
|
| 263 |
+
return _gallery_images(UPLOADS_PREFIX, "garments"), new_pool, new_counter, pool_images, None
|
| 264 |
+
|
| 265 |
+
def on_portrait_url(url, pool):
|
| 266 |
+
if not url or not url.strip():
|
| 267 |
+
return _gallery_images(UPLOADS_PREFIX, "portraits"), None, None, ""
|
| 268 |
+
local_path = download_to_local(url.strip())
|
| 269 |
+
if not is_dataset_url(url.strip()):
|
| 270 |
+
from PIL import Image as PILImage
|
| 271 |
+
item_id = generate_id()
|
| 272 |
+
fname = make_filename(item_id, "portrait")
|
| 273 |
+
dest = LOCAL_DATA / UPLOADS_PREFIX / "portraits" / fname
|
| 274 |
+
save_image(PILImage.open(local_path), dest)
|
| 275 |
+
local_path = str(dest)
|
| 276 |
+
return _gallery_images(UPLOADS_PREFIX, "portraits"), local_path, local_path, ""
|
| 277 |
+
|
| 278 |
+
def on_garment_url(url, pool, counter):
|
| 279 |
+
if not url or not url.strip():
|
| 280 |
+
return _gallery_images(UPLOADS_PREFIX, "garments"), pool, counter, [g["path"] for g in pool], ""
|
| 281 |
+
local_path = download_to_local(url.strip())
|
| 282 |
+
if not is_dataset_url(url.strip()):
|
| 283 |
+
from PIL import Image as PILImage
|
| 284 |
+
item_id = generate_id()
|
| 285 |
+
fname = make_filename(item_id, "garment")
|
| 286 |
+
dest = LOCAL_DATA / UPLOADS_PREFIX / "garments" / fname
|
| 287 |
+
save_image(PILImage.open(local_path), dest)
|
| 288 |
+
local_path = str(dest)
|
| 289 |
+
new_counter = counter + 1
|
| 290 |
+
label = f"Garment {new_counter}"
|
| 291 |
+
new_pool = pool + [{"path": local_path, "label": label}]
|
| 292 |
+
pool_images = [g["path"] for g in new_pool]
|
| 293 |
+
return _gallery_images(UPLOADS_PREFIX, "garments"), new_pool, new_counter, pool_images, ""
|
| 294 |
+
|
| 295 |
+
portrait_upload.change(
|
| 296 |
+
on_portrait_upload,
|
| 297 |
+
inputs=[portrait_upload, garment_pool],
|
| 298 |
+
outputs=[portrait_gallery, selected_portrait, preview_portrait, portrait_upload],
|
| 299 |
+
).then(reset_detection, outputs=detection_reset_outputs)
|
| 300 |
+
|
| 301 |
+
garment_upload.change(
|
| 302 |
+
on_garment_upload,
|
| 303 |
+
inputs=[garment_upload, garment_pool, garment_counter],
|
| 304 |
+
outputs=[garment_gallery, garment_pool, garment_counter, garment_pool_gallery, garment_upload],
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
portrait_url_btn.click(
|
| 308 |
+
on_portrait_url,
|
| 309 |
+
inputs=[portrait_url_input, garment_pool],
|
| 310 |
+
outputs=[portrait_gallery, selected_portrait, preview_portrait, portrait_url_input],
|
| 311 |
+
).then(reset_detection, outputs=detection_reset_outputs)
|
| 312 |
+
|
| 313 |
+
garment_url_btn.click(
|
| 314 |
+
on_garment_url,
|
| 315 |
+
inputs=[garment_url_input, garment_pool, garment_counter],
|
| 316 |
+
outputs=[garment_gallery, garment_pool, garment_counter, garment_pool_gallery, garment_url_input],
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
def on_detect(portrait_path, pool):
|
| 320 |
+
if detect_fn is None or portrait_path is None:
|
| 321 |
+
raise gr.Error("Please select a portrait first.")
|
| 322 |
+
people = detect_fn(portrait_path)
|
| 323 |
+
n = len(people)
|
| 324 |
+
return f"Found {n} {'person' if n == 1 else 'people'}", people, n
|
| 325 |
+
|
| 326 |
+
detect_btn.click(
|
| 327 |
+
lambda: "Detecting people...",
|
| 328 |
+
outputs=[detect_status],
|
| 329 |
+
).then(
|
| 330 |
+
on_detect,
|
| 331 |
+
inputs=[selected_portrait, garment_pool],
|
| 332 |
+
outputs=[detect_status, people_gallery, num_detected],
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def refresh_examples():
|
| 336 |
+
result_urls = _gallery_images(EXAMPLES_PREFIX, "results")
|
| 337 |
+
sets = []
|
| 338 |
+
for r in result_urls:
|
| 339 |
+
portrait, garments, result = _resolve_result_images(UPLOADS_PREFIX, r)
|
| 340 |
+
if portrait and garments:
|
| 341 |
+
sets.append({"portrait": portrait, "garments": garments, "result": result})
|
| 342 |
+
return sets
|
| 343 |
+
|
| 344 |
+
refresh_examples_btn.click(
|
| 345 |
+
refresh_examples,
|
| 346 |
+
outputs=[example_sets],
|
| 347 |
+
)
|
| 348 |
+
demo.load(
|
| 349 |
+
refresh_examples,
|
| 350 |
+
outputs=[example_sets],
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# ---- Admin Tab ----
|
| 354 |
+
with gr.Tab("Admin - Manage Examples"):
|
| 355 |
+
admin_status = gr.Textbox(label="Status", interactive=False)
|
| 356 |
+
|
| 357 |
+
gr.Markdown("### Current Examples")
|
| 358 |
+
admin_examples_table = gr.Dataframe(
|
| 359 |
+
headers=["ID", "Result Filename"],
|
| 360 |
+
label="Examples",
|
| 361 |
+
interactive=False,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
with gr.Row():
|
| 365 |
+
delete_id = gr.Textbox(label="Example ID to delete", scale=3)
|
| 366 |
+
delete_btn = gr.Button("Delete", variant="stop", scale=1)
|
| 367 |
+
|
| 368 |
+
def get_examples_table():
|
| 369 |
+
results = list_gallery_urls(EXAMPLES_PREFIX, "results")
|
| 370 |
+
rows = []
|
| 371 |
+
for r in results:
|
| 372 |
+
fname = Path(r).name
|
| 373 |
+
parsed = parse_result_filename(fname) or parse_multi_result_filename(fname)
|
| 374 |
+
rid = parsed["portrait_id"] if parsed else Path(fname).stem
|
| 375 |
+
rows.append([rid, fname])
|
| 376 |
+
return rows
|
| 377 |
+
|
| 378 |
+
def on_admin_delete(ex_id):
|
| 379 |
+
if not ex_id:
|
| 380 |
+
return "Please provide an ID.", get_examples_table()
|
| 381 |
+
delete_image_set(EXAMPLES_PREFIX, ex_id.strip())
|
| 382 |
+
return "Deleted.", get_examples_table()
|
| 383 |
+
|
| 384 |
+
delete_btn.click(
|
| 385 |
+
on_admin_delete,
|
| 386 |
+
inputs=[delete_id],
|
| 387 |
+
outputs=[admin_status, admin_examples_table],
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# ---- Promote from Uploads ----
|
| 391 |
+
gr.Markdown("---")
|
| 392 |
+
gr.Markdown("## Promote from Uploads")
|
| 393 |
+
gr.Markdown("Select a result to promote. The matching portrait and garment are found automatically.")
|
| 394 |
+
|
| 395 |
+
promote_portrait = gr.State(value=None)
|
| 396 |
+
promote_garments = gr.State(value=[])
|
| 397 |
+
promote_result = gr.State(value=None)
|
| 398 |
+
|
| 399 |
+
promo_result_gallery = gr.Gallery(
|
| 400 |
+
value=_gallery_images(UPLOADS_PREFIX, "results"),
|
| 401 |
+
label="Results",
|
| 402 |
+
columns=4,
|
| 403 |
+
height=200,
|
| 404 |
+
allow_preview=False,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
with gr.Row():
|
| 408 |
+
promo_preview_portrait = gr.Image(label="Portrait", interactive=False, height=150)
|
| 409 |
+
promo_preview_garments = gr.Gallery(label="Garments", columns=4, height=150, allow_preview=False)
|
| 410 |
+
promo_preview_result = gr.Image(label="Result", interactive=False, height=150)
|
| 411 |
+
|
| 412 |
+
promote_btn = gr.Button("Promote to Example", variant="primary")
|
| 413 |
+
promote_status = gr.Textbox(label="Status", interactive=False)
|
| 414 |
+
|
| 415 |
+
def _resolve_result_images(prefix, path):
|
| 416 |
+
"""Parse a result filename and resolve portrait + all garments."""
|
| 417 |
+
result_local = download_to_local(path)
|
| 418 |
+
fname = Path(result_local).name
|
| 419 |
+
# Try single-garment format
|
| 420 |
+
parsed = parse_result_filename(fname)
|
| 421 |
+
if not parsed:
|
| 422 |
+
parsed = parse_result_filename(Path(path).name)
|
| 423 |
+
if parsed:
|
| 424 |
+
try:
|
| 425 |
+
p_url = file_url(f"{prefix}/portraits/{make_filename(parsed['portrait_id'], 'portrait')}")
|
| 426 |
+
g_url = file_url(f"{prefix}/garments/{make_filename(parsed['garment_id'], 'garment')}")
|
| 427 |
+
return download_to_local(p_url), [download_to_local(g_url)], result_local
|
| 428 |
+
except Exception as e:
|
| 429 |
+
gr.Warning(f"Single-garment resolve failed for {fname}: {e}")
|
| 430 |
+
# Try multi-garment format
|
| 431 |
+
multi = parse_multi_result_filename(fname)
|
| 432 |
+
if not multi:
|
| 433 |
+
multi = parse_multi_result_filename(Path(path).name)
|
| 434 |
+
if multi:
|
| 435 |
+
gids = [gid for gid in multi["garment_ids"] if gid is not None]
|
| 436 |
+
if gids:
|
| 437 |
+
try:
|
| 438 |
+
p_url = file_url(f"{prefix}/portraits/{make_filename(multi['portrait_id'], 'portrait')}")
|
| 439 |
+
portrait_local = download_to_local(p_url)
|
| 440 |
+
garment_locals = []
|
| 441 |
+
for gid in gids:
|
| 442 |
+
g_url = file_url(f"{prefix}/garments/{make_filename(gid, 'garment')}")
|
| 443 |
+
garment_locals.append(download_to_local(g_url))
|
| 444 |
+
return portrait_local, garment_locals, result_local
|
| 445 |
+
except Exception as e:
|
| 446 |
+
gr.Warning(f"Multi-garment resolve failed for {fname} (portrait={multi['portrait_id']}, garments={gids}): {e}")
|
| 447 |
+
else:
|
| 448 |
+
gr.Warning(f"Could not parse result filename: {fname}")
|
| 449 |
+
return None, [], result_local
|
| 450 |
+
|
| 451 |
+
def on_result_select(evt: gr.SelectData):
|
| 452 |
+
path = evt.value["image"]["path"]
|
| 453 |
+
portrait_local, garment_locals, result_local = _resolve_result_images(UPLOADS_PREFIX, path)
|
| 454 |
+
if not portrait_local or not garment_locals:
|
| 455 |
+
gr.Warning(f"Could not find matching portrait/garment for: {Path(path).name}")
|
| 456 |
+
return portrait_local, garment_locals, result_local, portrait_local, garment_locals, result_local
|
| 457 |
+
|
| 458 |
+
promo_result_gallery.select(
|
| 459 |
+
on_result_select,
|
| 460 |
+
outputs=[promote_portrait, promote_garments, promote_result,
|
| 461 |
+
promo_preview_portrait, promo_preview_garments, promo_preview_result],
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
def on_promote(portrait_path, garment_paths, result_path):
|
| 465 |
+
if not result_path:
|
| 466 |
+
return "No result to promote.", get_examples_table()
|
| 467 |
+
name = promote_to_example(result_path)
|
| 468 |
+
return f"Promoted: {name}", get_examples_table()
|
| 469 |
+
|
| 470 |
+
promote_btn.click(
|
| 471 |
+
on_promote,
|
| 472 |
+
inputs=[promote_portrait, promote_garments, promote_result],
|
| 473 |
+
outputs=[promote_status, admin_examples_table],
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
refresh_promo_btn = gr.Button("Refresh Results", size="sm")
|
| 477 |
+
refresh_promo_btn.click(
|
| 478 |
+
lambda: _gallery_images(UPLOADS_PREFIX, "results"),
|
| 479 |
+
outputs=[promo_result_gallery],
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
demo.load(get_examples_table, outputs=[admin_examples_table])
|
| 483 |
+
|
| 484 |
+
return demo
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
if __name__ == "__main__":
|
| 488 |
+
def dummy_process(portrait, pool, num_detected, *assignment_args):
|
| 489 |
+
return Image.new("RGB", (512, 512), (200, 200, 200))
|
| 490 |
+
def dummy_detect(portrait_path):
|
| 491 |
+
return [Image.new("RGB", (100, 200), (255, 0, 0)), Image.new("RGB", (100, 200), (0, 255, 0))]
|
| 492 |
+
demo = build_demo(dummy_process, detect_fn=dummy_detect)
|
| 493 |
+
demo.launch()
|
vton_gradio_demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
vton_gradio_demo_remove_background.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|