feylur commited on
Commit
684c222
Β·
verified Β·
1 Parent(s): 070ae90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -160
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import sys
3
  import torch
@@ -5,8 +6,6 @@ import gradio as gr
5
  from PIL import Image
6
  import gc
7
  import traceback
8
- import base64
9
- import io
10
  from huggingface_hub import snapshot_download
11
 
12
  sys.path.insert(0, '/app/CatVTON')
@@ -15,17 +14,13 @@ from model.pipeline import CatVTONPipeline
15
  from model.cloth_masker import AutoMasker
16
  from utils import init_weight_dtype, resize_and_crop, resize_and_padding
17
 
18
- # Global model variables
19
  pipeline = None
20
  automasker = None
21
- models_loaded = False
22
 
23
  def load_models():
24
- """Load CatVTON models if not already loaded"""
25
- global pipeline, automasker, models_loaded
26
 
27
- if models_loaded and pipeline is not None and automasker is not None:
28
- print("βœ… Models already loaded", file=sys.stderr)
29
  return
30
 
31
  print("πŸ”„ Loading models...", file=sys.stderr)
@@ -36,139 +31,60 @@ def load_models():
36
  cache_dir="/tmp/models"
37
  )
38
 
39
- # Create NSFW placeholder if needed
40
  nsfw_path = "/tmp/NSFW.jpg"
41
  if not os.path.exists(nsfw_path):
42
  Image.new('RGB', (512, 512), color='black').save(nsfw_path)
43
 
44
- # Initialize pipeline
45
  pipeline = CatVTONPipeline(
46
  base_ckpt="booksforcharlie/stable-diffusion-inpainting",
47
  attn_ckpt=repo_path,
48
  attn_ckpt_version="mix",
49
  weight_dtype=torch.float16,
50
  use_tf32=True,
51
- device='cuda' if torch.cuda.is_available() else 'cpu'
52
  )
53
 
54
- # Initialize automasker
55
  automasker = AutoMasker(
56
  densepose_ckpt=os.path.join(repo_path, "DensePose"),
57
  schp_ckpt=os.path.join(repo_path, "SCHP"),
58
  device='cpu'
59
  )
60
 
61
- models_loaded = True
62
- print("βœ… Models loaded successfully!", file=sys.stderr)
63
-
64
- # Force garbage collection after loading
65
- gc.collect()
66
- if torch.cuda.is_available():
67
- torch.cuda.empty_cache()
68
 
69
  except Exception as e:
70
- print(f"❌ Error loading models: {e}", file=sys.stderr)
71
- traceback.print_exc(file=sys.stderr)
72
- models_loaded = False
73
  raise
74
 
75
- def _convert_to_pil_image(image_input):
76
- """Convert various input types to PIL Image"""
77
- if image_input is None:
78
- return None
79
-
80
- # Already a PIL Image
81
- if isinstance(image_input, Image.Image):
82
- return image_input.convert('RGB')
83
-
84
- # File path (string)
85
- if isinstance(image_input, str):
86
- # Check if it's a base64 string
87
- if image_input.startswith('data:image') or (len(image_input) > 100 and not os.path.exists(image_input)):
88
- # Try to decode as base64
89
- try:
90
- if ',' in image_input:
91
- # Remove data URI prefix
92
- base64_data = image_input.split(',')[1]
93
- else:
94
- base64_data = image_input
95
-
96
- image_bytes = base64.b64decode(base64_data)
97
- return Image.open(io.BytesIO(image_bytes)).convert('RGB')
98
- except Exception as e:
99
- print(f"⚠️ Failed to decode base64, trying as file path: {e}", file=sys.stderr)
100
-
101
- # Try as file path
102
- if os.path.exists(image_input):
103
- return Image.open(image_input).convert('RGB')
104
- else:
105
- raise ValueError(f"Image path does not exist: {image_input}")
106
-
107
- # Bytes or bytearray
108
- if isinstance(image_input, (bytes, bytearray)):
109
- return Image.open(io.BytesIO(image_input)).convert('RGB')
110
-
111
- # Try to convert using PIL's open
112
- try:
113
- return Image.open(image_input).convert('RGB')
114
- except Exception as e:
115
- raise ValueError(f"Unable to convert input to PIL Image: {type(image_input)}, error: {e}")
116
-
117
  def generate_tryon(person_img, cloth_img):
118
- """
119
- Generate virtual try-on result from person and garment images.
120
-
121
- Args:
122
- person_img: Person image (file path, PIL Image, base64 string, or bytes)
123
- cloth_img: Garment image (file path, PIL Image, base64 string, or bytes)
124
 
125
- Returns:
126
- PIL Image of the try-on result
127
- """
128
  print("="*50, file=sys.stderr)
