feylur commited on
Commit
9ee4d75
Β·
verified Β·
1 Parent(s): 51d8d1a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import gradio as gr
5
+ from PIL import Image
6
+ import gc
7
+ from huggingface_hub import snapshot_download
8
+
9
+ # Add CatVTON to path
10
+ sys.path.insert(0, '/app/CatVTON')
11
+
12
+ from model.pipeline import CatVTONPipeline
13
+ from model.cloth_masker import AutoMasker
14
+ from utils import init_weight_dtype, resize_and_crop, resize_and_padding
15
+
16
+ # Global variables
17
+ pipeline = None
18
+ automasker = None
19
+
20
+ def load_models():
21
+ """Load models once at startup"""
22
+ global pipeline, automasker
23
+
24
+ if pipeline is not None:
25
+ return
26
+
27
+ print("πŸ”„ Downloading/Loading CatVTON models (first time may take 5-10 mins)...")
28
+
29
+ # Download and cache models
30
+ repo_path = snapshot_download(
31
+ repo_id="zhengchong/CatVTON",
32
+ cache_dir="/tmp/models" # For CPU basic tier
33
+ )
34
+
35
+ print("βœ… Models downloaded! Initializing pipeline...")
36
+
37
+ # Initialize pipeline for CPU
38
+ pipeline = CatVTONPipeline(
39
+ base_ckpt="booksforcharlie/stable-diffusion-inpainting",
40
+ attn_ckpt=repo_path,
41
+ attn_ckpt_version="mix",
42
+ weight_dtype=init_weight_dtype("fp32"), # CPU uses fp32
43
+ use_tf32=False, # CPU doesn't support TF32
44
+ device='cpu'
45
+ )
46
+
47
+ automasker = AutoMasker(
48
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
49
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
50
+ device='cpu'
51
+ )
52
+
53
+ print("βœ… Models loaded successfully!")
54
+
55
+ def generate_tryon(person_img, cloth_img, progress=gr.Progress()):
56
+ """Generate virtual try-on"""
57
+
58
+ if person_img is None or cloth_img is None:
59
+ raise gr.Error("Please upload both person and garment images!")
60
+
61
+ try:
62
+ # Load models
63
+ progress(0.05, desc="Loading models...")
64
+ load_models()
65
+
66
+ progress(0.15, desc="Processing images...")
67
+
68
+ # Resize images
69
+ target_height = 1024
70
+ target_width = 768
71
+ person_img = resize_and_crop(person_img, (target_width, target_height))
72
+ cloth_img = resize_and_padding(cloth_img, (target_width, target_height))
73
+
74
+ progress(0.35, desc="Generating body mask...")
75
+
76
+ # Generate mask
77
+ mask = automasker(person_img, "upper")['mask']
78
+
79
+ # Clear memory
80
+ gc.collect()
81
+
82
+ progress(0.50, desc="Running virtual try-on (this may take 5-10 mins on CPU)...")
83
+
84
+ # Run inference
85
+ result = pipeline(
86
+ image=person_img,
87
+ condition_image=cloth_img,
88
+ mask=mask,
89
+ num_inference_steps=50,
90
+ guidance_scale=2.5,
91
+ seed=42,
92
+ height=target_height,
93
+ width=target_width
94
+ )[0]
95
+
96
+ progress(1.0, desc="Complete! ✨")
97
+
98
+ return result
99
+
100
+ except Exception as e:
101
+ raise gr.Error(f"Error: {str(e)}")
102
+
103
+ # Create Gradio UI
104
+ with gr.Blocks(
105
+ title="CatVTON Virtual Try-On",
106
+ theme=gr.themes.Soft()
107
+ ) as demo:
108
+
109
+ gr.Markdown("""
110
+ # 🎨 CatVTON Virtual Try-On
111
+ ### Upload a person image and garment to see the magic! ✨
112
+
113
+ ⚠️ **Note:** Running on CPU - processing takes 5-10 minutes per image.
114
+ """)
115
+
116
+ with gr.Row():
117
+ with gr.Column():
118
+ gr.Markdown("### πŸ“Έ Inputs")
119
+ person_input = gr.Image(
120
+ label="πŸ‘€ Person Image (full body, front-facing)",
121
+ type="pil",
122
+ height=350
123
+ )
124
+ cloth_input = gr.Image(
125
+ label="πŸ‘• Garment Image (flat, white background)",
126
+ type="pil",
127
+ height=350
128
+ )
129
+
130
+ with gr.Row():
131
+ clear_btn = gr.ClearButton(
132
+ [person_input, cloth_input],
133
+ value="πŸ—‘οΈ Clear"
134
+ )
135
+ submit_btn = gr.Button(
136
+ "πŸš€ Generate Try-On",
137
+ variant="primary",
138
+ size="lg"
139
+ )
140
+
141
+ with gr.Column():
142
+ gr.Markdown("### ✨ Result")
143
+ output_img = gr.Image(
144
+ label="Virtual Try-On Result",
145
+ height=700
146
+ )
147
+
148
+ gr.Markdown("""
149
+ ---
150
+ ### πŸ’‘ Tips for Best Results:
151
+ - βœ… Use well-lit, clear images
152
+ - βœ… Person should face the camera directly
153
+ - βœ… Garment should be flat or on white background
154
+ - βœ… Works best with shirts, jackets, or tops
155
+ - βœ… Avoid extreme angles or poses
156
+
157
+ ### ⏱️ Processing Time:
158
+ - **CPU Basic:** ~5-10 minutes per generation
159
+ - **Upgrade to GPU:** Reduces to ~2-3 minutes
160
+ """)
161
+
162
+ # Event handler
163
+ submit_btn.click(
164
+ fn=generate_tryon,
165
+ inputs=[person_input, cloth_input],
166
+ outputs=output_img
167
+ )
168
+
169
+ # Launch app
170
+ if __name__ == "__main__":
171
+ print("πŸš€ Starting CatVTON Virtual Try-On...")
172
+ print("⏳ First run will download models (~2-3 GB)...")
173
+
174
+ # Pre-load models at startup
175
+ try:
176
+ load_models()
177
+ except Exception as e:
178
+ print(f"⚠️ Model loading will happen on first inference: {e}")
179
+
180
+ # Launch
181
+ demo.queue(max_size=5).launch(
182
+ server_name="0.0.0.0",
183
+ server_port=7860,
184
+ share=False
185
+ )