feylur commited on
Commit
48a4c12
Β·
verified Β·
1 Parent(s): 76d32b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -127
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import patch_gradio # Add this as first import
2
  import os
3
  import sys
4
  import torch
@@ -8,41 +8,33 @@ import gc
8
  import traceback
9
  from huggingface_hub import snapshot_download
10
 
11
- # Add CatVTON to path
12
  sys.path.insert(0, '/app/CatVTON')
13
 
14
  from model.pipeline import CatVTONPipeline
15
  from model.cloth_masker import AutoMasker
16
  from utils import init_weight_dtype, resize_and_crop, resize_and_padding
17
 
18
- # Global variables
19
  pipeline = None
20
  automasker = None
21
 
22
  def load_models():
23
- """Load models once at startup"""
24
  global pipeline, automasker
25
 
26
  if pipeline is not None and automasker is not None:
27
  return
28
 
29
- print("πŸ”„ Downloading/Loading CatVTON models (first time may take 5-10 mins)...")
30
 
31
  try:
32
- # Download and cache models
33
  repo_path = snapshot_download(
34
  repo_id="zhengchong/CatVTON",
35
  cache_dir="/tmp/models"
36
  )
37
 
38
- print(f"βœ… Models downloaded to: {repo_path}")
39
-
40
- # Create NSFW placeholder in writable directory
41
  nsfw_path = "/tmp/NSFW.jpg"
42
  if not os.path.exists(nsfw_path):
43
  Image.new('RGB', (512, 512), color='black').save(nsfw_path)
44
 
45
- print("Initializing pipeline...")
46
  pipeline = CatVTONPipeline(
47
  base_ckpt="booksforcharlie/stable-diffusion-inpainting",
48
  attn_ckpt=repo_path,
@@ -51,76 +43,48 @@ def load_models():
51
  use_tf32=True,
52
  device='cuda'
53
  )
54
- print("βœ… Pipeline loaded!")
55
 
56
- print("Initializing automasker...")
57
  automasker = AutoMasker(
58
  densepose_ckpt=os.path.join(repo_path, "DensePose"),
59
  schp_ckpt=os.path.join(repo_path, "SCHP"),
60
  device='cpu'
61
  )
62
- print("βœ… Automasker loaded!")
 
63
 
64
  except Exception as e:
65
- print(f"❌ Error loading models: {e}", file=sys.stderr)
66
  traceback.print_exc()
67
  raise
68
 
69
- def generate_tryon(person_img, cloth_img, progress=gr.Progress()):
70
- """Generate virtual try-on"""
71
 
72
- # ADD EXTENSIVE LOGGING FOR API DEBUGGING
73
- print("=" * 50, file=sys.stderr)
74
- print(f"API CALL RECEIVED", file=sys.stderr)
75
- print(f"Person image type: {type(person_img)}", file=sys.stderr)
76
- print(f"Cloth image type: {type(cloth_img)}", file=sys.stderr)
77
 
78
  if person_img is None or cloth_img is None:
79
- error_msg = "Please upload both person and garment images!"
80
- print(f"ERROR: {error_msg}", file=sys.stderr)
81
- raise gr.Error(error_msg)
82
 
83
  try:
84
- # HANDLE DIFFERENT INPUT TYPES (filepath or PIL)
85
  if isinstance(person_img, str):
86
- print(f"Converting person_img from filepath: {person_img}", file=sys.stderr)
87
  person_img = Image.open(person_img).convert('RGB')
88
- elif not isinstance(person_img, Image.Image):
89
- print(f"Converting person_img from array", file=sys.stderr)
90
- person_img = Image.fromarray(person_img).convert('RGB')
91
-
92
  if isinstance(cloth_img, str):
93
- print(f"Converting cloth_img from filepath: {cloth_img}", file=sys.stderr)
94
  cloth_img = Image.open(cloth_img).convert('RGB')
95
- elif not isinstance(cloth_img, Image.Image):
96
- print(f"Converting cloth_img from array", file=sys.stderr)
97
- cloth_img = Image.fromarray(cloth_img).convert('RGB')
98
 
99
- print(f"Images loaded successfully", file=sys.stderr)
100
 
101
- # Load models
102
- progress(0.05, desc="Loading models...")
103
  load_models()
104
 
