feylur commited on
Commit
0ae5112
Β·
verified Β·
1 Parent(s): b0ed745

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -86
app.py CHANGED
@@ -16,7 +16,10 @@ from utils import init_weight_dtype, resize_and_crop, resize_and_padding
16
 
17
  class CatVTONService:
18
  def __init__(self):
19
- self.device = "cpu" # Will use CPU on free tier
 
 
 
20
  self.pipeline = None
21
  self.automasker = None
22
  self.models_loaded = False
@@ -28,33 +31,46 @@ class CatVTONService:
28
 
29
  print("πŸ”„ Loading CatVTON models (this happens once)...")
30
 
31
- # Download model weights from HuggingFace Hub - CACHED automatically
32
- repo_path = snapshot_download(
33
- repo_id="zhengchong/CatVTON",
34
- cache_dir="./model_cache"
35
- )
36
-
37
- print("βœ… Models downloaded and cached!")
38
-
39
- # Initialize pipeline
40
- self.pipeline = CatVTONPipeline(
41
- base_ckpt="booksforcharlie/stable-diffusion-inpainting",
42
- attn_ckpt=repo_path,
43
- attn_ckpt_version="mix",
44
- weight_dtype=init_weight_dtype("fp16" if self.device == "cuda" else "fp32"),
45
- use_tf32=False, # CPU doesn't support TF32
46
- device=self.device
47
- )
48
-
49
- # Initialize automasker
50
- self.automasker = AutoMasker(
51
- densepose_ckpt=os.path.join(repo_path, "DensePose"),
52
- schp_ckpt=os.path.join(repo_path, "SCHP"),
53
- device=self.device
54
- )
55
-
56
- self.models_loaded = True
57
- print("βœ… CatVTON ready!")
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def generate_tryon(self, person_image, garment_image, progress=gr.Progress()):
60
  """Generate virtual try-on result"""
@@ -96,7 +112,8 @@ class CatVTONService:
96
  if self.device == "cuda":
97
  torch.cuda.empty_cache()
98
 
99
- progress(0.6, desc="Running virtual try-on (this may take 2-5 minutes on CPU)...")
 
100
 
101
  # Run inference
102
  result = self.pipeline(
@@ -110,19 +127,38 @@ class CatVTONService:
110
  width=target_width
111
  )[0]
112
 
 
 
 
 
 
113
  progress(1.0, desc="Complete!")
114
 
115
- return result, "βœ… Virtual try-on generated successfully!"
116
 
117
  except Exception as e:
118
  import traceback
119
  error_msg = f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
120
  print(error_msg)
 
 
 
 
 
 
121
  return None, error_msg
122
 
123
  # Initialize service
 
124
  service = CatVTONService()
125
 
 
 
 
 
 
 
 
126
  # Create Gradio Interface
127
  def generate_tryon_interface(person_img, garment_img, progress=gr.Progress()):
128
  """Wrapper for Gradio"""
@@ -131,7 +167,7 @@ def generate_tryon_interface(person_img, garment_img, progress=gr.Progress()):
131
 
132
  # Build UI
133
  with gr.Blocks(
134
- title="Try-Space - CatVTON Virtual Try-On",
135
  theme=gr.themes.Soft(),
136
  css="""
137
  .gradio-container {max-width: 1200px !important}
@@ -140,13 +176,17 @@ with gr.Blocks(
140
  """
141
  ) as demo:
142
 
143
- gr.HTML("""
 
 
 
144
  <div id="title">
145
- <h1>πŸ‘— Try-Space - Virtual Try-On</h1>
146
  </div>
147
  <div id="subtitle">
148
  <p>Upload a person image and a garment to see how it looks on them!</p>
149
- <p><strong>⚠️ Note:</strong> Processing takes 2-5 minutes on CPU. First run downloads models (~5GB).</p>
 
150
  </div>
151
  """)
152
 
