Try-Space-Tryon / app.py
feylur's picture
Update app.py
0ae5112 verified
import gradio as gr
import torch
from PIL import Image
import os
import sys
import gc
from huggingface_hub import snapshot_download
import numpy as np
# Add CatVTON to path
sys.path.insert(0, './CatVTON')
from model.pipeline import CatVTONPipeline
from model.cloth_masker import AutoMasker
from utils import init_weight_dtype, resize_and_crop, resize_and_padding
class CatVTONService:
def __init__(self):
# Auto-detect device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"πŸ–₯️ Using device: {self.device}")
self.pipeline = None
self.automasker = None
self.models_loaded = False
def load_models(self):
"""Load models once and cache them"""
if self.models_loaded:
return
print("πŸ”„ Loading CatVTON models (this happens once)...")
try:
# Download model weights from HuggingFace Hub - CACHED automatically
repo_path = snapshot_download(
repo_id="zhengchong/CatVTON",
cache_dir="./model_cache",
resume_download=True, # Resume if interrupted
local_files_only=False # Allow downloading
)
print(f"βœ… Models downloaded to: {repo_path}")
# Determine weight dtype based on device
weight_dtype = init_weight_dtype("fp16" if self.device == "cuda" else "fp32")
use_tf32 = self.device == "cuda" # Only use TF32 on CUDA
print(f"βš™οΈ Weight dtype: {weight_dtype}, TF32: {use_tf32}")
# Initialize pipeline
self.pipeline = CatVTONPipeline(
base_ckpt="booksforcharlie/stable-diffusion-inpainting",
attn_ckpt=repo_path,
attn_ckpt_version="mix",
weight_dtype=weight_dtype,
use_tf32=use_tf32,
device=self.device
)
# Initialize automasker
self.automasker = AutoMasker(
densepose_ckpt=os.path.join(repo_path, "DensePose"),
schp_ckpt=os.path.join(repo_path, "SCHP"),
device=self.device
)
self.models_loaded = True
print("βœ… CatVTON ready!")
except Exception as e:
print(f"❌ Error loading models: {e}")
raise
def generate_tryon(self, person_image, garment_image, progress=gr.Progress()):
"""Generate virtual try-on result"""
try:
# Load models if not already loaded
progress(0, desc="Loading models...")
self.load_models()
# Validate inputs
if person_image is None or garment_image is None:
return None, "❌ Please upload both person and garment images!"
progress(0.2, desc="Processing images...")
# Convert to PIL Images
if isinstance(person_image, np.ndarray):
person_img = Image.fromarray(person_image).convert("RGB")
else:
person_img = person_image.convert("RGB")
if isinstance(garment_image, np.ndarray):
garment_img = Image.fromarray(garment_image).convert("RGB")
else:
garment_img = garment_image.convert("RGB")
# Resize images
target_width = 768
target_height = 1024
person_img = resize_and_crop(person_img, (target_width, target_height))
garment_img = resize_and_padding(garment_img, (target_width, target_height))
progress(0.4, desc="Generating body mask...")
# Generate mask
mask = self.automasker(person_img, "upper")['mask']
# Clear memory
gc.collect()
if self.device == "cuda":
torch.cuda.empty_cache()
device_msg = "GPU - ~30-60 seconds" if self.device == "cuda" else "CPU - ~2-5 minutes"
progress(0.6, desc=f"Running virtual try-on on {device_msg}...")
# Run inference
result = self.pipeline(
image=person_img,
condition_image=garment_img,
mask=mask,
num_inference_steps=50,
guidance_scale=2.5,
seed=42,
height=target_height,
width=target_width
)[0]
# Clear memory after inference
gc.collect()
if self.device == "cuda":
torch.cuda.empty_cache()
progress(1.0, desc="Complete!")
return result, f"βœ… Virtual try-on generated successfully on {self.device.upper()}!"
except Exception as e:
import traceback
error_msg = f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
print(error_msg)
# Clear memory on error
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return None, error_msg
# Initialize service
print("πŸš€ Initializing CatVTON Service...")
service = CatVTONService()
# Preload models on startup (optional - comment out if you want lazy loading)
# try:
# service.load_models()
# except Exception as e:
# print(f"⚠️ Could not preload models: {e}")
# print("Models will be loaded on first request")
# Create Gradio Interface
def generate_tryon_interface(person_img, garment_img, progress=gr.Progress()):
"""Wrapper for Gradio"""
result, message = service.generate_tryon(person_img, garment_img, progress)
return result, message
# Build UI
with gr.Blocks(
title="CatVTON Virtual Try-On",
theme=gr.themes.Soft(),
css="""
.gradio-container {max-width: 1200px !important}
#title {text-align: center; margin-bottom: 1em}
#subtitle {text-align: center; color: #666; margin-bottom: 2em}
"""
) as demo:
device_info = "πŸ–₯️ GPU" if torch.cuda.is_available() else "πŸ’» CPU"
processing_time = "30-60 seconds" if torch.cuda.is_available() else "2-5 minutes"
gr.HTML(f"""
<div id="title">
<h1>πŸ‘— CatVTON - Virtual Try-On</h1>
</div>
<div id="subtitle">
<p>Upload a person image and a garment to see how it looks on them!</p>
<p><strong>Device:</strong> {device_info} | <strong>Processing Time:</strong> ~{processing_time}</p>
<p><em>First run downloads models (~5GB) - subsequent runs are faster!</em></p>
</div>
""")
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ“Έ Step 1: Upload Images")
person_input = gr.Image(
label="πŸ‘€ Person Image (full body, front-facing)",
type="pil",
sources=["upload", "clipboard"]
)
garment_input = gr.Image(
label="πŸ‘• Garment Image (upper body clothing)",
type="pil",
sources=["upload", "clipboard"]
)
generate_btn = gr.Button(
"πŸš€ Generate Virtual Try-On",
variant="primary",
size="lg"
)
gr.Markdown("""
### πŸ’‘ Tips for Best Results:
- Use well-lit, clear images
- Person should face camera directly
- Garment on plain/white background
- Works best with shirts, jackets, tops
- Avoid images with multiple people
""")
with gr.Column():
gr.Markdown("### ✨ Result")
result_output = gr.Image(
label="Generated Try-On Result",
type="pil"
)
status_output = gr.Textbox(
label="Status",
lines=3,
show_label=True
)
# Examples (only show if examples directory exists)
if os.path.exists("examples"):
gr.Markdown("### πŸ“‹ Example Images")
example_files = []
if os.path.exists("examples/person1.jpg") and os.path.exists("examples/garment1.jpg"):
example_files.append(["examples/person1.jpg", "examples/garment1.jpg"])
if os.path.exists("examples/person2.jpg") and os.path.exists("examples/garment2.jpg"):
example_files.append(["examples/person2.jpg", "examples/garment2.jpg"])
if example_files:
gr.Examples(
examples=example_files,
inputs=[person_input, garment_input],
label="Try these examples"
)
# Footer
gr.Markdown("""
---
### ℹ️ About
This app uses **CatVTON** (Concatenation-based Attention Virtual Try-On) for realistic garment transfer.
- Model: [zhengchong/CatVTON](https://huggingface.co/zhengchong/CatVTON)
- Based on Stable Diffusion Inpainting
- Supports upper body garments (shirts, jackets, tops)
**Note:** Processing time depends on hardware. GPU is recommended for faster results.
""")
# Connect button
generate_btn.click(
fn=generate_tryon_interface,
inputs=[person_input, garment_input],
outputs=[result_output, status_output]
)
# Launch app
if __name__ == "__main__":
print("\n" + "="*60)
print("🌐 Starting CatVTON Virtual Try-On Server")
print("="*60)
print(f"Device: {service.device}")
print(f"Server: http://0.0.0.0:7860")
print("="*60 + "\n")
demo.queue(
max_size=20, # Max queue size
default_concurrency_limit=2 # Limit concurrent requests
)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
share=False # Don't create public link on HF Spaces
)