import spaces import cv2 import numpy as np from PIL import Image import torch from fashn_vton import TryOnPipeline from ultralytics import YOLO import gradio as gr from pathlib import Path import subprocess import sys from scipy.spatial import cKDTree from ui import build_demo class MultiPersonVTON: def __init__(self, weights_dir="./weights"): print("Initializing Multi-Person VTON pipeline...") self.pipeline = TryOnPipeline(weights_dir=weights_dir) self.model = YOLO("yolo26n-seg.pt") print("Pipeline initialized") def get_mask(self, result, H, W): cls_ids = result.boxes.cls.cpu().numpy().astype(int) person_idxs = cls_ids == 0 person_polygons = [poly for poly, keep in zip(result.masks.xy, person_idxs) if keep] masks = [] for poly in person_polygons: mask = np.zeros((H, W), dtype=np.uint8) poly_int = np.round(poly).astype(np.int32) cv2.fillPoly(mask, [poly_int], 1) masks.append(mask.astype(bool)) return masks def extract_people(self, img, masks): img_np = np.array(img) if isinstance(img, Image.Image) else img.copy() people = [] for mask in masks: cutout = img_np.copy() cutout[~mask] = 255 people.append(Image.fromarray(cutout)) return people def apply_vton_to_people(self, people, assignments): """Apply VTON per person based on individual assignments. assignments: list of {"garment": PIL.Image|None, "category": str} per person. If garment is None, person is kept as-is (skipped). """ vton_people = [] for i, person in enumerate(people): garment = assignments[i]["garment"] if garment is not None: result = self.pipeline( person_image=person, garment_image=garment, category=assignments[i]["category"] ) vton_people.append(result.images[0]) else: vton_people.append(person) return vton_people def get_vton_masks(self, vton_people): vton_masks = [] for people in vton_people: people_arr = np.array(people) gray = cv2.cvtColor(people_arr, cv2.COLOR_RGB2GRAY) _, mask = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV) mask = mask.astype(bool) kernel = np.ones((5, 5), np.uint8) mask_clean = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel, iterations=1) mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel, iterations=2) mask_u8 = (mask_clean.astype(np.uint8) * 255) mask_blur = cv2.GaussianBlur(mask_u8, (3, 3), 1) vton_masks.append(mask_blur) return vton_masks def contour_curvature(self, contour, k=5): pts = contour[:, 0, :].astype(np.float32) N = len(pts) curv = np.zeros(N) for i in range(N): p_prev = pts[(i - k) % N] p = pts[i] p_next = pts[(i + k) % N] v1 = p - p_prev v2 = p_next - p v1 /= (np.linalg.norm(v1) + 1e-6) v2 /= (np.linalg.norm(v2) + 1e-6) angle = np.arccos(np.clip(np.dot(v1, v2), -1, 1)) curv[i] = angle return curv def frontness_score(self, mask_a, mask_b): inter = mask_a & mask_b if inter.sum() < 50: return 0.0 cnts_a, _ = cv2.findContours(mask_a.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) cnts_b, _ = cv2.findContours(mask_b.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) if not cnts_a or not cnts_b: return 0.0 ca = max(cnts_a, key=len) cb = max(cnts_b, key=len) curv_a = self.contour_curvature(ca) curv_b = self.contour_curvature(cb) inter_pts = np.column_stack(np.where(inter))[:, ::-1] tree_a = cKDTree(ca[:, 0, :]) tree_b = cKDTree(cb[:, 0, :]) _, idx_a = tree_a.query(inter_pts, k=1) _, idx_b = tree_b.query(inter_pts, k=1) score_a = curv_a[idx_a].mean() score_b = curv_b[idx_b].mean() return score_a - score_b def estimate_front_to_back_order(self, masks): n = len(masks) scores = np.zeros(n) for i in range(n): for j in range(n): if i == j: continue scores[i] += self.frontness_score(masks[i], masks[j]) order = np.argsort(-scores) return order, scores def remove_original_people(self, image, person_masks): image_np = np.array(image) combined_mask = np.zeros(image_np.shape[:2], dtype=np.uint8) for mask in person_masks: combined_mask[mask] = 255 kernel = np.ones((5, 5), np.uint8) combined_mask = cv2.dilate(combined_mask, kernel, iterations=2) inpainted = cv2.inpaint(image_np, combined_mask, 3, cv2.INPAINT_TELEA) return Image.fromarray(inpainted), combined_mask def clean_vton_edges_on_overlap(self, img_pil, mask_uint8, other_masks_uint8, erode_iters=1, edge_dilate=2, inner_erode=2): src = np.array(img_pil).copy() others_union = np.zeros_like(mask_uint8, dtype=np.uint8) for m in other_masks_uint8: others_union = np.maximum(others_union, m) overlap = (mask_uint8 > 0) & (others_union > 0) overlap = overlap.astype(np.uint8) * 255 if overlap.sum() == 0: return img_pil, mask_uint8 kernel = np.ones((3, 3), np.uint8) tight_mask = cv2.erode(mask_uint8, kernel, iterations=erode_iters) edge = cv2.Canny(tight_mask, 50, 150) edge = cv2.dilate(edge, np.ones((3, 3), np.uint8), iterations=edge_dilate) overlap_band = cv2.dilate(overlap, np.ones((5, 5), np.uint8), iterations=1) edge = cv2.bitwise_and(edge, overlap_band) if edge.sum() == 0: return img_pil, tight_mask inner = cv2.erode(tight_mask, np.ones((5, 5), np.uint8), iterations=inner_erode) inner_rgb = cv2.inpaint(src, 255 - inner, 3, cv2.INPAINT_TELEA) src[edge > 0] = inner_rgb[edge > 0] return Image.fromarray(src), tight_mask def clean_masks(self, vton_people, vton_masks): cleaned_vton_people = [] cleaned_vton_masks = [] for i in range(len(vton_people)): other_masks = [m for j, m in enumerate(vton_masks) if j != i] cleaned_img, cleaned_mask = self.clean_vton_edges_on_overlap( vton_people[i], vton_masks[i], other_masks, erode_iters=1, edge_dilate=2, inner_erode=2 ) cleaned_vton_people.append(cleaned_img) cleaned_vton_masks.append(cleaned_mask) return cleaned_vton_people, cleaned_vton_masks def process_group_image(self, group_image, assignments): """Process a group image with per-person garment assignments. assignments: list of {"garment": PIL.Image|None, "category": str} per person. """ print("Step 1: Loading images...") if isinstance(group_image, np.ndarray): group_image = Image.fromarray(group_image) if isinstance(group_image, Image.Image): group_image.save("people.png") img_bgr = cv2.imread("people.png") img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) H, W = img.shape[:2] print("Step 2: Getting segmentation masks with YOLO...") results = self.model("people.png") result = results[0] masks = self.get_mask(result, H, W) print(f"Found {len(masks)} people") print("Step 3: Extracting individual people...") people = self.extract_people(img, masks) # Pad assignments to match detected people count while len(assignments) < len(people): assignments.append({"garment": None, "category": "tops"}) print("Step 4: Applying VTON to people...") vton_people = self.apply_vton_to_people(people, assignments) print("Step 5: Getting masks for VTON results...") vton_masks = self.get_vton_masks(vton_people) for i in range(len(vton_masks)): if assignments[i]["garment"] is None: yolo_mask = (masks[i].astype(np.uint8) * 255) yolo_mask = cv2.GaussianBlur(yolo_mask, (3, 3), 1) vton_masks[i] = yolo_mask order, scores = self.estimate_front_to_back_order(vton_masks) cleaned_vton_people, cleaned_vton_masks = self.clean_masks(vton_people, vton_masks) print("Step 6: Resizing to match dimensions...") img = cv2.resize(img, vton_people[0].size) print("Step 7: Creating clean background by removing original people...") clean_background, person_mask = self.remove_original_people(img, masks) clean_background_np = np.array(clean_background) print("Step 8: Recomposing final image...") recomposed = clean_background_np.copy() for i in order: vton_mask = cleaned_vton_masks[i] img_pil = cleaned_vton_people[i] out = recomposed.astype(np.float32) src = np.array(img_pil).astype(np.float32) alpha = (vton_mask.astype(np.float32) / 255.0)[..., None] src = src * alpha out = src + (1 - alpha) * out recomposed = out.astype(np.uint8) final_image = Image.fromarray(recomposed) return final_image, { "original": Image.fromarray(img), "clean_background": clean_background, "person_mask": Image.fromarray(person_mask), "num_people": len(people), "individual_people": people, "vton_results": cleaned_vton_people, "masks": masks, "vton_masks": cleaned_vton_masks } WEIGHTS_DIR = Path("./weights") def ensure_weights(): if WEIGHTS_DIR.exists() and any(WEIGHTS_DIR.iterdir()): print("Weights already present, skipping download.") return print("Downloading weights...") subprocess.check_call([ sys.executable, "fashn-vton-1.5/scripts/download_weights.py", "--weights-dir", str(WEIGHTS_DIR), ]) ensure_weights() _pipeline = None def get_pipeline(): global _pipeline if _pipeline is None: _pipeline = MultiPersonVTON() return _pipeline @spaces.GPU def detect_people(portrait_path): if portrait_path is None: raise gr.Error("Please select a portrait first.") portrait = Image.open(portrait_path) if isinstance(portrait_path, str) else portrait_path new_width = 576 w, h = portrait.size new_height = int(h * new_width / w) resized = portrait.resize((new_width, new_height), Image.LANCZOS) resized.save("people.png") pipeline = get_pipeline() results = pipeline.model("people.png") result = results[0] img = np.array(resized) H, W = img.shape[:2] masks = pipeline.get_mask(result, H, W) people = pipeline.extract_people(img, masks) return people @spaces.GPU def process_images(selected_portrait, garment_pool, num_detected, *assignment_args): if selected_portrait is None: raise gr.Error("Please select a portrait.") if not garment_pool: raise gr.Error("Please add at least one garment to the pool.") portrait = Image.open(selected_portrait) if isinstance(selected_portrait, str) else selected_portrait pipeline = get_pipeline() new_width = 576 w, h = portrait.size new_height = int(h * new_width / w) resized = portrait.resize((new_width, new_height), Image.LANCZOS) # Build per-person assignments from dropdown/radio values # assignment_args: dd_0, dd_1, ..., dd_7, cat_0, cat_1, ..., cat_7 n = num_detected if num_detected else 0 max_p = len(assignment_args) // 2 pool_by_label = {g["label"]: g for g in garment_pool} assignments = [] for i in range(n): dd_val = assignment_args[i] cat_val = assignment_args[max_p + i] if dd_val == "Skip" or dd_val not in pool_by_label: assignments.append({"garment": None, "category": cat_val or "tops"}) else: g = pool_by_label[dd_val] garment_img = Image.open(g["path"]) if isinstance(g["path"], str) else g["path"] assignments.append({"garment": garment_img, "category": cat_val or "tops"}) result, _ = pipeline.process_group_image(resized, assignments) return result demo = build_demo(process_images, detect_fn=detect_people, max_people=8) from huggingface_hub import constants as hf_constants demo.launch(allowed_paths=[hf_constants.HF_HUB_CACHE])