@@ -176,6 +216,7 @@ with gr.Blocks(
176
  - Person should face camera directly
177
  - Garment on plain/white background
178
  - Works best with shirts, jackets, tops
 
179
  """)
180
 
181
  with gr.Column():
@@ -186,61 +227,38 @@ with gr.Blocks(
186
  )
187
  status_output = gr.Textbox(
188
  label="Status",
189
- lines=3
 
190
  )
191
 
192
- # Examples
193
- gr.Markdown("### πŸ“‹ Example Images")
194
- gr.Examples(
195
- examples=[
196
- ["examples/person1.jpg", "examples/garment1.jpg"],
197
- ["examples/person2.jpg", "examples/garment2.jpg"],
198
- ],
199
- inputs=[person_input, garment_input],
200
- label="Try these examples (if available)"
201
- )
202
-
203
- # API Documentation
204
- with gr.Accordion("πŸ“‘ API Usage", open=False):
205
- gr.Markdown(f"""
206
- ### Using the Try-Space API
207
 
208
- **Endpoint:** `https://YOUR_USERNAME-try-space.hf.space/api/predict`
209
-
210
- **Python Example:**
211
- ```python
212
- import requests
213
- import base64
214
- from io import BytesIO
215
- from PIL import Image
216
-
217
- def encode_image(image_path):
218
- with open(image_path, "rb") as f:
219
- return base64.b64encode(f.read()).decode()
220
-
221
- url = "https://YOUR_USERNAME-try-space.hf.space/api/predict"
222
-
223
- response = requests.post(
224
- url,
225
- json={{
226
- "data": [
227
- encode_image("person.jpg"),
228
- encode_image("garment.jpg")
229
- ]
230
- }}
231
- )
232
 
233
- result = response.json()
234
- print(result)
235
- ```
236
 
237
- **cURL Example:**
238
- ```bash
239
- curl -X POST https://YOUR_USERNAME-try-space.hf.space/api/predict \\
240
- -H "Content-Type: application/json" \\
241
- -d '{{"data": ["<person_base64>", "<garment_base64>"]}}'
242
- ```
243
- """)
244
 
245
  # Connect button
246
  generate_btn.click(
@@ -251,9 +269,21 @@ with gr.Blocks(
251
 
252
  # Launch app
253
  if __name__ == "__main__":
254
- demo.queue(max_size=20) # Enable queue for long-running tasks
 
 
 
 
 
 
 
 
 
 
 
255
  demo.launch(
256
  server_name="0.0.0.0",
257
  server_port=7860,
258
- show_error=True
 
259
  )
 
16
 
17
  class CatVTONService:
18
  def __init__(self):
19
+ # Auto-detect device
20
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ print(f"πŸ–₯️ Using device: {self.device}")
22
+
23
  self.pipeline = None
24
  self.automasker = None
25
  self.models_loaded = False
 
31
 
32
  print("πŸ”„ Loading CatVTON models (this happens once)...")
33
 
34
+ try:
35
+ # Download model weights from HuggingFace Hub - CACHED automatically
36
+ repo_path = snapshot_download(
37
+ repo_id="zhengchong/CatVTON",
38
+ cache_dir="./model_cache",
39
+ resume_download=True, # Resume if interrupted
40
+ local_files_only=False # Allow downloading
41
+ )
42
+
43
+ print(f"βœ… Models downloaded to: {repo_path}")
44
+
45
+ # Determine weight dtype based on device
46
+ weight_dtype = init_weight_dtype("fp16" if self.device == "cuda" else "fp32")
47
+ use_tf32 = self.device == "cuda" # Only use TF32 on CUDA
48
+
49
+ print(f"βš™οΈ Weight dtype: {weight_dtype}, TF32: {use_tf32}")
50
+
51
+ # Initialize pipeline
52
+ self.pipeline = CatVTONPipeline(
53
+ base_ckpt="booksforcharlie/stable-diffusion-inpainting",
54
+ attn_ckpt=repo_path,
55
+ attn_ckpt_version="mix",
56
+ weight_dtype=weight_dtype,
57
+ use_tf32=use_tf32,
58
+ device=self.device
59
+ )
60
+
61
+ # Initialize automasker
62
+ self.automasker = AutoMasker(
63
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
64
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
65
+ device=self.device
66
+ )
67
+
68
+ self.models_loaded = True
69
+ print("βœ… CatVTON ready!")
70
+
71
+ except Exception as e:
72
+ print(f"❌ Error loading models: {e}")
73
+ raise
74
 
75
  def generate_tryon(self, person_image, garment_image, progress=gr.Progress()):
76
  """Generate virtual try-on result"""
 
112
  if self.device == "cuda":
113
  torch.cuda.empty_cache()
114
 
115
+ device_msg = "GPU - ~30-60 seconds" if self.device == "cuda" else "CPU - ~2-5 minutes"
116
+ progress(0.6, desc=f"Running virtual try-on on {device_msg}...")
117
 
118
  # Run inference
119
  result = self.pipeline(
 
127
  width=target_width
128
  )[0]
129
 
130
+ # Clear memory after inference
131
+ gc.collect()
132
+ if self.device == "cuda":
133
+ torch.cuda.empty_cache()
134
+
135
  progress(1.0, desc="Complete!")
136
 
137
+ return result, f"βœ… Virtual try-on generated successfully on {self.device.upper()}!"
138
 
139
  except Exception as e:
140
  import traceback
141
  error_msg = f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
142
  print(error_msg)
143
+
144
+ # Clear memory on error
145
+ gc.collect()
146
+ if torch.cuda.is_available():
147
+ torch.cuda.empty_cache()
148
+
149
  return None, error_msg
150
 
151
  # Initialize service
152
+ print("πŸš€ Initializing CatVTON Service...")
153
  service = CatVTONService()
154
 
155
+ # Preload models on startup (optional - comment out if you want lazy loading)
156
+ # try:
157
+ # service.load_models()
158
+ # except Exception as e:
159
+ # print(f"⚠️ Could not preload models: {e}")
160
+ # print("Models will be loaded on first request")
161
+
162
  # Create Gradio Interface
163
  def generate_tryon_interface(person_img, garment_img, progress=gr.Progress()):
164
  """Wrapper for Gradio"""
 
167
 
168
  # Build UI
169
  with gr.Blocks(
170
+ title="CatVTON Virtual Try-On",
171
  theme=gr.themes.Soft(),
172
  css="""
173
  .gradio-container {max-width: 1200px !important}
 
176
  """
177
  ) as demo:
178
 
179
+ device_info = "πŸ–₯️ GPU" if torch.cuda.is_available() else "πŸ’» CPU"
180
+ processing_time = "30-60 seconds" if torch.cuda.is_available() else "2-5 minutes"
181
+
182
+ gr.HTML(f"""
183
  <div id="title">
184
+ <h1>πŸ‘— CatVTON - Virtual Try-On</h1>
185
  </div>
186
  <div id="subtitle">
187
  <p>Upload a person image and a garment to see how it looks on them!</p>
188
+ <p><strong>Device:</strong> {device_info} | <strong>Processing Time:</strong> ~{processing_time}</p>
189
+ <p><em>First run downloads models (~5GB) - subsequent runs are faster!</em></p>
190
  </div>
191
  """)
192
 
 
216
  - Person should face camera directly
217
  - Garment on plain/white background
218
  - Works best with shirts, jackets, tops
219
+ - Avoid images with multiple people
220
  """)
