File size: 4,056 Bytes
684c222
9ee4d75
 
 
 
 
 
76d32b4
9ee4d75
 
 
 
 
 
 
 
 
 
 
 
684c222
9ee4d75
684c222
9ee4d75
 
48a4c12
9ee4d75
0e6a4e7
 
 
 
 
 
6fbab6f
8ec79fc
 
 
0e6a4e7
 
 
 
76d32b4
 
1693385
0e6a4e7
 
 
 
 
 
48a4c12
684c222
0e6a4e7
 
b9f9bb9
684c222
0e6a4e7
9ee4d75
a128a81
48a4c12
a128a81
76d32b4
9ee4d75
684c222
9ee4d75
 
b9f9bb9
 
9ee4d75
 
3a5360b
 
 
 
 
 
 
9ee4d75
 
 
 
 
a128a81
 
9ee4d75
3a5360b
 
9ee4d75
 
 
 
 
 
 
a128a81
9ee4d75
 
 
 
b9f9bb9
9ee4d75
 
 
684c222
 
 
9ee4d75
a128a81
48a4c12
 
 
a128a81
 
 
 
 
 
 
 
48a4c12
a128a81
1d69cfd
a128a81
684c222
 
48a4c12
9ee4d75
 
684c222
9ee4d75
 
684c222
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import patch_gradio
import os
import sys
import torch
import gradio as gr
from PIL import Image
import gc
import traceback
from huggingface_hub import snapshot_download

sys.path.insert(0, '/app/CatVTON')

from model.pipeline import CatVTONPipeline
from model.cloth_masker import AutoMasker
from utils import init_weight_dtype, resize_and_crop, resize_and_padding

pipeline = None
automasker = None

def load_models():
    global pipeline, automasker
    
    if pipeline is not None and automasker is not None:
        return
    
    print("πŸ”„ Loading models...", file=sys.stderr)
    
    try:
        repo_path = snapshot_download(
            repo_id="zhengchong/CatVTON",
            cache_dir="/tmp/models"
        )
        
        nsfw_path = "/tmp/NSFW.jpg"
        if not os.path.exists(nsfw_path):
            Image.new('RGB', (512, 512), color='black').save(nsfw_path)
        
        pipeline = CatVTONPipeline(
            base_ckpt="booksforcharlie/stable-diffusion-inpainting",
            attn_ckpt=repo_path,
            attn_ckpt_version="mix",
            weight_dtype=torch.float16,
            use_tf32=True,
            device='cuda'
        )
        automasker = AutoMasker(
            densepose_ckpt=os.path.join(repo_path, "DensePose"),
            schp_ckpt=os.path.join(repo_path, "SCHP"),
            device='cpu'
        )
        
        print("βœ… Models loaded!", file=sys.stderr)
        
    except Exception as e:
        print(f"❌ Error: {e}", file=sys.stderr)
        traceback.print_exc()
        raise

def generate_tryon(person_img, cloth_img, garment_category):  # Added parameter
    print("="*50, file=sys.stderr)
    print(f"Received - Person: {type(person_img)}, Cloth: {type(cloth_img)}, Category: {garment_category}", file=sys.stderr)
    
    if person_img is None or cloth_img is None:
        raise gr.Error("Both images required!")
    
    try:
        print("Images received as PIL", file=sys.stderr)
        
        load_models()
        
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        person_img = person_img.copy()
        cloth_img = cloth_img.copy()
        
        target_height = 1024
        target_width = 768
        person_img = resize_and_crop(person_img, (target_width, target_height))
        cloth_img = resize_and_padding(cloth_img, (target_width, target_height))
        
        # CHANGED: Use selected category instead of hardcoded "upper"
        mask = automasker(person_img, garment_category)['mask']
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        result = pipeline(
            image=person_img,
            condition_image=cloth_img,
            mask=mask,
            num_inference_steps=50,
            guidance_scale=2.5,
            seed=None,
            height=target_height,
            width=target_width
        )[0]
        
        print("βœ… Success!", file=sys.stderr)
        return result
        
    except Exception as e:
        print(f"❌ Error: {e}", file=sys.stderr)
        traceback.print_exc()
        raise gr.Error(str(e))

# Updated Interface with category dropdown
demo = gr.Interface(
    fn=generate_tryon,
    inputs=[
        gr.Image(label="Person Image", type="pil"),
        gr.Image(label="Garment Image", type="pil"),
        gr.Dropdown(
            choices=["upper", "lower", "overall"],
            value="upper",
            label="Garment Category",
            info="Select: upper (tops), lower (pants/skirts), overall (dresses/full outfits)"
        )
    ],
    outputs=gr.Image(label="Result", type="pil"),
    title="Try-Space Virtual Try-On",
    description="Upload person and garment images. Select category. Processing takes 2-3 minutes on GPU T4.",
    api_name="generate_tryon",
    allow_flagging="never"
)

if __name__ == "__main__":
    print("πŸš€ Starting...", file=sys.stderr)
    try:
        load_models()
    except:
        pass
    demo.queue().launch(show_error=True)