Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,7 +16,10 @@ from utils import init_weight_dtype, resize_and_crop, resize_and_padding
|
|
| 16 |
|
| 17 |
class CatVTONService:
|
| 18 |
def __init__(self):
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
| 20 |
self.pipeline = None
|
| 21 |
self.automasker = None
|
| 22 |
self.models_loaded = False
|
|
@@ -28,33 +31,46 @@ class CatVTONService:
|
|
| 28 |
|
| 29 |
print("π Loading CatVTON models (this happens once)...")
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
def generate_tryon(self, person_image, garment_image, progress=gr.Progress()):
|
| 60 |
"""Generate virtual try-on result"""
|
|
@@ -96,7 +112,8 @@ class CatVTONService:
|
|
| 96 |
if self.device == "cuda":
|
| 97 |
torch.cuda.empty_cache()
|
| 98 |
|
| 99 |
-
|
|
|
|
| 100 |
|
| 101 |
# Run inference
|
| 102 |
result = self.pipeline(
|
|
@@ -110,19 +127,38 @@ class CatVTONService:
|
|
| 110 |
width=target_width
|
| 111 |
)[0]
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
progress(1.0, desc="Complete!")
|
| 114 |
|
| 115 |
-
return result, "β
Virtual try-on generated successfully!"
|
| 116 |
|
| 117 |
except Exception as e:
|
| 118 |
import traceback
|
| 119 |
error_msg = f"β Error: {str(e)}\n\n{traceback.format_exc()}"
|
| 120 |
print(error_msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
return None, error_msg
|
| 122 |
|
| 123 |
# Initialize service
|
|
|
|
| 124 |
service = CatVTONService()
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
# Create Gradio Interface
|
| 127 |
def generate_tryon_interface(person_img, garment_img, progress=gr.Progress()):
|
| 128 |
"""Wrapper for Gradio"""
|
|
@@ -131,7 +167,7 @@ def generate_tryon_interface(person_img, garment_img, progress=gr.Progress()):
|
|
| 131 |
|
| 132 |
# Build UI
|
| 133 |
with gr.Blocks(
|
| 134 |
-
title="
|
| 135 |
theme=gr.themes.Soft(),
|
| 136 |
css="""
|
| 137 |
.gradio-container {max-width: 1200px !important}
|
|
@@ -140,13 +176,17 @@ with gr.Blocks(
|
|
| 140 |
"""
|
| 141 |
) as demo:
|
| 142 |
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
| 144 |
<div id="title">
|
| 145 |
-
<h1>π
|
| 146 |
</div>
|
| 147 |
<div id="subtitle">
|
| 148 |
<p>Upload a person image and a garment to see how it looks on them!</p>
|
| 149 |
-
<p><strong
|
|
|
|
| 150 |
</div>
|
| 151 |
""")
|
| 152 |
|
|
@@ -176,6 +216,7 @@ with gr.Blocks(
|
|
| 176 |
- Person should face camera directly
|
| 177 |
- Garment on plain/white background
|
| 178 |
- Works best with shirts, jackets, tops
|
|
|
|
| 179 |
""")
|
| 180 |
|
| 181 |
with gr.Column():
|
|
@@ -186,61 +227,38 @@ with gr.Blocks(
|
|
| 186 |
)
|
| 187 |
status_output = gr.Textbox(
|
| 188 |
label="Status",
|
| 189 |
-
lines=3
|
|
|
|
| 190 |
)
|
| 191 |
|
| 192 |
-
# Examples
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
["examples/
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
label="Try these examples (if available)"
|
| 201 |
-
)
|
| 202 |
-
|
| 203 |
-
# API Documentation
|
| 204 |
-
with gr.Accordion("π‘ API Usage", open=False):
|
| 205 |
-
gr.Markdown(f"""
|
| 206 |
-
### Using the Try-Space API
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
url = "https://YOUR_USERNAME-try-space.hf.space/api/predict"
|
| 222 |
-
|
| 223 |
-
response = requests.post(
|
| 224 |
-
url,
|
| 225 |
-
json={{
|
| 226 |
-
"data": [
|
| 227 |
-
encode_image("person.jpg"),
|
| 228 |
-
encode_image("garment.jpg")
|
| 229 |
-
]
|
| 230 |
-
}}
|
| 231 |
-
)
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
|
| 237 |
-
**
|
| 238 |
-
|
| 239 |
-
curl -X POST https://YOUR_USERNAME-try-space.hf.space/api/predict \\
|
| 240 |
-
-H "Content-Type: application/json" \\
|
| 241 |
-
-d '{{"data": ["<person_base64>", "<garment_base64>"]}}'
|
| 242 |
-
```
|
| 243 |
-
""")
|
| 244 |
|
| 245 |
# Connect button
|
| 246 |
generate_btn.click(
|
|
@@ -251,9 +269,21 @@ with gr.Blocks(
|
|
| 251 |
|
| 252 |
# Launch app
|
| 253 |
if __name__ == "__main__":
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
demo.launch(
|
| 256 |
server_name="0.0.0.0",
|
| 257 |
server_port=7860,
|
| 258 |
-
show_error=True
|
|
|
|
| 259 |
)
|
|
|
|
| 16 |
|
| 17 |
class CatVTONService:
|
| 18 |
def __init__(self):
|
| 19 |
+
# Auto-detect device
|
| 20 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
print(f"π₯οΈ Using device: {self.device}")
|
| 22 |
+
|
| 23 |
self.pipeline = None
|
| 24 |
self.automasker = None
|
| 25 |
self.models_loaded = False
|
|
|
|
| 31 |
|
| 32 |
print("π Loading CatVTON models (this happens once)...")
|
| 33 |
|
| 34 |
+
try:
|
| 35 |
+
# Download model weights from HuggingFace Hub - CACHED automatically
|
| 36 |
+
repo_path = snapshot_download(
|
| 37 |
+
repo_id="zhengchong/CatVTON",
|
| 38 |
+
cache_dir="./model_cache",
|
| 39 |
+
resume_download=True, # Resume if interrupted
|
| 40 |
+
local_files_only=False # Allow downloading
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
print(f"β
Models downloaded to: {repo_path}")
|
| 44 |
+
|
| 45 |
+
# Determine weight dtype based on device
|
| 46 |
+
weight_dtype = init_weight_dtype("fp16" if self.device == "cuda" else "fp32")
|
| 47 |
+
use_tf32 = self.device == "cuda" # Only use TF32 on CUDA
|
| 48 |
+
|
| 49 |
+
print(f"βοΈ Weight dtype: {weight_dtype}, TF32: {use_tf32}")
|
| 50 |
+
|
| 51 |
+
# Initialize pipeline
|
| 52 |
+
self.pipeline = CatVTONPipeline(
|
| 53 |
+
base_ckpt="booksforcharlie/stable-diffusion-inpainting",
|
| 54 |
+
attn_ckpt=repo_path,
|
| 55 |
+
attn_ckpt_version="mix",
|
| 56 |
+
weight_dtype=weight_dtype,
|
| 57 |
+
use_tf32=use_tf32,
|
| 58 |
+
device=self.device
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Initialize automasker
|
| 62 |
+
self.automasker = AutoMasker(
|
| 63 |
+
densepose_ckpt=os.path.join(repo_path, "DensePose"),
|
| 64 |
+
schp_ckpt=os.path.join(repo_path, "SCHP"),
|
| 65 |
+
device=self.device
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self.models_loaded = True
|
| 69 |
+
print("β
CatVTON ready!")
|
| 70 |
+
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"β Error loading models: {e}")
|
| 73 |
+
raise
|
| 74 |
|
| 75 |
def generate_tryon(self, person_image, garment_image, progress=gr.Progress()):
|
| 76 |
"""Generate virtual try-on result"""
|
|
|
|
| 112 |
if self.device == "cuda":
|
| 113 |
torch.cuda.empty_cache()
|
| 114 |
|
| 115 |
+
device_msg = "GPU - ~30-60 seconds" if self.device == "cuda" else "CPU - ~2-5 minutes"
|
| 116 |
+
progress(0.6, desc=f"Running virtual try-on on {device_msg}...")
|
| 117 |
|
| 118 |
# Run inference
|
| 119 |
result = self.pipeline(
|
|
|
|
| 127 |
width=target_width
|
| 128 |
)[0]
|
| 129 |
|
| 130 |
+
# Clear memory after inference
|
| 131 |
+
gc.collect()
|
| 132 |
+
if self.device == "cuda":
|
| 133 |
+
torch.cuda.empty_cache()
|
| 134 |
+
|
| 135 |
progress(1.0, desc="Complete!")
|
| 136 |
|
| 137 |
+
return result, f"β
Virtual try-on generated successfully on {self.device.upper()}!"
|
| 138 |
|
| 139 |
except Exception as e:
|
| 140 |
import traceback
|
| 141 |
error_msg = f"β Error: {str(e)}\n\n{traceback.format_exc()}"
|
| 142 |
print(error_msg)
|
| 143 |
+
|
| 144 |
+
# Clear memory on error
|
| 145 |
+
gc.collect()
|
| 146 |
+
if torch.cuda.is_available():
|
| 147 |
+
torch.cuda.empty_cache()
|
| 148 |
+
|
| 149 |
return None, error_msg
|
| 150 |
|
| 151 |
# Initialize service
|
| 152 |
+
print("π Initializing CatVTON Service...")
|
| 153 |
service = CatVTONService()
|
| 154 |
|
| 155 |
+
# Preload models on startup (optional - comment out if you want lazy loading)
|
| 156 |
+
# try:
|
| 157 |
+
# service.load_models()
|
| 158 |
+
# except Exception as e:
|
| 159 |
+
# print(f"β οΈ Could not preload models: {e}")
|
| 160 |
+
# print("Models will be loaded on first request")
|
| 161 |
+
|
| 162 |
# Create Gradio Interface
|
| 163 |
def generate_tryon_interface(person_img, garment_img, progress=gr.Progress()):
|
| 164 |
"""Wrapper for Gradio"""
|
|
|
|
| 167 |
|
| 168 |
# Build UI
|
| 169 |
with gr.Blocks(
|
| 170 |
+
title="CatVTON Virtual Try-On",
|
| 171 |
theme=gr.themes.Soft(),
|
| 172 |
css="""
|
| 173 |
.gradio-container {max-width: 1200px !important}
|
|
|
|
| 176 |
"""
|
| 177 |
) as demo:
|
| 178 |
|
| 179 |
+
device_info = "π₯οΈ GPU" if torch.cuda.is_available() else "π» CPU"
|
| 180 |
+
processing_time = "30-60 seconds" if torch.cuda.is_available() else "2-5 minutes"
|
| 181 |
+
|
| 182 |
+
gr.HTML(f"""
|
| 183 |
<div id="title">
|
| 184 |
+
<h1>π CatVTON - Virtual Try-On</h1>
|
| 185 |
</div>
|
| 186 |
<div id="subtitle">
|
| 187 |
<p>Upload a person image and a garment to see how it looks on them!</p>
|
| 188 |
+
<p><strong>Device:</strong> {device_info} | <strong>Processing Time:</strong> ~{processing_time}</p>
|
| 189 |
+
<p><em>First run downloads models (~5GB) - subsequent runs are faster!</em></p>
|
| 190 |
</div>
|
| 191 |
""")
|
| 192 |
|
|
|
|
| 216 |
- Person should face camera directly
|
| 217 |
- Garment on plain/white background
|
| 218 |
- Works best with shirts, jackets, tops
|
| 219 |
+
- Avoid images with multiple people
|
| 220 |
""")
|
| 221 |
|
| 222 |
with gr.Column():
|
|
|
|
| 227 |
)
|
| 228 |
status_output = gr.Textbox(
|
| 229 |
label="Status",
|
| 230 |
+
lines=3,
|
| 231 |
+
show_label=True
|
| 232 |
)
|
| 233 |
|
| 234 |
+
# Examples (only show if examples directory exists)
|
| 235 |
+
if os.path.exists("examples"):
|
| 236 |
+
gr.Markdown("### π Example Images")
|
| 237 |
+
example_files = []
|
| 238 |
+
if os.path.exists("examples/person1.jpg") and os.path.exists("examples/garment1.jpg"):
|
| 239 |
+
example_files.append(["examples/person1.jpg", "examples/garment1.jpg"])
|
| 240 |
+
if os.path.exists("examples/person2.jpg") and os.path.exists("examples/garment2.jpg"):
|
| 241 |
+
example_files.append(["examples/person2.jpg", "examples/garment2.jpg"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
+
if example_files:
|
| 244 |
+
gr.Examples(
|
| 245 |
+
examples=example_files,
|
| 246 |
+
inputs=[person_input, garment_input],
|
| 247 |
+
label="Try these examples"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Footer
|
| 251 |
+
gr.Markdown("""
|
| 252 |
+
---
|
| 253 |
+
### βΉοΈ About
|
| 254 |
+
This app uses **CatVTON** (Concatenation-based Attention Virtual Try-On) for realistic garment transfer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
+
- Model: [zhengchong/CatVTON](https://huggingface.co/zhengchong/CatVTON)
|
| 257 |
+
- Based on Stable Diffusion Inpainting
|
| 258 |
+
- Supports upper body garments (shirts, jackets, tops)
|
| 259 |
|
| 260 |
+
**Note:** Processing time depends on hardware. GPU is recommended for faster results.
|
| 261 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
# Connect button
|
| 264 |
generate_btn.click(
|
|
|
|
| 269 |
|
| 270 |
# Launch app
|
| 271 |
if __name__ == "__main__":
|
| 272 |
+
print("\n" + "="*60)
|
| 273 |
+
print("π Starting CatVTON Virtual Try-On Server")
|
| 274 |
+
print("="*60)
|
| 275 |
+
print(f"Device: {service.device}")
|
| 276 |
+
print(f"Server: http://0.0.0.0:7860")
|
| 277 |
+
print("="*60 + "\n")
|
| 278 |
+
|
| 279 |
+
demo.queue(
|
| 280 |
+
max_size=20, # Max queue size
|
| 281 |
+
default_concurrency_limit=2 # Limit concurrent requests
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
demo.launch(
|
| 285 |
server_name="0.0.0.0",
|
| 286 |
server_port=7860,
|
| 287 |
+
show_error=True,
|
| 288 |
+
share=False # Don't create public link on HF Spaces
|
| 289 |
)
|