feylur commited on
Commit
070ae90
Β·
verified Β·
1 Parent(s): 48a4c12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -31
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import patch_gradio
2
  import os
3
  import sys
4
  import torch
@@ -6,6 +5,8 @@ import gradio as gr
6
  from PIL import Image
7
  import gc
8
  import traceback
 
 
9
  from huggingface_hub import snapshot_download
10
 
11
  sys.path.insert(0, '/app/CatVTON')
@@ -14,13 +15,17 @@ from model.pipeline import CatVTONPipeline
14
  from model.cloth_masker import AutoMasker
15
  from utils import init_weight_dtype, resize_and_crop, resize_and_padding
16
 
 
17
  pipeline = None
18
  automasker = None
 
19
 
20
  def load_models():
21
- global pipeline, automasker
 
22
 
23
- if pipeline is not None and automasker is not None:
 
24
  return
25
 
26
  print("πŸ”„ Loading models...", file=sys.stderr)
@@ -31,60 +36,139 @@ def load_models():
31
  cache_dir="/tmp/models"
32
  )
33
 
 
34
  nsfw_path = "/tmp/NSFW.jpg"
35
  if not os.path.exists(nsfw_path):
36
  Image.new('RGB', (512, 512), color='black').save(nsfw_path)
37
 
 
38
  pipeline = CatVTONPipeline(
39
  base_ckpt="booksforcharlie/stable-diffusion-inpainting",
40
  attn_ckpt=repo_path,
41
  attn_ckpt_version="mix",
42
  weight_dtype=torch.float16,
43
  use_tf32=True,
44
- device='cuda'
45
  )
46
 
 
47
  automasker = AutoMasker(
48
  densepose_ckpt=os.path.join(repo_path, "DensePose"),
49
  schp_ckpt=os.path.join(repo_path, "SCHP"),
50
  device='cpu'
51
  )
52
 
53
- print("βœ… Models loaded!", file=sys.stderr)
 
 
 
 
 
 
54
 
55
  except Exception as e:
56
- print(f"❌ Error: {e}", file=sys.stderr)
57
- traceback.print_exc()
 
58
  raise
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def generate_tryon(person_img, cloth_img):
61
- """CRITICAL: Removed progress parameter - causes API issues"""
 
 
 
 
 
62
 
 
 
 
63
  print("="*50, file=sys.stderr)
64
- print(f"Received - Person: {type(person_img)}, Cloth: {type(cloth_img)}", file=sys.stderr)
65
 
66
  if person_img is None or cloth_img is None:
67
- raise gr.Error("Both images required!")
 
 
68
 
69
  try:
70
- # Convert filepaths to PIL
71
- if isinstance(person_img, str):
72
- person_img = Image.open(person_img).convert('RGB')
73
- if isinstance(cloth_img, str):
74
- cloth_img = Image.open(cloth_img).convert('RGB')
 
 
 
 
75
 
76
- print("Images converted to PIL", file=sys.stderr)
77
 
 
78
  load_models()
79
 
 
 
 
 
 
 
80
  target_height = 1024
81
  target_width = 768
 
 
82
  person_img = resize_and_crop(person_img, (target_width, target_height))
83
  cloth_img = resize_and_padding(cloth_img, (target_width, target_height))
84
 
 
 
85
  mask = automasker(person_img, "upper")['mask']
86
  gc.collect()
87
 
 
 
88
  result = pipeline(
89
  image=person_img,
90
  condition_image=cloth_img,
@@ -96,32 +180,77 @@ def generate_tryon(person_img, cloth_img):
96
  width=target_width
97
  )[0]
98
 
99
- print("βœ… Success!", file=sys.stderr)
 
 
 
 
 
100
  return result
101
 
 
 
 
102
  except Exception as e:
103
- print(f"❌ Error: {e}", file=sys.stderr)
104
- traceback.print_exc()
105
- raise gr.Error(str(e))
 
106
 
107
- # CRITICAL: Use gr.Interface instead of gr.Blocks for better API support
108
  demo = gr.Interface(
109
  fn=generate_tryon,
110
  inputs=[
111
- gr.Image(label="Person Image", type="filepath"),
112
- gr.Image(label="Garment Image", type="filepath")
 
 
 
 
 
 
 
 
113
  ],
114
- outputs=gr.Image(label="Result", type="filepath"),
 
 
 
115
  title="Try-Space Virtual Try-On",
116
- description="Upload person and garment images. Processing takes 2-3 minutes on GPU T4.",
117
- api_name="generate_tryon",
118
- allow_flagging="never"
 
 
 
 
 
 
 
 
 
 
119
  )
120
 
121
  if __name__ == "__main__":
122
- print("πŸš€ Starting...", file=sys.stderr)
 
 
123
  try:
 
124
  load_models()
