Spaces:
Running
Running
| import os | |
| import gc | |
| import threading | |
| from typing import Optional | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| # βββββββββββββββββββββββββββ CONFIG ββββββββββββββββββββββββββββ # | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| WEIGHTS_DIR = os.path.join(SCRIPT_DIR, "weights") | |
| EXAMPLES_DIR = os.path.join(SCRIPT_DIR, "examples", "data") | |
| CATEGORIES = ["tops", "bottoms", "one-pieces"] | |
| GARMENT_PHOTO_TYPES = ["model", "flat-lay"] | |
| # ββββββββββββββββββββββββ WEIGHT DOWNLOAD ββββββββββββββββββββββ # | |
| def download_weights(): | |
| """Download model weights from HuggingFace Hub (skips if already present).""" | |
| os.makedirs(WEIGHTS_DIR, exist_ok=True) | |
| dwpose_dir = os.path.join(WEIGHTS_DIR, "dwpose") | |
| os.makedirs(dwpose_dir, exist_ok=True) | |
| tryon_path = os.path.join(WEIGHTS_DIR, "model.safetensors") | |
| if not os.path.exists(tryon_path): | |
| print("Downloading TryOnModel weights...") | |
| hf_hub_download( | |
| repo_id="fashn-ai/fashn-vton-1.5", | |
| filename="model.safetensors", | |
| local_dir=WEIGHTS_DIR, | |
| ) | |
| for filename in ["yolox_l.onnx", "dw-ll_ucoco_384.onnx"]: | |
| filepath = os.path.join(dwpose_dir, filename) | |
| if not os.path.exists(filepath): | |
| print(f"Downloading DWPose/{filename}...") | |
| hf_hub_download( | |
| repo_id="fashn-ai/DWPose", | |
| filename=filename, | |
| local_dir=dwpose_dir, | |
| ) | |
| print("All weights ready!") | |
| # Download weights at startup | |
| download_weights() | |
| # ββββββββββββββββββββββββ PIPELINE LOADER ββββββββββββββββββββββ # | |
| _pipeline_lock = threading.Lock() | |
| _pipeline: Optional[object] = None | |
| def get_pipeline(): | |
| """Thread-safe lazy pipeline loader (CPU mode).""" | |
| global _pipeline | |
| with _pipeline_lock: | |
| if _pipeline is None: | |
| from fashn_vton import TryOnPipeline | |
| print("Loading pipeline on CPU...") | |
| _pipeline = TryOnPipeline(weights_dir=WEIGHTS_DIR, device="cpu") | |
| print("Pipeline ready!") | |
| return _pipeline | |
| # βββββββββββββββββββββββββββ INFERENCE βββββββββββββββββββββββββ # | |
| def try_on( | |
| person_image, | |
| garment_image, | |
| category: str, | |
| garment_photo_type: str, | |
| num_timesteps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| segmentation_free: bool, | |
| ): | |
| """Run virtual try-on inference.""" | |
| if person_image is None: | |
| raise gr.Error("Please upload a person image.") | |
| if garment_image is None: | |
| raise gr.Error("Please upload a garment image.") | |
| # Normalise seed | |
| if seed is None or seed < 0: | |
| seed = 42 | |
| seed = int(seed) | |
| # Ensure PIL RGB | |
| def to_pil(x): | |
| if isinstance(x, np.ndarray): | |
| x = Image.fromarray(x) | |
| if isinstance(x, Image.Image): | |
| return x.convert("RGB") | |
| return Image.open(x).convert("RGB") | |
| person_img = to_pil(person_image) | |
| garment_img = to_pil(garment_image) | |
| pipeline = get_pipeline() | |
| try: | |
| result = pipeline( | |
| person_image=person_img, | |
| garment_image=garment_img, | |
| category=category, | |
| garment_photo_type=garment_photo_type, | |
| num_samples=1, | |
| num_timesteps=num_timesteps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| segmentation_free=segmentation_free, | |
| ) | |
| return result.images[0], "β Done!" | |
| except Exception as e: | |
| return None, f"β Error: {e}" | |
| # βββββββββββββββββββββββββββ GRADIO UI βββββββββββββββββββββββββ # | |
| CUSTOM_CSS = """ | |
| body { font-family: 'Inter', sans-serif; } | |
| .contain img { | |
| object-fit: contain !important; | |
| max-height: 520px !important; | |
| } | |
| #run-btn { | |
| background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%) !important; | |
| border: none !important; | |
| color: white !important; | |
| font-size: 1.1rem !important; | |
| font-weight: 600 !important; | |
| padding: 0.75rem !important; | |
| border-radius: 12px !important; | |
| transition: opacity 0.2s; | |
| } | |
| #run-btn:hover { opacity: 0.85; } | |
| .status-box textarea { | |
| font-size: 0.9rem !important; | |
| color: #a3e635 !important; | |
| background: #1e1e2e !important; | |
| border-radius: 8px !important; | |
| } | |
| .gr-accordion { border-radius: 10px !important; } | |
| """ | |
| BANNER_MD = """ | |
| # π FASHN VTON β Virtual Try-On | |
| Upload a **person image** and a **garment image**, choose the garment category and hit **Try On**! | |
| > β οΈ Running on **CPU** β inference may take a few minutes. Reduce *Sampling Steps* for faster results. | |
| """ | |
| TIPS_HTML = """ | |
| <div style="display: flex; justify-content: center; align-items: center; gap: 1rem; flex-wrap: wrap; margin-bottom: 20px; font-size: 0.95rem; color: #a1a1aa;"> | |
| <div style="font-weight: 600; color: #e4e4e7;">π‘ Tips for best results:</div> | |
| <div>π€ Single person, clearly visible</div> | |
| <div style="color: #52525b;">|</div> | |
| <div>π Match category to garment type</div> | |
| <div style="color: #52525b;">|</div> | |
| <div>πΈ Use "flat-lay" for product shots</div> | |
| <div style="color: #52525b;">|</div> | |
| <div>π 2:3 aspect ratio optimal</div> | |
| </div> | |
| """ | |
| person_example = os.path.join(EXAMPLES_DIR, "model.jpeg") | |
| garment_example = os.path.join(EXAMPLES_DIR, "garment.jpeg") | |
| with gr.Blocks(css=CUSTOM_CSS, title="FASHN VTON β Virtual Try-On") as demo: | |
| gr.Markdown(BANNER_MD) | |
| gr.HTML(TIPS_HTML) | |
| with gr.Row(equal_height=False): | |
| # ββ Column 1 : Person ββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| person_in = gr.Image( | |
| label="Person Image", | |
| type="pil", | |
| sources=["upload", "clipboard"], | |
| elem_classes=["contain"], | |
| ) | |
| if os.path.exists(person_example): | |
| gr.Examples( | |
| examples=[[person_example]], | |
| inputs=[person_in], | |
| label="Person Example", | |
| ) | |
| # ββ Column 2 : Garment βββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| garment_in = gr.Image( | |
| label="Garment Image", | |
| type="pil", | |
| sources=["upload", "clipboard"], | |
| elem_classes=["contain"], | |
| ) | |
| with gr.Row(): | |
| category = gr.Dropdown( | |
| choices=CATEGORIES, | |
| value="tops", | |
| label="Category", | |
| ) | |
| garment_photo_type = gr.Dropdown( | |
| choices=GARMENT_PHOTO_TYPES, | |
| value="model", | |
| label="Photo Type", | |
| ) | |
| if os.path.exists(garment_example): | |
| gr.Examples( | |
| examples=[[garment_example]], | |
| inputs=[garment_in], | |
| label="Garment Example", | |
| ) | |
| # ββ Column 3 : Result ββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| result_img = gr.Image( | |
| label="Try-On Result", | |
| type="pil", | |
| interactive=False, | |
| elem_classes=["contain"], | |
| ) | |
| status = gr.Textbox( | |
| value="Ready", | |
| label="Status", | |
| interactive=False, | |
| elem_classes=["status-box"], | |
| ) | |
| run_btn = gr.Button("π Try On", variant="primary", elem_id="run-btn") | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| num_timesteps = gr.Slider( | |
| minimum=10, maximum=50, value=30, step=5, | |
| label="Sampling Steps", | |
| info="Higher = better quality but slower. 30 is a good balance.", | |
| ) | |
| guidance_scale = gr.Slider( | |
| minimum=1.0, maximum=3.0, value=1.5, step=0.1, | |
| label="Guidance Scale", | |
| info="How closely to follow the garment details. 1.5 recommended.", | |
| ) | |
| seed = gr.Number( | |
| value=42, label="Seed", precision=0, | |
| info="Change seed to get a different variation of the result.", | |
| ) | |
| segmentation_free = gr.Checkbox( | |
| value=True, | |
| label="Segmentation-Free (Recommended)", | |
| info="Preserves body features and allows unconstrained garment volume.", | |
| ) | |
| # ββ Event ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| run_btn.click( | |
| fn=try_on, | |
| inputs=[ | |
| person_in, garment_in, | |
| category, garment_photo_type, | |
| num_timesteps, guidance_scale, | |
| seed, segmentation_free, | |
| ], | |
| outputs=[result_img, status], | |
| ) | |
| demo.queue(default_concurrency_limit=1, max_size=10) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", share=False) | |