feylur's picture
Update app.py
a128a81 verified
import patch_gradio
import os
import sys
import torch
import gradio as gr
from PIL import Image
import gc
import traceback
from huggingface_hub import snapshot_download
sys.path.insert(0, '/app/CatVTON')
from model.pipeline import CatVTONPipeline
from model.cloth_masker import AutoMasker
from utils import init_weight_dtype, resize_and_crop, resize_and_padding
pipeline = None
automasker = None
def load_models():
global pipeline, automasker
if pipeline is not None and automasker is not None:
return
print("πŸ”„ Loading models...", file=sys.stderr)
try:
repo_path = snapshot_download(
repo_id="zhengchong/CatVTON",
cache_dir="/tmp/models"
)
nsfw_path = "/tmp/NSFW.jpg"
if not os.path.exists(nsfw_path):
Image.new('RGB', (512, 512), color='black').save(nsfw_path)
pipeline = CatVTONPipeline(
base_ckpt="booksforcharlie/stable-diffusion-inpainting",
attn_ckpt=repo_path,
attn_ckpt_version="mix",
weight_dtype=torch.float16,
use_tf32=True,
device='cuda'
)
automasker = AutoMasker(
densepose_ckpt=os.path.join(repo_path, "DensePose"),
schp_ckpt=os.path.join(repo_path, "SCHP"),
device='cpu'
)
print("βœ… Models loaded!", file=sys.stderr)
except Exception as e:
print(f"❌ Error: {e}", file=sys.stderr)
traceback.print_exc()
raise
def generate_tryon(person_img, cloth_img, garment_category): # Added parameter
print("="*50, file=sys.stderr)
print(f"Received - Person: {type(person_img)}, Cloth: {type(cloth_img)}, Category: {garment_category}", file=sys.stderr)
if person_img is None or cloth_img is None:
raise gr.Error("Both images required!")
try:
print("Images received as PIL", file=sys.stderr)
load_models()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
person_img = person_img.copy()
cloth_img = cloth_img.copy()
target_height = 1024
target_width = 768
person_img = resize_and_crop(person_img, (target_width, target_height))
cloth_img = resize_and_padding(cloth_img, (target_width, target_height))
# CHANGED: Use selected category instead of hardcoded "upper"
mask = automasker(person_img, garment_category)['mask']
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
result = pipeline(
image=person_img,
condition_image=cloth_img,
mask=mask,
num_inference_steps=50,
guidance_scale=2.5,
seed=None,
height=target_height,
width=target_width
)[0]
print("βœ… Success!", file=sys.stderr)
return result
except Exception as e:
print(f"❌ Error: {e}", file=sys.stderr)
traceback.print_exc()
raise gr.Error(str(e))
# Updated Interface with category dropdown
demo = gr.Interface(
fn=generate_tryon,
inputs=[
gr.Image(label="Person Image", type="pil"),
gr.Image(label="Garment Image", type="pil"),
gr.Dropdown(
choices=["upper", "lower", "overall"],
value="upper",
label="Garment Category",
info="Select: upper (tops), lower (pants/skirts), overall (dresses/full outfits)"
)
],
outputs=gr.Image(label="Result", type="pil"),
title="Try-Space Virtual Try-On",
description="Upload person and garment images. Select category. Processing takes 2-3 minutes on GPU T4.",
api_name="generate_tryon",
allow_flagging="never"
)
if __name__ == "__main__":
print("πŸš€ Starting...", file=sys.stderr)
try:
load_models()
except:
pass
demo.queue().launch(show_error=True)