125
- except:
126
- pass
127
- demo.queue().launch(show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
  import torch
 
5
  from PIL import Image
6
  import gc
7
  import traceback
8
+ import base64
9
+ import io
10
  from huggingface_hub import snapshot_download
11
 
12
  sys.path.insert(0, '/app/CatVTON')
 
15
  from model.cloth_masker import AutoMasker
16
  from utils import init_weight_dtype, resize_and_crop, resize_and_padding
17
 
18
+ # Global model variables
19
  pipeline = None
20
  automasker = None
21
+ models_loaded = False
22
 
23
  def load_models():
24
+ """Load CatVTON models if not already loaded"""
25
+ global pipeline, automasker, models_loaded
26
 
27
+ if models_loaded and pipeline is not None and automasker is not None:
28
+ print("βœ… Models already loaded", file=sys.stderr)
29
  return
30
 
31
  print("πŸ”„ Loading models...", file=sys.stderr)
 
36
  cache_dir="/tmp/models"
37
  )
38
 
39
+ # Create NSFW placeholder if needed
40
  nsfw_path = "/tmp/NSFW.jpg"
41
  if not os.path.exists(nsfw_path):
42
  Image.new('RGB', (512, 512), color='black').save(nsfw_path)
43
 
44
+ # Initialize pipeline
45
  pipeline = CatVTONPipeline(
46
  base_ckpt="booksforcharlie/stable-diffusion-inpainting",
47
  attn_ckpt=repo_path,
48
  attn_ckpt_version="mix",
49
  weight_dtype=torch.float16,
50
  use_tf32=True,
51
+ device='cuda' if torch.cuda.is_available() else 'cpu'
52
  )
53
 
54
+ # Initialize automasker
55
  automasker = AutoMasker(
56
  densepose_ckpt=os.path.join(repo_path, "DensePose"),
57
  schp_ckpt=os.path.join(repo_path, "SCHP"),
58
  device='cpu'
59
  )
60
 
61
+ models_loaded = True
62
+ print("βœ… Models loaded successfully!", file=sys.stderr)
63
+
64
+ # Force garbage collection after loading
65
+ gc.collect()
66
+ if torch.cuda.is_available():
67
+ torch.cuda.empty_cache()
68
 
69
  except Exception as e:
70
+ print(f"❌ Error loading models: {e}", file=sys.stderr)
71
+ traceback.print_exc(file=sys.stderr)
72
+ models_loaded = False
73
  raise
74
 
75
+ def _convert_to_pil_image(image_input):
76
+ """Convert various input types to PIL Image"""
77
+ if image_input is None:
78
+ return None
79
+
80
+ # Already a PIL Image
81
+ if isinstance(image_input, Image.Image):
82
+ return image_input.convert('RGB')
83
+
84
+ # File path (string)
85
+ if isinstance(image_input, str):
86
+ # Check if it's a base64 string
87
+ if image_input.startswith('data:image') or (len(image_input) > 100 and not os.path.exists(image_input)):
88
+ # Try to decode as base64
89
+ try:
90
+ if ',' in image_input:
91
+ # Remove data URI prefix
92
+ base64_data = image_input.split(',')[1]
93
+ else:
94
+ base64_data = image_input
95
+
96
+ image_bytes = base64.b64decode(base64_data)
97
+ return Image.open(io.BytesIO(image_bytes)).convert('RGB')
98
+ except Exception as e:
99
+ print(f"⚠️ Failed to decode base64, trying as file path: {e}", file=sys.stderr)
100
+
101
+ # Try as file path
102
+ if os.path.exists(image_input):
103
+ return Image.open(image_input).convert('RGB')
104
+ else:
105
+ raise ValueError(f"Image path does not exist: {image_input}")
106
+
107
+ # Bytes or bytearray
108
+ if isinstance(image_input, (bytes, bytearray)):
109
+ return Image.open(io.BytesIO(image_input)).convert('RGB')
110
+
111
+ # Try to convert using PIL's open
112
+ try:
113
+ return Image.open(image_input).convert('RGB')
114
+ except Exception as e:
115
+ raise ValueError(f"Unable to convert input to PIL Image: {type(image_input)}, error: {e}")
116
+
117
  def generate_tryon(person_img, cloth_img):
118
+ """
119
+ Generate virtual try-on result from person and garment images.
120
+
121
+ Args:
122
+ person_img: Person image (file path, PIL Image, base64 string, or bytes)
123
+ cloth_img: Garment image (file path, PIL Image, base64 string, or bytes)
124
 
125
+ Returns:
126
+ PIL Image of the try-on result
127
+ """
128
  print("="*50, file=sys.stderr)
129
+ print(f"πŸ“₯ Received inputs - Person: {type(person_img)}, Cloth: {type(cloth_img)}", file=sys.stderr)
130
 
131
  if person_img is None or cloth_img is None:
132
+ error_msg = "Both person and garment images are required!"
133
+ print(f"❌ {error_msg}", file=sys.stderr)
134
+ raise gr.Error(error_msg)
135
 
136
  try:
137
+ # Convert inputs to PIL Images
138
+ print("πŸ”„ Converting inputs to PIL Images...", file=sys.stderr)
139
+ person_img = _convert_to_pil_image(person_img)
140
+ cloth_img = _convert_to_pil_image(cloth_img)
141
+
142
+ if person_img is None or cloth_img is None:
143
+ error_msg = "Failed to convert images to PIL format!"
144
+ print(f"❌ {error_msg}", file=sys.stderr)
145
+ raise gr.Error(error_msg)
146
 
147
+ print(f"βœ… Images converted - Person: {person_img.size}, Cloth: {cloth_img.size}", file=sys.stderr)
148
 
149
+ # Load models if not already loaded
150
  load_models()
151
 
152
+ if pipeline is None or automasker is None:
153
+ error_msg = "Failed to load models. Please try again."
154
+ print(f"❌ {error_msg}", file=sys.stderr)
155
+ raise gr.Error(error_msg)
156
+
157
+ # Resize images to target dimensions
158
  target_height = 1024
159
  target_width = 768
160
+ print(f"πŸ”„ Resizing images to {target_width}x{target_height}...", file=sys.stderr)
161
+
162
  person_img = resize_and_crop(person_img, (target_width, target_height))
163
  cloth_img = resize_and_padding(cloth_img, (target_width, target_height))
164
 
165
+ # Generate mask
166
+ print("πŸ”„ Generating mask...", file=sys.stderr)
167
  mask = automasker(person_img, "upper")['mask']
168
  gc.collect()
169
 
170
+ # Generate try-on result
171
+ print("πŸ”„ Generating try-on result (this may take 2-3 minutes)...", file=sys.stderr)
172
  result = pipeline(
173
  image=person_img,
174
  condition_image=cloth_img,
 
180
  width=target_width
181
  )[0]
182
 
183
+ # Clean up
184
+ gc.collect()
185
+ if torch.cuda.is_available():
186
+ torch.cuda.empty_cache()
187
+
188
+ print("βœ… Try-on generation completed successfully!", file=sys.stderr)
189
  return result
190
 
191
+ except gr.Error:
192
+ # Re-raise Gradio errors as-is
193
+ raise
194
  except Exception as e:
195
+ error_msg = f"Error during try-on generation: {str(e)}"
196
+ print(f"❌ {error_msg}", file=sys.stderr)
197
+ traceback.print_exc(file=sys.stderr)
198
+ raise gr.Error(error_msg)
199
 
200
+ # Create Gradio Interface with proper API configuration
201
  demo = gr.Interface(
202
  fn=generate_tryon,
203
  inputs=[
204
+ gr.Image(
205
+ label="Person Image",
206
+ type="filepath", # Accepts file paths, but we handle other types in the function
207
+ sources=["upload", "webcam"],
208
+ ),
209
+ gr.Image(
210
+ label="Garment Image",
211
+ type="filepath",
212
+ sources=["upload", "webcam"],
213
+ )
214
  ],
215
+ outputs=gr.Image(
216
+ label="Try-On Result",
217
+ type="pil" # Return PIL Image for better API compatibility
218
+ ),
219
  title="Try-Space Virtual Try-On",
220
+ description="""
221
+ Upload person and garment images to generate a virtual try-on result.
222
+
223
+ **Processing Time:** 2-3 minutes on GPU T4
224
+
225
+ **Tips:**
226
+ - Use clear, well-lit images
227
+ - Person should be facing forward
228
+ - Garment should be on a plain background
229
+ """,
230
+ api_name="generate_tryon", # Named endpoint for API access
231
+ allow_flagging="never",
232
+ examples=None, # Can add examples later if needed
233
  )
234
 
235
  if __name__ == "__main__":
236
+ print("πŸš€ Starting Try-Space Virtual Try-On Space...", file=sys.stderr)
237
+
238
+ # Try to load models at startup (non-blocking)
239
  try:
240
+ print("πŸ”„ Pre-loading models...", file=sys.stderr)
241
  load_models()
242
+ except Exception as e:
243
+ print(f"⚠️ Failed to pre-load models: {e}", file=sys.stderr)
244
+ print("⚠️ Models will be loaded on first request", file=sys.stderr)
245
+
246
+ # Launch with queue for better API handling
247
+ demo.queue(
248
+ max_size=10, # Limit queue size
249
+ default_concurrency_limit=1 # Process one request at a time
250
+ ).launch(
251
+ server_name="0.0.0.0",
252
+ server_port=7860,
253
+ show_error=True,
254
+ share=False,
255
+ enable_queue=True
256
+ )