import gradio as gr import torch from PIL import Image import os import sys import gc from huggingface_hub import snapshot_download import numpy as np # Add CatVTON to path sys.path.insert(0, './CatVTON') from model.pipeline import CatVTONPipeline from model.cloth_masker import AutoMasker from utils import init_weight_dtype, resize_and_crop, resize_and_padding class CatVTONService: def __init__(self): # Auto-detect device self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"đĨī¸ Using device: {self.device}") self.pipeline = None self.automasker = None self.models_loaded = False def load_models(self): """Load models once and cache them""" if self.models_loaded: return print("đ Loading CatVTON models (this happens once)...") try: # Download model weights from HuggingFace Hub - CACHED automatically repo_path = snapshot_download( repo_id="zhengchong/CatVTON", cache_dir="./model_cache", resume_download=True, # Resume if interrupted local_files_only=False # Allow downloading ) print(f"â Models downloaded to: {repo_path}") # Determine weight dtype based on device weight_dtype = init_weight_dtype("fp16" if self.device == "cuda" else "fp32") use_tf32 = self.device == "cuda" # Only use TF32 on CUDA print(f"âī¸ Weight dtype: {weight_dtype}, TF32: {use_tf32}") # Initialize pipeline self.pipeline = CatVTONPipeline( base_ckpt="booksforcharlie/stable-diffusion-inpainting", attn_ckpt=repo_path, attn_ckpt_version="mix", weight_dtype=weight_dtype, use_tf32=use_tf32, device=self.device ) # Initialize automasker self.automasker = AutoMasker( densepose_ckpt=os.path.join(repo_path, "DensePose"), schp_ckpt=os.path.join(repo_path, "SCHP"), device=self.device ) self.models_loaded = True print("â CatVTON ready!") except Exception as e: print(f"â Error loading models: {e}") raise def generate_tryon(self, person_image, garment_image, progress=gr.Progress()): """Generate virtual try-on result""" try: # Load models if not already loaded progress(0, desc="Loading models...") self.load_models() # Validate inputs if person_image is None or garment_image is None: return None, "â Please upload both person and garment images!" progress(0.2, desc="Processing images...") # Convert to PIL Images if isinstance(person_image, np.ndarray): person_img = Image.fromarray(person_image).convert("RGB") else: person_img = person_image.convert("RGB") if isinstance(garment_image, np.ndarray): garment_img = Image.fromarray(garment_image).convert("RGB") else: garment_img = garment_image.convert("RGB") # Resize images target_width = 768 target_height = 1024 person_img = resize_and_crop(person_img, (target_width, target_height)) garment_img = resize_and_padding(garment_img, (target_width, target_height)) progress(0.4, desc="Generating body mask...") # Generate mask mask = self.automasker(person_img, "upper")['mask'] # Clear memory gc.collect() if self.device == "cuda": torch.cuda.empty_cache() device_msg = "GPU - ~30-60 seconds" if self.device == "cuda" else "CPU - ~2-5 minutes" progress(0.6, desc=f"Running virtual try-on on {device_msg}...") # Run inference result = self.pipeline( image=person_img, condition_image=garment_img, mask=mask, num_inference_steps=50, guidance_scale=2.5, seed=42, height=target_height, width=target_width )[0] # Clear memory after inference gc.collect() if self.device == "cuda": torch.cuda.empty_cache() progress(1.0, desc="Complete!") return result, f"â Virtual try-on generated successfully on {self.device.upper()}!" except Exception as e: import traceback error_msg = f"â Error: {str(e)}\n\n{traceback.format_exc()}" print(error_msg) # Clear memory on error gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return None, error_msg # Initialize service print("đ Initializing CatVTON Service...") service = CatVTONService() # Preload models on startup (optional - comment out if you want lazy loading) # try: # service.load_models() # except Exception as e: # print(f"â ī¸ Could not preload models: {e}") # print("Models will be loaded on first request") # Create Gradio Interface def generate_tryon_interface(person_img, garment_img, progress=gr.Progress()): """Wrapper for Gradio""" result, message = service.generate_tryon(person_img, garment_img, progress) return result, message # Build UI with gr.Blocks( title="CatVTON Virtual Try-On", theme=gr.themes.Soft(), css=""" .gradio-container {max-width: 1200px !important} #title {text-align: center; margin-bottom: 1em} #subtitle {text-align: center; color: #666; margin-bottom: 2em} """ ) as demo: device_info = "đĨī¸ GPU" if torch.cuda.is_available() else "đģ CPU" processing_time = "30-60 seconds" if torch.cuda.is_available() else "2-5 minutes" gr.HTML(f"""
Upload a person image and a garment to see how it looks on them!
Device: {device_info} | Processing Time: ~{processing_time}
First run downloads models (~5GB) - subsequent runs are faster!