import gradio as gr import torch from PIL import Image import os import sys import gc from huggingface_hub import snapshot_download import numpy as np # Add CatVTON to path sys.path.insert(0, './CatVTON') from model.pipeline import CatVTONPipeline from model.cloth_masker import AutoMasker from utils import init_weight_dtype, resize_and_crop, resize_and_padding class CatVTONService: def __init__(self): # Auto-detect device self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"đŸ–Ĩī¸ Using device: {self.device}") self.pipeline = None self.automasker = None self.models_loaded = False def load_models(self): """Load models once and cache them""" if self.models_loaded: return print("🔄 Loading CatVTON models (this happens once)...") try: # Download model weights from HuggingFace Hub - CACHED automatically repo_path = snapshot_download( repo_id="zhengchong/CatVTON", cache_dir="./model_cache", resume_download=True, # Resume if interrupted local_files_only=False # Allow downloading ) print(f"✅ Models downloaded to: {repo_path}") # Determine weight dtype based on device weight_dtype = init_weight_dtype("fp16" if self.device == "cuda" else "fp32") use_tf32 = self.device == "cuda" # Only use TF32 on CUDA print(f"âš™ī¸ Weight dtype: {weight_dtype}, TF32: {use_tf32}") # Initialize pipeline self.pipeline = CatVTONPipeline( base_ckpt="booksforcharlie/stable-diffusion-inpainting", attn_ckpt=repo_path, attn_ckpt_version="mix", weight_dtype=weight_dtype, use_tf32=use_tf32, device=self.device ) # Initialize automasker self.automasker = AutoMasker( densepose_ckpt=os.path.join(repo_path, "DensePose"), schp_ckpt=os.path.join(repo_path, "SCHP"), device=self.device ) self.models_loaded = True print("✅ CatVTON ready!") except Exception as e: print(f"❌ Error loading models: {e}") raise def generate_tryon(self, person_image, garment_image, progress=gr.Progress()): """Generate virtual try-on result""" try: # Load models if not already loaded progress(0, desc="Loading models...") self.load_models() # Validate inputs if person_image is None or garment_image is None: return None, "❌ Please upload both person and garment images!" progress(0.2, desc="Processing images...") # Convert to PIL Images if isinstance(person_image, np.ndarray): person_img = Image.fromarray(person_image).convert("RGB") else: person_img = person_image.convert("RGB") if isinstance(garment_image, np.ndarray): garment_img = Image.fromarray(garment_image).convert("RGB") else: garment_img = garment_image.convert("RGB") # Resize images target_width = 768 target_height = 1024 person_img = resize_and_crop(person_img, (target_width, target_height)) garment_img = resize_and_padding(garment_img, (target_width, target_height)) progress(0.4, desc="Generating body mask...") # Generate mask mask = self.automasker(person_img, "upper")['mask'] # Clear memory gc.collect() if self.device == "cuda": torch.cuda.empty_cache() device_msg = "GPU - ~30-60 seconds" if self.device == "cuda" else "CPU - ~2-5 minutes" progress(0.6, desc=f"Running virtual try-on on {device_msg}...") # Run inference result = self.pipeline( image=person_img, condition_image=garment_img, mask=mask, num_inference_steps=50, guidance_scale=2.5, seed=42, height=target_height, width=target_width )[0] # Clear memory after inference gc.collect() if self.device == "cuda": torch.cuda.empty_cache() progress(1.0, desc="Complete!") return result, f"✅ Virtual try-on generated successfully on {self.device.upper()}!" except Exception as e: import traceback error_msg = f"❌ Error: {str(e)}\n\n{traceback.format_exc()}" print(error_msg) # Clear memory on error gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return None, error_msg # Initialize service print("🚀 Initializing CatVTON Service...") service = CatVTONService() # Preload models on startup (optional - comment out if you want lazy loading) # try: # service.load_models() # except Exception as e: # print(f"âš ī¸ Could not preload models: {e}") # print("Models will be loaded on first request") # Create Gradio Interface def generate_tryon_interface(person_img, garment_img, progress=gr.Progress()): """Wrapper for Gradio""" result, message = service.generate_tryon(person_img, garment_img, progress) return result, message # Build UI with gr.Blocks( title="CatVTON Virtual Try-On", theme=gr.themes.Soft(), css=""" .gradio-container {max-width: 1200px !important} #title {text-align: center; margin-bottom: 1em} #subtitle {text-align: center; color: #666; margin-bottom: 2em} """ ) as demo: device_info = "đŸ–Ĩī¸ GPU" if torch.cuda.is_available() else "đŸ’ģ CPU" processing_time = "30-60 seconds" if torch.cuda.is_available() else "2-5 minutes" gr.HTML(f"""

👗 CatVTON - Virtual Try-On

Upload a person image and a garment to see how it looks on them!

Device: {device_info} | Processing Time: ~{processing_time}

First run downloads models (~5GB) - subsequent runs are faster!

""") with gr.Row(): with gr.Column(): gr.Markdown("### 📸 Step 1: Upload Images") person_input = gr.Image( label="👤 Person Image (full body, front-facing)", type="pil", sources=["upload", "clipboard"] ) garment_input = gr.Image( label="👕 Garment Image (upper body clothing)", type="pil", sources=["upload", "clipboard"] ) generate_btn = gr.Button( "🚀 Generate Virtual Try-On", variant="primary", size="lg" ) gr.Markdown(""" ### 💡 Tips for Best Results: - Use well-lit, clear images - Person should face camera directly - Garment on plain/white background - Works best with shirts, jackets, tops - Avoid images with multiple people """) with gr.Column(): gr.Markdown("### ✨ Result") result_output = gr.Image( label="Generated Try-On Result", type="pil" ) status_output = gr.Textbox( label="Status", lines=3, show_label=True ) # Examples (only show if examples directory exists) if os.path.exists("examples"): gr.Markdown("### 📋 Example Images") example_files = [] if os.path.exists("examples/person1.jpg") and os.path.exists("examples/garment1.jpg"): example_files.append(["examples/person1.jpg", "examples/garment1.jpg"]) if os.path.exists("examples/person2.jpg") and os.path.exists("examples/garment2.jpg"): example_files.append(["examples/person2.jpg", "examples/garment2.jpg"]) if example_files: gr.Examples( examples=example_files, inputs=[person_input, garment_input], label="Try these examples" ) # Footer gr.Markdown(""" --- ### â„šī¸ About This app uses **CatVTON** (Concatenation-based Attention Virtual Try-On) for realistic garment transfer. - Model: [zhengchong/CatVTON](https://huggingface.co/zhengchong/CatVTON) - Based on Stable Diffusion Inpainting - Supports upper body garments (shirts, jackets, tops) **Note:** Processing time depends on hardware. GPU is recommended for faster results. """) # Connect button generate_btn.click( fn=generate_tryon_interface, inputs=[person_input, garment_input], outputs=[result_output, status_output] ) # Launch app if __name__ == "__main__": print("\n" + "="*60) print("🌐 Starting CatVTON Virtual Try-On Server") print("="*60) print(f"Device: {service.device}") print(f"Server: http://0.0.0.0:7860") print("="*60 + "\n") demo.queue( max_size=20, # Max queue size default_concurrency_limit=2 # Limit concurrent requests ) demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True, share=False # Don't create public link on HF Spaces )