File size: 12,893 Bytes
6933b0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import spaces
import cv2
import numpy as np
from PIL import Image
import torch
from fashn_vton import TryOnPipeline
from ultralytics import YOLO
import gradio as gr
from pathlib import Path
import subprocess
import sys
from scipy.spatial import cKDTree
from ui import build_demo


class MultiPersonVTON:
    def __init__(self, weights_dir="./weights"):
        print("Initializing Multi-Person VTON pipeline...")
        self.pipeline = TryOnPipeline(weights_dir=weights_dir)
        self.model = YOLO("yolo26n-seg.pt")
        print("Pipeline initialized")

    def get_mask(self, result, H, W):
        cls_ids = result.boxes.cls.cpu().numpy().astype(int)
        person_idxs = cls_ids == 0
        person_polygons = [poly for poly, keep in zip(result.masks.xy, person_idxs) if keep]
        masks = []
        for poly in person_polygons:
            mask = np.zeros((H, W), dtype=np.uint8)
            poly_int = np.round(poly).astype(np.int32)
            cv2.fillPoly(mask, [poly_int], 1)
            masks.append(mask.astype(bool))
        return masks

    def extract_people(self, img, masks):
        img_np = np.array(img) if isinstance(img, Image.Image) else img.copy()
        people = []
        for mask in masks:
            cutout = img_np.copy()
            cutout[~mask] = 255
            people.append(Image.fromarray(cutout))
        return people

    def apply_vton_to_people(self, people, assignments):
        """Apply VTON per person based on individual assignments.

        assignments: list of {"garment": PIL.Image|None, "category": str} per person.
        If garment is None, person is kept as-is (skipped).
        """
        vton_people = []
        for i, person in enumerate(people):
            garment = assignments[i]["garment"]
            if garment is not None:
                result = self.pipeline(
                    person_image=person,
                    garment_image=garment,
                    category=assignments[i]["category"]
                )
                vton_people.append(result.images[0])
            else:
                vton_people.append(person)
        return vton_people

    def get_vton_masks(self, vton_people):
        vton_masks = []
        for people in vton_people:
            people_arr = np.array(people)
            gray = cv2.cvtColor(people_arr, cv2.COLOR_RGB2GRAY)
            _, mask = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)
            mask = mask.astype(bool)
            kernel = np.ones((5, 5), np.uint8)
            mask_clean = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel, iterations=1)
            mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel, iterations=2)
            mask_u8 = (mask_clean.astype(np.uint8) * 255)
            mask_blur = cv2.GaussianBlur(mask_u8, (3, 3), 1)
            vton_masks.append(mask_blur)
        return vton_masks

    def contour_curvature(self, contour, k=5):
        pts = contour[:, 0, :].astype(np.float32)
        N = len(pts)
        curv = np.zeros(N)
        for i in range(N):
            p_prev = pts[(i - k) % N]
            p = pts[i]
            p_next = pts[(i + k) % N]
            v1 = p - p_prev
            v2 = p_next - p
            v1 /= (np.linalg.norm(v1) + 1e-6)
            v2 /= (np.linalg.norm(v2) + 1e-6)
            angle = np.arccos(np.clip(np.dot(v1, v2), -1, 1))
            curv[i] = angle
        return curv

    def frontness_score(self, mask_a, mask_b):
        inter = mask_a & mask_b
        if inter.sum() < 50:
            return 0.0
        cnts_a, _ = cv2.findContours(mask_a.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        cnts_b, _ = cv2.findContours(mask_b.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        if not cnts_a or not cnts_b:
            return 0.0
        ca = max(cnts_a, key=len)
        cb = max(cnts_b, key=len)
        curv_a = self.contour_curvature(ca)
        curv_b = self.contour_curvature(cb)
        inter_pts = np.column_stack(np.where(inter))[:, ::-1]
        tree_a = cKDTree(ca[:, 0, :])
        tree_b = cKDTree(cb[:, 0, :])
        _, idx_a = tree_a.query(inter_pts, k=1)
        _, idx_b = tree_b.query(inter_pts, k=1)
        score_a = curv_a[idx_a].mean()
        score_b = curv_b[idx_b].mean()
        return score_a - score_b

    def estimate_front_to_back_order(self, masks):
        n = len(masks)
        scores = np.zeros(n)
        for i in range(n):
            for j in range(n):
                if i == j:
                    continue
                scores[i] += self.frontness_score(masks[i], masks[j])
        order = np.argsort(-scores)
        return order, scores

    def remove_original_people(self, image, person_masks):
        image_np = np.array(image)
        combined_mask = np.zeros(image_np.shape[:2], dtype=np.uint8)
        for mask in person_masks:
            combined_mask[mask] = 255
        kernel = np.ones((5, 5), np.uint8)
        combined_mask = cv2.dilate(combined_mask, kernel, iterations=2)
        inpainted = cv2.inpaint(image_np, combined_mask, 3, cv2.INPAINT_TELEA)
        return Image.fromarray(inpainted), combined_mask

    def clean_vton_edges_on_overlap(self, img_pil, mask_uint8, other_masks_uint8,
                                    erode_iters=1, edge_dilate=2, inner_erode=2):
        src = np.array(img_pil).copy()
        others_union = np.zeros_like(mask_uint8, dtype=np.uint8)
        for m in other_masks_uint8:
            others_union = np.maximum(others_union, m)
        overlap = (mask_uint8 > 0) & (others_union > 0)
        overlap = overlap.astype(np.uint8) * 255
        if overlap.sum() == 0:
            return img_pil, mask_uint8
        kernel = np.ones((3, 3), np.uint8)
        tight_mask = cv2.erode(mask_uint8, kernel, iterations=erode_iters)
        edge = cv2.Canny(tight_mask, 50, 150)
        edge = cv2.dilate(edge, np.ones((3, 3), np.uint8), iterations=edge_dilate)
        overlap_band = cv2.dilate(overlap, np.ones((5, 5), np.uint8), iterations=1)
        edge = cv2.bitwise_and(edge, overlap_band)
        if edge.sum() == 0:
            return img_pil, tight_mask
        inner = cv2.erode(tight_mask, np.ones((5, 5), np.uint8), iterations=inner_erode)
        inner_rgb = cv2.inpaint(src, 255 - inner, 3, cv2.INPAINT_TELEA)
        src[edge > 0] = inner_rgb[edge > 0]
        return Image.fromarray(src), tight_mask

    def clean_masks(self, vton_people, vton_masks):
        cleaned_vton_people = []
        cleaned_vton_masks = []
        for i in range(len(vton_people)):
            other_masks = [m for j, m in enumerate(vton_masks) if j != i]
            cleaned_img, cleaned_mask = self.clean_vton_edges_on_overlap(
                vton_people[i], vton_masks[i], other_masks,
                erode_iters=1, edge_dilate=2, inner_erode=2
            )
            cleaned_vton_people.append(cleaned_img)
            cleaned_vton_masks.append(cleaned_mask)
        return cleaned_vton_people, cleaned_vton_masks

    def process_group_image(self, group_image, assignments):
        """Process a group image with per-person garment assignments.

        assignments: list of {"garment": PIL.Image|None, "category": str} per person.
        """
        print("Step 1: Loading images...")
        if isinstance(group_image, np.ndarray):
            group_image = Image.fromarray(group_image)
        if isinstance(group_image, Image.Image):
            group_image.save("people.png")

        img_bgr = cv2.imread("people.png")
        img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        H, W = img.shape[:2]

        print("Step 2: Getting segmentation masks with YOLO...")
        results = self.model("people.png")
        result = results[0]
        masks = self.get_mask(result, H, W)
        print(f"Found {len(masks)} people")

        print("Step 3: Extracting individual people...")
        people = self.extract_people(img, masks)

        # Pad assignments to match detected people count
        while len(assignments) < len(people):
            assignments.append({"garment": None, "category": "tops"})

        print("Step 4: Applying VTON to people...")
        vton_people = self.apply_vton_to_people(people, assignments)

        print("Step 5: Getting masks for VTON results...")
        vton_masks = self.get_vton_masks(vton_people)
        for i in range(len(vton_masks)):
            if assignments[i]["garment"] is None:
                yolo_mask = (masks[i].astype(np.uint8) * 255)
                yolo_mask = cv2.GaussianBlur(yolo_mask, (3, 3), 1)
                vton_masks[i] = yolo_mask
        order, scores = self.estimate_front_to_back_order(vton_masks)
        cleaned_vton_people, cleaned_vton_masks = self.clean_masks(vton_people, vton_masks)

        print("Step 6: Resizing to match dimensions...")
        img = cv2.resize(img, vton_people[0].size)

        print("Step 7: Creating clean background by removing original people...")
        clean_background, person_mask = self.remove_original_people(img, masks)
        clean_background_np = np.array(clean_background)

        print("Step 8: Recomposing final image...")
        recomposed = clean_background_np.copy()
        for i in order:
            vton_mask = cleaned_vton_masks[i]
            img_pil = cleaned_vton_people[i]
            out = recomposed.astype(np.float32)
            src = np.array(img_pil).astype(np.float32)
            alpha = (vton_mask.astype(np.float32) / 255.0)[..., None]
            src = src * alpha
            out = src + (1 - alpha) * out
            recomposed = out.astype(np.uint8)

        final_image = Image.fromarray(recomposed)
        return final_image, {
            "original": Image.fromarray(img),
            "clean_background": clean_background,
            "person_mask": Image.fromarray(person_mask),
            "num_people": len(people),
            "individual_people": people,
            "vton_results": cleaned_vton_people,
            "masks": masks,
            "vton_masks": cleaned_vton_masks
        }


WEIGHTS_DIR = Path("./weights")

def ensure_weights():
    if WEIGHTS_DIR.exists() and any(WEIGHTS_DIR.iterdir()):
        print("Weights already present, skipping download.")
        return
    print("Downloading weights...")
    subprocess.check_call([
        sys.executable,
        "fashn-vton-1.5/scripts/download_weights.py",
        "--weights-dir",
        str(WEIGHTS_DIR),
    ])

ensure_weights()

_pipeline = None

def get_pipeline():
    global _pipeline
    if _pipeline is None:
        _pipeline = MultiPersonVTON()
    return _pipeline

@spaces.GPU
def detect_people(portrait_path):
    if portrait_path is None:
        raise gr.Error("Please select a portrait first.")
    portrait = Image.open(portrait_path) if isinstance(portrait_path, str) else portrait_path
    new_width = 576
    w, h = portrait.size
    new_height = int(h * new_width / w)
    resized = portrait.resize((new_width, new_height), Image.LANCZOS)
    resized.save("people.png")
    pipeline = get_pipeline()
    results = pipeline.model("people.png")
    result = results[0]
    img = np.array(resized)
    H, W = img.shape[:2]
    masks = pipeline.get_mask(result, H, W)
    people = pipeline.extract_people(img, masks)
    return people

@spaces.GPU
def process_images(selected_portrait, garment_pool, num_detected, *assignment_args):
    if selected_portrait is None:
        raise gr.Error("Please select a portrait.")
    if not garment_pool:
        raise gr.Error("Please add at least one garment to the pool.")
    portrait = Image.open(selected_portrait) if isinstance(selected_portrait, str) else selected_portrait
    pipeline = get_pipeline()
    new_width = 576
    w, h = portrait.size
    new_height = int(h * new_width / w)
    resized = portrait.resize((new_width, new_height), Image.LANCZOS)

    # Build per-person assignments from dropdown/radio values
    # assignment_args: dd_0, dd_1, ..., dd_7, cat_0, cat_1, ..., cat_7
    n = num_detected if num_detected else 0
    max_p = len(assignment_args) // 2
    pool_by_label = {g["label"]: g for g in garment_pool}
    assignments = []
    for i in range(n):
        dd_val = assignment_args[i]
        cat_val = assignment_args[max_p + i]
        if dd_val == "Skip" or dd_val not in pool_by_label:
            assignments.append({"garment": None, "category": cat_val or "tops"})
        else:
            g = pool_by_label[dd_val]
            garment_img = Image.open(g["path"]) if isinstance(g["path"], str) else g["path"]
            assignments.append({"garment": garment_img, "category": cat_val or "tops"})

    result, _ = pipeline.process_group_image(resized, assignments)
    return result

demo = build_demo(process_images, detect_fn=detect_people, max_people=8)
from huggingface_hub import constants as hf_constants
demo.launch(allowed_paths=[hf_constants.HF_HUB_CACHE])