105
- progress(0.15, desc="Processing images...")
106
-
107
- # Resize images
108
  target_height = 1024
109
  target_width = 768
110
  person_img = resize_and_crop(person_img, (target_width, target_height))
111
  cloth_img = resize_and_padding(cloth_img, (target_width, target_height))
112
 
113
- progress(0.35, desc="Generating body mask...")
114
-
115
- # Generate mask
116
  mask = automasker(person_img, "upper")['mask']
117
-
118
- # Clear memory
119
  gc.collect()
120
 
121
- progress(0.50, desc="Running on GPU T4 - processing takes 2-3 minutes per image.")
122
-
123
- # Run inference
124
  result = pipeline(
125
  image=person_img,
126
  condition_image=cloth_img,
@@ -132,92 +96,32 @@ def generate_tryon(person_img, cloth_img, progress=gr.Progress()):
132
  width=target_width
133
  )[0]
134
 
135
- progress(1.0, desc="Complete! ✨")
136
-
137
- print("SUCCESS: Try-on generated successfully", file=sys.stderr)
138
  return result
139
 
140
  except Exception as e:
141
- error_msg = f"Error during try-on: {str(e)}"
142
- print(f"ERROR: {error_msg}", file=sys.stderr)
143
  traceback.print_exc()
144
- raise gr.Error(error_msg)
145
 
146
- # Create Gradio UI
147
- with gr.Blocks(
 
 
 
 
 
 
148
  title="Try-Space Virtual Try-On",
149
- theme=gr.themes.Soft()
150
- ) as demo:
151
-
152
- gr.Markdown("""
153
- # 🎨 Try-Space Virtual Try-On
154
- ### Upload a person image and garment to see the magic! ✨
155
-
156
- ⚠️ **Note:** Running on GPU T4 - processing takes 2-3 minutes per image.
157
- """)
158
-
159
- with gr.Row():
160
- with gr.Column():
161
- gr.Markdown("### πŸ“Έ Inputs")
162
- person_input = gr.Image(
163
- label="πŸ‘€ Person Image (full body, front-facing)",
164
- type="filepath", # CHANGED FROM "pil" TO "filepath"
165
- height=350
166
- )
167
- cloth_input = gr.Image(
168
- label="πŸ‘• Garment Image (flat, white background)",
169
- type="filepath", # CHANGED FROM "pil" TO "filepath"
170
- height=350
171
- )
172
-
173
- with gr.Row():
174
- clear_btn = gr.ClearButton(
175
- [person_input, cloth_input],
176
- value="πŸ—‘οΈ Clear"
177
- )
178
- submit_btn = gr.Button(
179
- "πŸš€ Generate Try-On",
180
- variant="primary",
181
- size="lg"
182
- )
183
-
184
- with gr.Column():
185
- gr.Markdown("### ✨ Result")
186
- output_img = gr.Image(
187
- label="Virtual Try-On Result",
188
- type="filepath", # ADDED type
189
- height=700
190
- )
191
-
192
- gr.Markdown("""
193
- ---
194
- ### πŸ’‘ Tips for Best Results:
195
- - βœ… Use well-lit, clear images
196
- - βœ… Person should face the camera directly
197
- - βœ… Garment should be flat or on white background
198
- - βœ… Works best with shirts, jackets, or tops
199
- - βœ… Avoid extreme angles or poses
200
-
201
- ### ⏱️ Processing Time:
202
- - **GPU T4:** ~2-3 minutes per generation
203
- """)
204
-
205
- # Event handler with API name
206
- submit_btn.click(
207
- fn=generate_tryon,
208
- inputs=[person_input, cloth_input],
209
- outputs=output_img,
210
- api_name="generate_tryon" # ADDED THIS - CRITICAL FOR API ACCESS
211
- )
212
 
213
- # Launch app
214
  if __name__ == "__main__":
215
- print("πŸš€ Starting Try-Space Virtual Try-On...")
216
- # Pre-load models at startup
217
  try:
218
  load_models()
219
- except Exception as e:
220
- print(f"⚠️ Model loading will happen on first inference: {e}")
221
-
222
- # Launch with queue and show_error for better debugging
223
- demo.queue().launch(show_error=True) # ADDED show_error=True
 
1
+ import patch_gradio
2
  import os
3
  import sys
4
  import torch
 
8
  import traceback