129
- print(f"πŸ“₯ Received inputs - Person: {type(person_img)}, Cloth: {type(cloth_img)}", file=sys.stderr)
130
 
131
  if person_img is None or cloth_img is None:
132
- error_msg = "Both person and garment images are required!"
133
- print(f"❌ {error_msg}", file=sys.stderr)
134
- raise gr.Error(error_msg)
135
 
136
  try:
137
- # Convert inputs to PIL Images
138
- print("πŸ”„ Converting inputs to PIL Images...", file=sys.stderr)
139
- person_img = _convert_to_pil_image(person_img)
140
- cloth_img = _convert_to_pil_image(cloth_img)
141
-
142
- if person_img is None or cloth_img is None:
143
- error_msg = "Failed to convert images to PIL format!"
144
- print(f"❌ {error_msg}", file=sys.stderr)
145
- raise gr.Error(error_msg)
146
 
147
- print(f"βœ… Images converted - Person: {person_img.size}, Cloth: {cloth_img.size}", file=sys.stderr)
148
 
149
- # Load models if not already loaded
150
  load_models()
151
 
152
- if pipeline is None or automasker is None:
153
- error_msg = "Failed to load models. Please try again."
154
- print(f"❌ {error_msg}", file=sys.stderr)
155
- raise gr.Error(error_msg)
156
-
157
- # Resize images to target dimensions
158
  target_height = 1024
159
  target_width = 768
160
- print(f"πŸ”„ Resizing images to {target_width}x{target_height}...", file=sys.stderr)
161
-
162
  person_img = resize_and_crop(person_img, (target_width, target_height))
163
  cloth_img = resize_and_padding(cloth_img, (target_width, target_height))
164
 
165
- # Generate mask
166
- print("πŸ”„ Generating mask...", file=sys.stderr)
167
  mask = automasker(person_img, "upper")['mask']
168
  gc.collect()
169
 
170
- # Generate try-on result
171
- print("πŸ”„ Generating try-on result (this may take 2-3 minutes)...", file=sys.stderr)
172
  result = pipeline(
173
  image=person_img,
174
  condition_image=cloth_img,
@@ -180,77 +96,32 @@ def generate_tryon(person_img, cloth_img):
180
  width=target_width
181
  )[0]
182
 
183
- # Clean up
184
- gc.collect()
185
- if torch.cuda.is_available():
186
- torch.cuda.empty_cache()
187
-
188
- print("βœ… Try-on generation completed successfully!", file=sys.stderr)
189
  return result
190
 
191
- except gr.Error:
192
- # Re-raise Gradio errors as-is
193
- raise
194
  except Exception as e:
195
- error_msg = f"Error during try-on generation: {str(e)}"
196
- print(f"❌ {error_msg}", file=sys.stderr)
197
- traceback.print_exc(file=sys.stderr)
198
- raise gr.Error(error_msg)
199
 
200
- # Create Gradio Interface with proper API configuration
201
  demo = gr.Interface(
202
  fn=generate_tryon,
203
  inputs=[
204
- gr.Image(
205
- label="Person Image",
206
- type="filepath", # Accepts file paths, but we handle other types in the function
207
- sources=["upload", "webcam"],
208
- ),
209
- gr.Image(
210
- label="Garment Image",
211
- type="filepath",
212
- sources=["upload", "webcam"],
213
- )
214
  ],
215
- outputs=gr.Image(
216
- label="Try-On Result",
217
- type="pil" # Return PIL Image for better API compatibility
218
- ),
219
  title="Try-Space Virtual Try-On",
220
- description="""
221
- Upload person and garment images to generate a virtual try-on result.
222
-
223
- **Processing Time:** 2-3 minutes on GPU T4
224
-
225
- **Tips:**
226
- - Use clear, well-lit images
227
- - Person should be facing forward
228
- - Garment should be on a plain background
229
- """,
230
- api_name="generate_tryon", # Named endpoint for API access
231
- allow_flagging="never",
232
- examples=None, # Can add examples later if needed
233
  )
234
 
235
  if __name__ == "__main__":
236
- print("πŸš€ Starting Try-Space Virtual Try-On Space...", file=sys.stderr)
237
-
238
- # Try to load models at startup (non-blocking)
239
  try:
240
- print("πŸ”„ Pre-loading models...", file=sys.stderr)
241
  load_models()
242
- except Exception as e:
243
- print(f"⚠️ Failed to pre-load models: {e}", file=sys.stderr)
244
- print("⚠️ Models will be loaded on first request", file=sys.stderr)
245
-
246
- # Launch with queue for better API handling
247
- demo.queue(
248
- max_size=10, # Limit queue size
249
- default_concurrency_limit=1 # Process one request at a time
250
- ).launch(
251
- server_name="0.0.0.0",
252
- server_port=7860,
253
- show_error=True,
254
- share=False,
255
- enable_queue=True
256
- )
 
