Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import os | |
| import sys | |
| import gc | |
| from huggingface_hub import snapshot_download | |
| import numpy as np | |
| # Add CatVTON to path | |
| sys.path.insert(0, './CatVTON') | |
| from model.pipeline import CatVTONPipeline | |
| from model.cloth_masker import AutoMasker | |
| from utils import init_weight_dtype, resize_and_crop, resize_and_padding | |
| class CatVTONService: | |
| def __init__(self): | |
| # Auto-detect device | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"π₯οΈ Using device: {self.device}") | |
| self.pipeline = None | |
| self.automasker = None | |
| self.models_loaded = False | |
| def load_models(self): | |
| """Load models once and cache them""" | |
| if self.models_loaded: | |
| return | |
| print("π Loading CatVTON models (this happens once)...") | |
| try: | |
| # Download model weights from HuggingFace Hub - CACHED automatically | |
| repo_path = snapshot_download( | |
| repo_id="zhengchong/CatVTON", | |
| cache_dir="./model_cache", | |
| resume_download=True, # Resume if interrupted | |
| local_files_only=False # Allow downloading | |
| ) | |
| print(f"β Models downloaded to: {repo_path}") | |
| # Determine weight dtype based on device | |
| weight_dtype = init_weight_dtype("fp16" if self.device == "cuda" else "fp32") | |
| use_tf32 = self.device == "cuda" # Only use TF32 on CUDA | |
| print(f"βοΈ Weight dtype: {weight_dtype}, TF32: {use_tf32}") | |
| # Initialize pipeline | |
| self.pipeline = CatVTONPipeline( | |
| base_ckpt="booksforcharlie/stable-diffusion-inpainting", | |
| attn_ckpt=repo_path, | |
| attn_ckpt_version="mix", | |
| weight_dtype=weight_dtype, | |
| use_tf32=use_tf32, | |
| device=self.device | |
| ) | |
| # Initialize automasker | |
| self.automasker = AutoMasker( | |
| densepose_ckpt=os.path.join(repo_path, "DensePose"), | |
| schp_ckpt=os.path.join(repo_path, "SCHP"), | |
| device=self.device | |
| ) | |
| self.models_loaded = True | |
| print("β CatVTON ready!") | |
| except Exception as e: | |
| print(f"β Error loading models: {e}") | |
| raise | |
| def generate_tryon(self, person_image, garment_image, progress=gr.Progress()): | |
| """Generate virtual try-on result""" | |
| try: | |
| # Load models if not already loaded | |
| progress(0, desc="Loading models...") | |
| self.load_models() | |
| # Validate inputs | |
| if person_image is None or garment_image is None: | |
| return None, "β Please upload both person and garment images!" | |
| progress(0.2, desc="Processing images...") | |
| # Convert to PIL Images | |
| if isinstance(person_image, np.ndarray): | |
| person_img = Image.fromarray(person_image).convert("RGB") | |
| else: | |
| person_img = person_image.convert("RGB") | |
| if isinstance(garment_image, np.ndarray): | |
| garment_img = Image.fromarray(garment_image).convert("RGB") | |
| else: | |
| garment_img = garment_image.convert("RGB") | |
| # Resize images | |
| target_width = 768 | |
| target_height = 1024 | |
| person_img = resize_and_crop(person_img, (target_width, target_height)) | |
| garment_img = resize_and_padding(garment_img, (target_width, target_height)) | |
| progress(0.4, desc="Generating body mask...") | |
| # Generate mask | |
| mask = self.automasker(person_img, "upper")['mask'] | |
| # Clear memory | |
| gc.collect() | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| device_msg = "GPU - ~30-60 seconds" if self.device == "cuda" else "CPU - ~2-5 minutes" | |
| progress(0.6, desc=f"Running virtual try-on on {device_msg}...") | |
| # Run inference | |
| result = self.pipeline( | |
| image=person_img, | |
| condition_image=garment_img, | |
| mask=mask, | |
| num_inference_steps=50, | |
| guidance_scale=2.5, | |
| seed=42, | |
| height=target_height, | |
| width=target_width | |
| )[0] | |
| # Clear memory after inference | |
| gc.collect() | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| progress(1.0, desc="Complete!") | |
| return result, f"β Virtual try-on generated successfully on {self.device.upper()}!" | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"β Error: {str(e)}\n\n{traceback.format_exc()}" | |
| print(error_msg) | |
| # Clear memory on error | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return None, error_msg | |
| # Initialize service | |
| print("π Initializing CatVTON Service...") | |
| service = CatVTONService() | |
| # Preload models on startup (optional - comment out if you want lazy loading) | |
| # try: | |
| # service.load_models() | |
| # except Exception as e: | |
| # print(f"β οΈ Could not preload models: {e}") | |
| # print("Models will be loaded on first request") | |
| # Create Gradio Interface | |
| def generate_tryon_interface(person_img, garment_img, progress=gr.Progress()): | |
| """Wrapper for Gradio""" | |
| result, message = service.generate_tryon(person_img, garment_img, progress) | |
| return result, message | |
| # Build UI | |
| with gr.Blocks( | |
| title="CatVTON Virtual Try-On", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container {max-width: 1200px !important} | |
| #title {text-align: center; margin-bottom: 1em} | |
| #subtitle {text-align: center; color: #666; margin-bottom: 2em} | |
| """ | |
| ) as demo: | |
| device_info = "π₯οΈ GPU" if torch.cuda.is_available() else "π» CPU" | |
| processing_time = "30-60 seconds" if torch.cuda.is_available() else "2-5 minutes" | |
| gr.HTML(f""" | |
| <div id="title"> | |
| <h1>π CatVTON - Virtual Try-On</h1> | |
| </div> | |
| <div id="subtitle"> | |
| <p>Upload a person image and a garment to see how it looks on them!</p> | |
| <p><strong>Device:</strong> {device_info} | <strong>Processing Time:</strong> ~{processing_time}</p> | |
| <p><em>First run downloads models (~5GB) - subsequent runs are faster!</em></p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### πΈ Step 1: Upload Images") | |
| person_input = gr.Image( | |
| label="π€ Person Image (full body, front-facing)", | |
| type="pil", | |
| sources=["upload", "clipboard"] | |
| ) | |
| garment_input = gr.Image( | |
| label="π Garment Image (upper body clothing)", | |
| type="pil", | |
| sources=["upload", "clipboard"] | |
| ) | |
| generate_btn = gr.Button( | |
| "π Generate Virtual Try-On", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.Markdown(""" | |
| ### π‘ Tips for Best Results: | |
| - Use well-lit, clear images | |
| - Person should face camera directly | |
| - Garment on plain/white background | |
| - Works best with shirts, jackets, tops | |
| - Avoid images with multiple people | |
| """) | |
| with gr.Column(): | |
| gr.Markdown("### β¨ Result") | |
| result_output = gr.Image( | |
| label="Generated Try-On Result", | |
| type="pil" | |
| ) | |
| status_output = gr.Textbox( | |
| label="Status", | |
| lines=3, | |
| show_label=True | |
| ) | |
| # Examples (only show if examples directory exists) | |
| if os.path.exists("examples"): | |
| gr.Markdown("### π Example Images") | |
| example_files = [] | |
| if os.path.exists("examples/person1.jpg") and os.path.exists("examples/garment1.jpg"): | |
| example_files.append(["examples/person1.jpg", "examples/garment1.jpg"]) | |
| if os.path.exists("examples/person2.jpg") and os.path.exists("examples/garment2.jpg"): | |
| example_files.append(["examples/person2.jpg", "examples/garment2.jpg"]) | |
| if example_files: | |
| gr.Examples( | |
| examples=example_files, | |
| inputs=[person_input, garment_input], | |
| label="Try these examples" | |
| ) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| ### βΉοΈ About | |
| This app uses **CatVTON** (Concatenation-based Attention Virtual Try-On) for realistic garment transfer. | |
| - Model: [zhengchong/CatVTON](https://huggingface.co/zhengchong/CatVTON) | |
| - Based on Stable Diffusion Inpainting | |
| - Supports upper body garments (shirts, jackets, tops) | |
| **Note:** Processing time depends on hardware. GPU is recommended for faster results. | |
| """) | |
| # Connect button | |
| generate_btn.click( | |
| fn=generate_tryon_interface, | |
| inputs=[person_input, garment_input], | |
| outputs=[result_output, status_output] | |
| ) | |
| # Launch app | |
| if __name__ == "__main__": | |
| print("\n" + "="*60) | |
| print("π Starting CatVTON Virtual Try-On Server") | |
| print("="*60) | |
| print(f"Device: {service.device}") | |
| print(f"Server: http://0.0.0.0:7860") | |
| print("="*60 + "\n") | |
| demo.queue( | |
| max_size=20, # Max queue size | |
| default_concurrency_limit=2 # Limit concurrent requests | |
| ) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| share=False # Don't create public link on HF Spaces | |
| ) |