9
  from huggingface_hub import snapshot_download
10
 
 
11
  sys.path.insert(0, '/app/CatVTON')
12
 
13
  from model.pipeline import CatVTONPipeline
14
  from model.cloth_masker import AutoMasker
15
  from utils import init_weight_dtype, resize_and_crop, resize_and_padding
16
 
 
17
  pipeline = None
18
  automasker = None
19
 
20
  def load_models():
 
21
  global pipeline, automasker
22
 
23
  if pipeline is not None and automasker is not None:
24
  return
25
 
26
+ print("πŸ”„ Loading models...", file=sys.stderr)
27
 
28
  try:
 
29
  repo_path = snapshot_download(
30
  repo_id="zhengchong/CatVTON",
31
  cache_dir="/tmp/models"
32
  )
33
 
 
 
 
34
  nsfw_path = "/tmp/NSFW.jpg"
35
  if not os.path.exists(nsfw_path):
36
  Image.new('RGB', (512, 512), color='black').save(nsfw_path)
37
 
 
38
  pipeline = CatVTONPipeline(
39
  base_ckpt="booksforcharlie/stable-diffusion-inpainting",
40
  attn_ckpt=repo_path,
 
43
  use_tf32=True,
44
  device='cuda'
45
  )
 
46
 
 
47
  automasker = AutoMasker(
48
  densepose_ckpt=os.path.join(repo_path, "DensePose"),
49
  schp_ckpt=os.path.join(repo_path, "SCHP"),
50
  device='cpu'
51
  )
52
+
53
+ print("βœ… Models loaded!", file=sys.stderr)
54
 
55
  except Exception as e:
56
+ print(f"❌ Error: {e}", file=sys.stderr)
57
  traceback.print_exc()
58
  raise
59
 
60
+ def generate_tryon(person_img, cloth_img):
61
+ """CRITICAL: Removed progress parameter - causes API issues"""
62
 
63
+ print("="*50, file=sys.stderr)
64
+ print(f"Received - Person: {type(person_img)}, Cloth: {type(cloth_img)}", file=sys.stderr)
 
 
 
65
 
66
  if person_img is None or cloth_img is None:
67
+ raise gr.Error("Both images required!")
 
 
68
 
69
  try:
70
+ # Convert filepaths to PIL
71
  if isinstance(person_img, str):
 
72
  person_img = Image.open(person_img).convert('RGB')
 
 
 
 
73
  if isinstance(cloth_img, str):
 
74
  cloth_img = Image.open(cloth_img).convert('RGB')
 
 
 
75
 
76
+ print("Images converted to PIL", file=sys.stderr)
77
 
 
 
78
  load_models()
79
 
 
 
 
80
  target_height = 1024
81
  target_width = 768
82
  person_img = resize_and_crop(person_img, (target_width, target_height))
83
  cloth_img = resize_and_padding(cloth_img, (target_width, target_height))
84
 
 
 
 
85
  mask = automasker(person_img, "upper")['mask']
 
 
86
  gc.collect()
87
 
 
 
 
88
  result = pipeline(
89
  image=person_img,
90
  condition_image=cloth_img,
 
96
  width=target_width
97
  )[0]
98
 
99
+ print("βœ… Success!", file=sys.stderr)
 
 
100
  return result
101
 
102
  except Exception as e:
103
+ print(f"❌ Error: {e}", file=sys.stderr)
 
104
  traceback.print_exc()
105
+ raise gr.Error(str(e))
106
 
107
+ # CRITICAL: Use gr.Interface instead of gr.Blocks for better API support
108
+ demo = gr.Interface(
109
+ fn=generate_tryon,
110
+ inputs=[
111
+ gr.Image(label="Person Image", type="filepath"),
112
+ gr.Image(label="Garment Image", type="filepath")
113
+ ],
114
+ outputs=gr.Image(label="Result", type="filepath"),
115
  title="Try-Space Virtual Try-On",
116
+ description="Upload person and garment images. Processing takes 2-3 minutes on GPU T4.",
117
+ api_name="generate_tryon",
118
+ allow_flagging="never"
119
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
 
121
  if __name__ == "__main__":
122
+ print("πŸš€ Starting...", file=sys.stderr)
 
123
  try:
124
  load_models()
125
+ except:
126
+ pass
127
+ demo.queue().launch(show_error=True)