Spaces:
Paused
Paused
File size: 4,056 Bytes
684c222 9ee4d75 76d32b4 9ee4d75 684c222 9ee4d75 684c222 9ee4d75 48a4c12 9ee4d75 0e6a4e7 6fbab6f 8ec79fc 0e6a4e7 76d32b4 1693385 0e6a4e7 48a4c12 684c222 0e6a4e7 b9f9bb9 684c222 0e6a4e7 9ee4d75 a128a81 48a4c12 a128a81 76d32b4 9ee4d75 684c222 9ee4d75 b9f9bb9 9ee4d75 3a5360b 9ee4d75 a128a81 9ee4d75 3a5360b 9ee4d75 a128a81 9ee4d75 b9f9bb9 9ee4d75 684c222 9ee4d75 a128a81 48a4c12 a128a81 48a4c12 a128a81 1d69cfd a128a81 684c222 48a4c12 9ee4d75 684c222 9ee4d75 684c222 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import patch_gradio
import os
import sys
import torch
import gradio as gr
from PIL import Image
import gc
import traceback
from huggingface_hub import snapshot_download
sys.path.insert(0, '/app/CatVTON')
from model.pipeline import CatVTONPipeline
from model.cloth_masker import AutoMasker
from utils import init_weight_dtype, resize_and_crop, resize_and_padding
pipeline = None
automasker = None
def load_models():
global pipeline, automasker
if pipeline is not None and automasker is not None:
return
print("π Loading models...", file=sys.stderr)
try:
repo_path = snapshot_download(
repo_id="zhengchong/CatVTON",
cache_dir="/tmp/models"
)
nsfw_path = "/tmp/NSFW.jpg"
if not os.path.exists(nsfw_path):
Image.new('RGB', (512, 512), color='black').save(nsfw_path)
pipeline = CatVTONPipeline(
base_ckpt="booksforcharlie/stable-diffusion-inpainting",
attn_ckpt=repo_path,
attn_ckpt_version="mix",
weight_dtype=torch.float16,
use_tf32=True,
device='cuda'
)
automasker = AutoMasker(
densepose_ckpt=os.path.join(repo_path, "DensePose"),
schp_ckpt=os.path.join(repo_path, "SCHP"),
device='cpu'
)
print("β
Models loaded!", file=sys.stderr)
except Exception as e:
print(f"β Error: {e}", file=sys.stderr)
traceback.print_exc()
raise
def generate_tryon(person_img, cloth_img, garment_category): # Added parameter
print("="*50, file=sys.stderr)
print(f"Received - Person: {type(person_img)}, Cloth: {type(cloth_img)}, Category: {garment_category}", file=sys.stderr)
if person_img is None or cloth_img is None:
raise gr.Error("Both images required!")
try:
print("Images received as PIL", file=sys.stderr)
load_models()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
person_img = person_img.copy()
cloth_img = cloth_img.copy()
target_height = 1024
target_width = 768
person_img = resize_and_crop(person_img, (target_width, target_height))
cloth_img = resize_and_padding(cloth_img, (target_width, target_height))
# CHANGED: Use selected category instead of hardcoded "upper"
mask = automasker(person_img, garment_category)['mask']
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
result = pipeline(
image=person_img,
condition_image=cloth_img,
mask=mask,
num_inference_steps=50,
guidance_scale=2.5,
seed=None,
height=target_height,
width=target_width
)[0]
print("β
Success!", file=sys.stderr)
return result
except Exception as e:
print(f"β Error: {e}", file=sys.stderr)
traceback.print_exc()
raise gr.Error(str(e))
# Updated Interface with category dropdown
demo = gr.Interface(
fn=generate_tryon,
inputs=[
gr.Image(label="Person Image", type="pil"),
gr.Image(label="Garment Image", type="pil"),
gr.Dropdown(
choices=["upper", "lower", "overall"],
value="upper",
label="Garment Category",
info="Select: upper (tops), lower (pants/skirts), overall (dresses/full outfits)"
)
],
outputs=gr.Image(label="Result", type="pil"),
title="Try-Space Virtual Try-On",
description="Upload person and garment images. Select category. Processing takes 2-3 minutes on GPU T4.",
api_name="generate_tryon",
allow_flagging="never"
)
if __name__ == "__main__":
print("π Starting...", file=sys.stderr)
try:
load_models()
except:
pass
demo.queue().launch(show_error=True) |