import gradio as gr from PIL import Image from pathlib import Path from storage import ( save_multi_result, delete_image_set, promote_to_example, list_gallery_urls, download_to_local, is_dataset_url, save_image, upload_image, make_filename, generate_id, parse_filename, parse_result_filename, parse_multi_result_filename, file_url, LOCAL_DATA, ) EXAMPLES_PREFIX = "examples" UPLOADS_PREFIX = "user_uploads" for subdir in ["portraits", "garments", "results"]: (LOCAL_DATA / EXAMPLES_PREFIX / subdir).mkdir(parents=True, exist_ok=True) (LOCAL_DATA / UPLOADS_PREFIX / subdir).mkdir(parents=True, exist_ok=True) def _gallery_images(prefix, subdir): return list_gallery_urls(prefix, subdir) MAX_PEOPLE = 8 def build_demo(process_fn, detect_fn=None, max_people=MAX_PEOPLE): def process_and_save(portrait_path, garment_pool, num_detected, *assignment_args): if portrait_path is None: raise gr.Error("Please select a portrait.") if not garment_pool: raise gr.Error("Please add at least one garment to the pool.") result = process_fn(portrait_path, garment_pool, num_detected, *assignment_args) if result and portrait_path: p_parsed = parse_filename(Path(portrait_path).name) if p_parsed: # Upload portrait if local local = Path(portrait_path) if local.exists() and str(local).startswith(str(LOCAL_DATA)): upload_image(local, str(local.relative_to(LOCAL_DATA))) # Upload all garments in pool for g in garment_pool: garment_local = Path(g["path"]) if garment_local.exists() and str(garment_local).startswith(str(LOCAL_DATA)): upload_image(garment_local, str(garment_local.relative_to(LOCAL_DATA))) # Build garment ID list for filename n = num_detected if num_detected else 0 max_p = len(assignment_args) // 2 pool_by_label = {g["label"]: g for g in garment_pool} garment_ids = [] for i in range(n): dd_val = assignment_args[i] if dd_val == "Skip" or dd_val not in pool_by_label: garment_ids.append(None) else: g = pool_by_label[dd_val] g_parsed = parse_filename(Path(g["path"]).name) garment_ids.append(g_parsed["id"] if g_parsed else generate_id()) save_multi_result(UPLOADS_PREFIX, p_parsed["id"], garment_ids, result) return result with gr.Blocks(title="Multi-Person Virtual Try-On") as demo: with gr.Tabs(): # ---- Main VTON Tab ---- with gr.Tab("Virtual Try-On"): gr.Markdown("# Multi-Person Virtual Try-On") gr.Markdown("Select a portrait, add garments to the pool, detect people, and assign garments.") selected_portrait = gr.State(value=None) garment_pool = gr.State(value=[]) num_detected = gr.State(value=0) garment_counter = gr.State(value=0) gr.Markdown("**Step 1:** Select or upload a portrait, and select or upload garments to the pool.") with gr.Row(): with gr.Column(): gr.Markdown("### Portrait") portrait_gallery = gr.Gallery( value=_gallery_images(UPLOADS_PREFIX, "portraits"), label="Uploaded Portraits", columns=4, height=200, allow_preview=False, ) with gr.Accordion("Upload new portrait", open=False): portrait_upload = gr.Image(type="pil", label="Upload Portrait", sources=["upload", "webcam"]) with gr.Row(): portrait_url_input = gr.Textbox(label="Or paste image URL", scale=4) portrait_url_btn = gr.Button("Load", size="sm", scale=1) preview_portrait = gr.Image(label="Selected Portrait", interactive=False, height=250) # Garment pool section with gr.Column(): gr.Markdown("### Garment Pool") garment_gallery = gr.Gallery( value=_gallery_images(UPLOADS_PREFIX, "garments"), label="Available Garments (click to add to pool)", columns=4, height=200, allow_preview=False, ) with gr.Accordion("Upload new garment", open=False): garment_upload = gr.Image(type="pil", label="Upload Garment", sources=["upload", "webcam"]) with gr.Row(): garment_url_input = gr.Textbox(label="Or paste image URL", scale=4) garment_url_btn = gr.Button("Load", size="sm", scale=1) garment_pool_gallery = gr.Gallery( label="Current Pool", columns=6, height=250, allow_preview=False, ) clear_pool_btn = gr.Button("Clear Pool", size="sm", variant="stop") gr.Markdown("**Step 2:** Detect people in the portrait. This lets you choose which garment goes on each person.") detect_btn = gr.Button("Detect People", variant="secondary") detect_status = gr.Textbox(interactive=False, show_label=False, value="") people_gallery = gr.Gallery( label="Detected People", columns=6, height=300, allow_preview=False, ) gr.Markdown("**Step 3:** Assign garments to each person, then click Try On.") @gr.render(inputs=[num_detected, garment_pool]) def render_assignments(n_detected, pool): n = n_detected or 0 choices = ["Skip"] + [g["label"] for g in (pool or [])] default_garment = choices[1] if len(choices) > 1 else "Skip" if n > 0: gr.Markdown(f"### Assign Garments to {n} {'Person' if n == 1 else 'People'}") dds = [] cats = [] for i in range(n): with gr.Row(): dd = gr.Dropdown( choices=choices, value=default_garment, label=f"Person {i + 1} — Garment", scale=3, interactive=True, ) cat = gr.Radio( choices=["tops", "bottoms", "one-pieces"], value="tops", label="Category", scale=2, interactive=True, ) dds.append(dd) cats.append(cat) submit_btn = gr.Button("Try On", variant="primary") result_image = gr.Image(type="pil", label="Result", interactive=False) submit_btn.click( process_and_save, inputs=[selected_portrait, garment_pool, num_detected] + dds + cats, outputs=result_image, ) # Examples section gr.Markdown("---") gr.Markdown("### Examples") example_sets = gr.State(value=[]) refresh_examples_btn = gr.Button("Refresh Examples", size="sm") @gr.render(inputs=[example_sets]) def render_examples(sets): for i, ex in enumerate(sets or []): with gr.Row(): gr.Image(value=ex["portrait"], label="Portrait", height=200, interactive=False, scale=1) for j, g in enumerate(ex["garments"]): gr.Image(value=g, label=f"Garment {j+1}", height=200, interactive=False, scale=1) gr.Image(value=ex["result"], label="Result", height=200, interactive=False, scale=1) use_btn = gr.Button("Use", size="sm", scale=0, min_width=60) use_btn.click( lambda p=ex["portrait"], gs=ex["garments"]: _load_example(p, gs), outputs=[selected_portrait, preview_portrait, garment_pool, garment_counter, garment_pool_gallery], ) # -- Event handlers -- def on_portrait_gallery_select(evt: gr.SelectData): path = evt.value["image"]["path"] local_path = download_to_local(path) return local_path, local_path def on_garment_gallery_select(evt: gr.SelectData, pool, counter): """Add selected garment to pool.""" path = evt.value["image"]["path"] local_path = download_to_local(path) new_counter = counter + 1 label = f"Garment {new_counter}" new_pool = pool + [{"path": local_path, "label": label}] pool_images = [g["path"] for g in new_pool] return new_pool, new_counter, pool_images def _load_example(portrait_path, garment_paths): pool = [{"path": g, "label": f"Garment {i+1}"} for i, g in enumerate(garment_paths)] pool_images = [g["path"] for g in pool] return portrait_path, portrait_path, pool, len(pool), pool_images def clear_pool(): return [], 0, [], _gallery_images(UPLOADS_PREFIX, "garments") def reset_detection(): return "", [], 0 detection_reset_outputs = [detect_status, people_gallery, num_detected] portrait_gallery.select( on_portrait_gallery_select, outputs=[selected_portrait, preview_portrait] ).then(reset_detection, outputs=detection_reset_outputs) garment_gallery.select( on_garment_gallery_select, inputs=[garment_pool, garment_counter], outputs=[garment_pool, garment_counter, garment_pool_gallery], ) clear_pool_btn.click( clear_pool, outputs=[garment_pool, garment_counter, garment_pool_gallery, garment_gallery], ) def on_portrait_upload(img, current_pool): if img is None: return _gallery_images(UPLOADS_PREFIX, "portraits"), None, None, None item_id = generate_id() fname = make_filename(item_id, "portrait") local_path = LOCAL_DATA / UPLOADS_PREFIX / "portraits" / fname save_image(img, local_path) path = str(local_path) return _gallery_images(UPLOADS_PREFIX, "portraits"), path, path, None def on_garment_upload(img, pool, counter): if img is None: return _gallery_images(UPLOADS_PREFIX, "garments"), pool, counter, [g["path"] for g in pool], None item_id = generate_id() fname = make_filename(item_id, "garment") local_path = LOCAL_DATA / UPLOADS_PREFIX / "garments" / fname save_image(img, local_path) path = str(local_path) new_counter = counter + 1 label = f"Garment {new_counter}" new_pool = pool + [{"path": path, "label": label}] pool_images = [g["path"] for g in new_pool] return _gallery_images(UPLOADS_PREFIX, "garments"), new_pool, new_counter, pool_images, None def on_portrait_url(url, pool): if not url or not url.strip(): return _gallery_images(UPLOADS_PREFIX, "portraits"), None, None, "" local_path = download_to_local(url.strip()) if not is_dataset_url(url.strip()): from PIL import Image as PILImage item_id = generate_id() fname = make_filename(item_id, "portrait") dest = LOCAL_DATA / UPLOADS_PREFIX / "portraits" / fname save_image(PILImage.open(local_path), dest) local_path = str(dest) return _gallery_images(UPLOADS_PREFIX, "portraits"), local_path, local_path, "" def on_garment_url(url, pool, counter): if not url or not url.strip(): return _gallery_images(UPLOADS_PREFIX, "garments"), pool, counter, [g["path"] for g in pool], "" local_path = download_to_local(url.strip()) if not is_dataset_url(url.strip()): from PIL import Image as PILImage item_id = generate_id() fname = make_filename(item_id, "garment") dest = LOCAL_DATA / UPLOADS_PREFIX / "garments" / fname save_image(PILImage.open(local_path), dest) local_path = str(dest) new_counter = counter + 1 label = f"Garment {new_counter}" new_pool = pool + [{"path": local_path, "label": label}] pool_images = [g["path"] for g in new_pool] return _gallery_images(UPLOADS_PREFIX, "garments"), new_pool, new_counter, pool_images, "" portrait_upload.change( on_portrait_upload, inputs=[portrait_upload, garment_pool], outputs=[portrait_gallery, selected_portrait, preview_portrait, portrait_upload], ).then(reset_detection, outputs=detection_reset_outputs) garment_upload.change( on_garment_upload, inputs=[garment_upload, garment_pool, garment_counter], outputs=[garment_gallery, garment_pool, garment_counter, garment_pool_gallery, garment_upload], ) portrait_url_btn.click( on_portrait_url, inputs=[portrait_url_input, garment_pool], outputs=[portrait_gallery, selected_portrait, preview_portrait, portrait_url_input], ).then(reset_detection, outputs=detection_reset_outputs) garment_url_btn.click( on_garment_url, inputs=[garment_url_input, garment_pool, garment_counter], outputs=[garment_gallery, garment_pool, garment_counter, garment_pool_gallery, garment_url_input], ) def on_detect(portrait_path, pool): if detect_fn is None or portrait_path is None: raise gr.Error("Please select a portrait first.") people = detect_fn(portrait_path) n = len(people) return f"Found {n} {'person' if n == 1 else 'people'}", people, n detect_btn.click( lambda: "Detecting people...", outputs=[detect_status], ).then( on_detect, inputs=[selected_portrait, garment_pool], outputs=[detect_status, people_gallery, num_detected], ) def refresh_examples(): result_urls = _gallery_images(EXAMPLES_PREFIX, "results") sets = [] for r in result_urls: portrait, garments, result = _resolve_result_images(UPLOADS_PREFIX, r) if portrait and garments: sets.append({"portrait": portrait, "garments": garments, "result": result}) return sets refresh_examples_btn.click( refresh_examples, outputs=[example_sets], ) demo.load( refresh_examples, outputs=[example_sets], ) # ---- Admin Tab ---- with gr.Tab("Admin - Manage Examples"): admin_status = gr.Textbox(label="Status", interactive=False) gr.Markdown("### Current Examples") admin_examples_table = gr.Dataframe( headers=["ID", "Result Filename"], label="Examples", interactive=False, ) with gr.Row(): delete_id = gr.Textbox(label="Example ID to delete", scale=3) delete_btn = gr.Button("Delete", variant="stop", scale=1) def get_examples_table(): results = list_gallery_urls(EXAMPLES_PREFIX, "results") rows = [] for r in results: fname = Path(r).name parsed = parse_result_filename(fname) or parse_multi_result_filename(fname) rid = parsed["portrait_id"] if parsed else Path(fname).stem rows.append([rid, fname]) return rows def on_admin_delete(ex_id): if not ex_id: return "Please provide an ID.", get_examples_table() delete_image_set(EXAMPLES_PREFIX, ex_id.strip()) return "Deleted.", get_examples_table() delete_btn.click( on_admin_delete, inputs=[delete_id], outputs=[admin_status, admin_examples_table], ) # ---- Promote from Uploads ---- gr.Markdown("---") gr.Markdown("## Promote from Uploads") gr.Markdown("Select a result to promote. The matching portrait and garment are found automatically.") promote_portrait = gr.State(value=None) promote_garments = gr.State(value=[]) promote_result = gr.State(value=None) promo_result_gallery = gr.Gallery( value=_gallery_images(UPLOADS_PREFIX, "results"), label="Results", columns=4, height=200, allow_preview=False, ) with gr.Row(): promo_preview_portrait = gr.Image(label="Portrait", interactive=False, height=150) promo_preview_garments = gr.Gallery(label="Garments", columns=4, height=150, allow_preview=False) promo_preview_result = gr.Image(label="Result", interactive=False, height=150) promote_btn = gr.Button("Promote to Example", variant="primary") promote_status = gr.Textbox(label="Status", interactive=False) def _resolve_result_images(prefix, path): """Parse a result filename and resolve portrait + all garments.""" result_local = download_to_local(path) fname = Path(result_local).name # Try single-garment format parsed = parse_result_filename(fname) if not parsed: parsed = parse_result_filename(Path(path).name) if parsed: try: p_url = file_url(f"{prefix}/portraits/{make_filename(parsed['portrait_id'], 'portrait')}") g_url = file_url(f"{prefix}/garments/{make_filename(parsed['garment_id'], 'garment')}") return download_to_local(p_url), [download_to_local(g_url)], result_local except Exception as e: gr.Warning(f"Single-garment resolve failed for {fname}: {e}") # Try multi-garment format multi = parse_multi_result_filename(fname) if not multi: multi = parse_multi_result_filename(Path(path).name) if multi: gids = [gid for gid in multi["garment_ids"] if gid is not None] if gids: try: p_url = file_url(f"{prefix}/portraits/{make_filename(multi['portrait_id'], 'portrait')}") portrait_local = download_to_local(p_url) garment_locals = [] for gid in gids: g_url = file_url(f"{prefix}/garments/{make_filename(gid, 'garment')}") garment_locals.append(download_to_local(g_url)) return portrait_local, garment_locals, result_local except Exception as e: gr.Warning(f"Multi-garment resolve failed for {fname} (portrait={multi['portrait_id']}, garments={gids}): {e}") else: gr.Warning(f"Could not parse result filename: {fname}") return None, [], result_local def on_result_select(evt: gr.SelectData): path = evt.value["image"]["path"] portrait_local, garment_locals, result_local = _resolve_result_images(UPLOADS_PREFIX, path) if not portrait_local or not garment_locals: gr.Warning(f"Could not find matching portrait/garment for: {Path(path).name}") return portrait_local, garment_locals, result_local, portrait_local, garment_locals, result_local promo_result_gallery.select( on_result_select, outputs=[promote_portrait, promote_garments, promote_result, promo_preview_portrait, promo_preview_garments, promo_preview_result], ) def on_promote(portrait_path, garment_paths, result_path): if not result_path: return "No result to promote.", get_examples_table() name = promote_to_example(result_path) return f"Promoted: {name}", get_examples_table() promote_btn.click( on_promote, inputs=[promote_portrait, promote_garments, promote_result], outputs=[promote_status, admin_examples_table], ) refresh_promo_btn = gr.Button("Refresh Results", size="sm") refresh_promo_btn.click( lambda: _gallery_images(UPLOADS_PREFIX, "results"), outputs=[promo_result_gallery], ) demo.load(get_examples_table, outputs=[admin_examples_table]) return demo if __name__ == "__main__": def dummy_process(portrait, pool, num_detected, *assignment_args): return Image.new("RGB", (512, 512), (200, 200, 200)) def dummy_detect(portrait_path): return [Image.new("RGB", (100, 200), (255, 0, 0)), Image.new("RGB", (100, 200), (0, 255, 0))] demo = build_demo(dummy_process, detect_fn=dummy_detect) demo.launch()