221
 
222
  with gr.Column():
 
227
  )
228
  status_output = gr.Textbox(
229
  label="Status",
230
+ lines=3,
231
+ show_label=True
232
  )
233
 
234
+ # Examples (only show if examples directory exists)
235
+ if os.path.exists("examples"):
236
+ gr.Markdown("### πŸ“‹ Example Images")
237
+ example_files = []
238
+ if os.path.exists("examples/person1.jpg") and os.path.exists("examples/garment1.jpg"):
239
+ example_files.append(["examples/person1.jpg", "examples/garment1.jpg"])
240
+ if os.path.exists("examples/person2.jpg") and os.path.exists("examples/garment2.jpg"):
241
+ example_files.append(["examples/person2.jpg", "examples/garment2.jpg"])
 
 
 
 
 
 
 
242
 
243
+ if example_files:
244
+ gr.Examples(
245
+ examples=example_files,
246
+ inputs=[person_input, garment_input],
247
+ label="Try these examples"
248
+ )
249
+
250
+ # Footer
251
+ gr.Markdown("""
252
+ ---
253
+ ### ℹ️ About
254
+ This app uses **CatVTON** (Concatenation-based Attention Virtual Try-On) for realistic garment transfer.
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ - Model: [zhengchong/CatVTON](https://huggingface.co/zhengchong/CatVTON)
257
+ - Based on Stable Diffusion Inpainting
258
+ - Supports upper body garments (shirts, jackets, tops)
259
 
260
+ **Note:** Processing time depends on hardware. GPU is recommended for faster results.
261
+ """)
 
 
 
 
 
262
 
263
  # Connect button
264
  generate_btn.click(
 
269
 
270
  # Launch app
271
  if __name__ == "__main__":
272
+ print("\n" + "="*60)
273
+ print("🌐 Starting CatVTON Virtual Try-On Server")
274
+ print("="*60)
275
+ print(f"Device: {service.device}")
276
+ print(f"Server: http://0.0.0.0:7860")
277
+ print("="*60 + "\n")
278
+
279
+ demo.queue(
280
+ max_size=20, # Max queue size
281
+ default_concurrency_limit=2 # Limit concurrent requests
282
+ )
283
+
284
  demo.launch(
285
  server_name="0.0.0.0",
286
  server_port=7860,
287
+ show_error=True,
288
+ share=False # Don't create public link on HF Spaces
289
  )