1
+ import patch_gradio
2
  import os
3
  import sys
4
  import torch
 
6
  from PIL import Image
7
  import gc
8
  import traceback
 
 
9
  from huggingface_hub import snapshot_download
10
 
11
  sys.path.insert(0, '/app/CatVTON')
 
14
  from model.cloth_masker import AutoMasker
15
  from utils import init_weight_dtype, resize_and_crop, resize_and_padding
16
 
 
17
  pipeline = None
18
  automasker = None
 
19
 
20
  def load_models():
21
+ global pipeline, automasker
 
22
 
23
+ if pipeline is not None and automasker is not None:
 
24
  return
25
 
26
  print("πŸ”„ Loading models...", file=sys.stderr)
 
31
  cache_dir="/tmp/models"
32
  )
33
 
 
34
  nsfw_path = "/tmp/NSFW.jpg"
35
  if not os.path.exists(nsfw_path):
36
  Image.new('RGB', (512, 512), color='black').save(nsfw_path)
37
 
 
38
  pipeline = CatVTONPipeline(
39
  base_ckpt="booksforcharlie/stable-diffusion-inpainting",
40
  attn_ckpt=repo_path,
41
  attn_ckpt_version="mix",
42
  weight_dtype=torch.float16,
43
  use_tf32=True,
44
+ device='cuda'
45
  )
46
 
 
47
  automasker = AutoMasker(
48
  densepose_ckpt=os.path.join(repo_path, "DensePose"),
49
  schp_ckpt=os.path.join(repo_path, "SCHP"),
50
  device='cpu'
51
  )
52
 
53
+ print("βœ… Models loaded!", file=sys.stderr)
 
 
 
 
 
 
54
 
55
  except Exception as e:
56
+ print(f"❌ Error: {e}", file=sys.stderr)
57
+ traceback.print_exc()
 
58
  raise
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def generate_tryon(person_img, cloth_img):
61
+ """CRITICAL: Removed progress parameter - causes API issues"""
 
 
 
 
 
62
 
 
 
 
63
  print("="*50, file=sys.stderr)
64
+ print(f"Received - Person: {type(person_img)}, Cloth: {type(cloth_img)}", file=sys.stderr)
65
 
66
  if person_img is None or cloth_img is None:
67
+ raise gr.Error("Both images required!")
 
 
68
 
69
  try:
70
+ # Convert filepaths to PIL
71
+ if isinstance(person_img, str):
72
+ person_img = Image.open(person_img).convert('RGB')
73
+ if isinstance(cloth_img, str):
74
+ cloth_img = Image.open(cloth_img).convert('RGB')
 
 
 
 
75
 
76
+ print("Images converted to PIL", file=sys.stderr)
77
 
 
78
  load_models()
79
 
 
 
 
 
 
 
80
  target_height = 1024
81
  target_width = 768
 
 
82
  person_img = resize_and_crop(person_img, (target_width, target_height))
83
  cloth_img = resize_and_padding(cloth_img, (target_width, target_height))
84
 
 
 
85
  mask = automasker(person_img, "upper")['mask']
86
  gc.collect()
87
 
 
 
88
  result = pipeline(
89
  image=person_img,
90
  condition_image=cloth_img,
 
96
  width=target_width
97
  )[0]
98
 
99
+ print("βœ… Success!", file=sys.stderr)
 
 
 
 
 
100
  return result
101
 
 
 
 
102
  except Exception as e:
103
+ print(f"❌ Error: {e}", file=sys.stderr)
104
+ traceback.print_exc()
105
+ raise gr.Error(str(e))
 
106
 
107
+ # CRITICAL: Use gr.Interface instead of gr.Blocks for better API support
108
  demo = gr.Interface(
109
  fn=generate_tryon,
110
  inputs=[
111
+ gr.Image(label="Person Image", type="filepath"),
112
+ gr.Image(label="Garment Image", type="filepath")
 
 
 
 
 
 
 
 
113
  ],
114
+ outputs=gr.Image(label="Result", type="filepath"),
 
 
 
115
  title="Try-Space Virtual Try-On",
116
+ description="Upload person and garment images. Processing takes 2-3 minutes on GPU T4.",
117
+ api_name="generate_tryon",
118
+ allow_flagging="never"
 
 
 
 
 
 
 
 
 
 
119
  )
120
 
121
  if __name__ == "__main__":
122
+ print("πŸš€ Starting...", file=sys.stderr)
 
 
123
  try:
 
124
  load_models()
125
+ except:
126
+ pass
127
+ demo.queue().launch(show_error=True)