Amit Shamsundar commited on
Commit
9285254
·
1 Parent(s): a16aa42

GPU error

Browse files
Files changed (1) hide show
  1. app.py +146 -61
app.py CHANGED
@@ -5,162 +5,243 @@ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentati
5
  import numpy as np
6
  from diffusers import StableDiffusionInpaintPipeline
7
  import warnings
 
8
  warnings.filterwarnings("ignore")
9
 
 
 
 
 
10
  # Global variables for models
11
  processor = None
12
  model = None
13
  pipe = None
14
 
 
 
 
 
 
 
 
 
15
  def load_models():
16
- """Load models with proper error handling"""
17
  global processor, model, pipe
18
 
19
  try:
20
  print("Loading segmentation model...")
21
  processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
22
  model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
 
 
 
 
23
  print("Segmentation model loaded successfully!")
24
 
25
  print("Loading Stable Diffusion inpainting model...")
26
- # Use CPU for Hugging Face Spaces (most don't have GPU access)
27
- device = "cuda" if torch.cuda.is_available() else "cpu"
28
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
29
  "stabilityai/stable-diffusion-2-inpainting",
30
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
31
- safety_checker=None, # Disable safety checker to save memory
32
  requires_safety_checker=False,
33
  use_safetensors=True
34
  )
35
- pipe = pipe.to(device)
36
 
37
- # Enable memory efficient attention if available
 
 
 
38
  if hasattr(pipe, 'enable_attention_slicing'):
39
  pipe.enable_attention_slicing()
40
- if hasattr(pipe, 'enable_model_cpu_offload'):
41
- pipe.enable_model_cpu_offload()
 
 
 
 
42
 
43
- print("Models loaded successfully!")
44
  return True
45
 
46
  except Exception as e:
47
  print(f"Error loading models: {str(e)}")
 
 
48
  return False
49
 
50
  def segment_clothes(human_image):
51
- """Segment clothing from human image"""
52
  try:
53
- # Resize image if too large to save memory
 
54
  if human_image.size[0] > 512 or human_image.size[1] > 512:
55
  human_image = human_image.resize((512, 512), Image.Resampling.LANCZOS)
56
 
57
  # Process human image for segmentation
58
  inputs = processor(images=human_image, return_tensors="pt")
59
 
 
 
 
 
 
60
  with torch.no_grad():
61
  outputs = model(**inputs)
62
 
63
  logits = outputs.logits.cpu()
64
  upsampled_logits = torch.nn.functional.interpolate(
65
- logits, size=human_image.size[::-1], mode="bilinear", align_corners=False
 
 
 
66
  )
67
  pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
68
 
69
- # Create mask for clothes (labels 4, 5, 6, 7 typically represent different clothing items)
70
- clothes_mask = np.isin(pred_seg, [4, 5, 6, 7]).astype(np.uint8) * 255
 
 
 
 
 
 
 
 
 
 
71
 
72
- return Image.fromarray(clothes_mask)
 
 
 
 
 
73
 
74
  except Exception as e:
75
  print(f"Error in segmentation: {str(e)}")
76
- # Return a default mask if segmentation fails
77
- return Image.new('L', human_image.size, 255)
 
 
 
78
 
79
  def try_on_cloth(human_image, cloth_image, progress=gr.Progress()):
80
- """Main function for virtual try-on"""
81
  if human_image is None or cloth_image is None:
82
  return None, "Please upload both human and cloth images."
83
 
84
  if processor is None or model is None or pipe is None:
85
- return None, "Models not loaded. Please wait for initialization to complete."
86
 
87
  try:
88
  progress(0.1, desc="Processing images...")
89
 
90
- # Ensure images are PIL Images and resize for memory efficiency
91
  if isinstance(human_image, np.ndarray):
92
  human_image = Image.fromarray(human_image)
93
  if isinstance(cloth_image, np.ndarray):
94
  cloth_image = Image.fromarray(cloth_image)
95
-
96
- # Resize images to reasonable size for processing
 
 
 
 
 
 
97
  target_size = (512, 512)
98
  human_image = human_image.resize(target_size, Image.Resampling.LANCZOS)
99
  cloth_image = cloth_image.resize(target_size, Image.Resampling.LANCZOS)
100
 
101
  progress(0.3, desc="Generating clothing mask...")
102
 
103
- # Generate clothes mask
104
  mask = segment_clothes(human_image)
105
 
106
- progress(0.6, desc="Generating try-on result...")
107
 
108
- # Use Stable Diffusion Inpainting
109
- prompt = "A person wearing the provided clothing, realistic, high quality, detailed"
110
- negative_prompt = "blurry, low quality, distorted, deformed"
111
 
112
- # Generate result with lower parameters for faster processing
113
- result = pipe(
114
- prompt=prompt,
115
- negative_prompt=negative_prompt,
116
- image=human_image,
117
- mask_image=mask,
118
- num_inference_steps=20, # Reduced for faster processing
119
- strength=0.8,
120
- guidance_scale=7.5,
121
- generator=torch.Generator().manual_seed(42) # For reproducible results
122
- ).images[0]
123
 
124
- progress(1.0, desc="Complete!")
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- return result, "Try-on completed successfully!"
 
127
 
128
  except Exception as e:
129
  error_msg = f"Error during try-on: {str(e)}"
130
  print(error_msg)
131
- return None, error_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- # Initialize models when the app starts
134
- print("Initializing models...")
135
  models_loaded = load_models()
136
 
137
- # Create Gradio interface
138
- with gr.Blocks(title="Virtual Cloth Try-On AI", theme=gr.themes.Soft()) as interface:
139
  gr.Markdown("""
140
- # 🧥 Virtual Cloth Try-On AI
141
 
142
  Upload a photo of a person and a clothing item to see how the outfit would look!
143
 
144
- **Instructions:**
145
- 1. Upload a clear photo of a person (front-facing works best) in the first box
146
- 2. Upload an image of the clothing item you want to try on in the second box
147
- 3. Click "Generate Try-On" and wait for the result
148
 
149
- **Note:** Processing may take 1-2 minutes depending on server load.
 
 
 
150
  """)
151
 
152
  if not models_loaded:
153
- gr.Markdown("⚠️ **Models failed to load. Please refresh the page.**")
 
 
154
 
155
  with gr.Row():
156
  with gr.Column():
157
  human_input = gr.Image(
158
  type="pil",
159
- label="👤 Human Image"
160
  )
161
  cloth_input = gr.Image(
162
  type="pil",
163
- label="👕 Clothing Image"
164
  )
165
 
166
  with gr.Column():
@@ -170,11 +251,12 @@ with gr.Blocks(title="Virtual Cloth Try-On AI", theme=gr.themes.Soft()) as inter
170
  )
171
  status_output = gr.Textbox(
172
  label="Status",
173
- interactive=False
 
174
  )
175
 
176
  generate_btn = gr.Button(
177
- "🎨 Generate Try-On",
178
  variant="primary",
179
  size="lg"
180
  )
@@ -189,10 +271,13 @@ with gr.Blocks(title="Virtual Cloth Try-On AI", theme=gr.themes.Soft()) as inter
189
  gr.Markdown("""
190
  ---
191
  **Tips for better results:**
192
- - Use high-resolution, well-lit images
193
- - Ensure the person is facing forward
194
- - Use clothing images with clear, visible details
195
- - Avoid complex backgrounds when possible
 
 
 
196
  """)
197
 
198
  if __name__ == "__main__":
 
5
  import numpy as np
6
  from diffusers import StableDiffusionInpaintPipeline
7
  import warnings
8
+ import os
9
  warnings.filterwarnings("ignore")
10
 
11
+ # Force CPU usage to avoid GPU issues on Hugging Face Spaces
12
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
13
+ torch.set_default_dtype(torch.float32)
14
+
15
  # Global variables for models
16
  processor = None
17
  model = None
18
  pipe = None
19
 
20
+ def get_device():
21
+ """Safely determine the best available device"""
22
+ try:
23
+ # Force CPU for stability on HF Spaces
24
+ return "cpu"
25
+ except:
26
+ return "cpu"
27
+
28
  def load_models():
29
+ """Load models with CPU-only configuration"""
30
  global processor, model, pipe
31
 
32
  try:
33
  print("Loading segmentation model...")
34
  processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
35
  model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
36
+
37
+ # Ensure segmentation model is on CPU
38
+ model = model.to("cpu")
39
+ model.eval()
40
  print("Segmentation model loaded successfully!")
41
 
42
  print("Loading Stable Diffusion inpainting model...")
43
+
44
+ # Load with explicit CPU configuration
45
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
46
  "stabilityai/stable-diffusion-2-inpainting",
47
+ torch_dtype=torch.float32, # Use float32 for CPU
48
+ safety_checker=None,
49
  requires_safety_checker=False,
50
  use_safetensors=True
51
  )
 
52
 
53
+ # Explicitly move all components to CPU
54
+ pipe = pipe.to("cpu")
55
+
56
+ # Enable memory efficiency
57
  if hasattr(pipe, 'enable_attention_slicing'):
58
  pipe.enable_attention_slicing()
59
+
60
+ # Set to eval mode
61
+ pipe.unet.eval()
62
+ pipe.vae.eval()
63
+ if hasattr(pipe, 'text_encoder'):
64
+ pipe.text_encoder.eval()
65
 
66
+ print("Stable Diffusion model loaded successfully on CPU!")
67
  return True
68
 
69
  except Exception as e:
70
  print(f"Error loading models: {str(e)}")
71
+ import traceback
72
+ traceback.print_exc()
73
  return False
74
 
75
  def segment_clothes(human_image):
76
+ """Segment clothing from human image with CPU-only operations"""
77
  try:
78
+ # Resize image if too large
79
+ original_size = human_image.size
80
  if human_image.size[0] > 512 or human_image.size[1] > 512:
81
  human_image = human_image.resize((512, 512), Image.Resampling.LANCZOS)
82
 
83
  # Process human image for segmentation
84
  inputs = processor(images=human_image, return_tensors="pt")
85
 
86
+ # Ensure inputs are on CPU
87
+ for key in inputs:
88
+ if torch.is_tensor(inputs[key]):
89
+ inputs[key] = inputs[key].to("cpu")
90
+
91
  with torch.no_grad():
92
  outputs = model(**inputs)
93
 
94
  logits = outputs.logits.cpu()
95
  upsampled_logits = torch.nn.functional.interpolate(
96
+ logits,
97
+ size=human_image.size[::-1],
98
+ mode="bilinear",
99
+ align_corners=False
100
  )
101
  pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
102
 
103
+ # Create mask for clothes
104
+ clothes_labels = [4, 5, 6, 7, 8, 9, 10]
105
+ clothes_mask = np.isin(pred_seg, clothes_labels).astype(np.uint8) * 255
106
+
107
+ # If no clothes detected, create a default mask
108
+ if np.sum(clothes_mask) < 100:
109
+ print("Creating default upper body mask")
110
+ mask = np.zeros_like(pred_seg, dtype=np.uint8)
111
+ h, w = mask.shape
112
+ # Upper body region
113
+ mask[h//4:3*h//4, w//3:2*w//3] = 255
114
+ clothes_mask = mask
115
 
116
+ # Resize back to original size
117
+ mask_image = Image.fromarray(clothes_mask)
118
+ if original_size != mask_image.size:
119
+ mask_image = mask_image.resize(original_size, Image.Resampling.LANCZOS)
120
+
121
+ return mask_image
122
 
123
  except Exception as e:
124
  print(f"Error in segmentation: {str(e)}")
125
+ # Return a default center mask
126
+ h, w = human_image.size[::-1]
127
+ mask = np.zeros((h, w), dtype=np.uint8)
128
+ mask[h//4:3*h//4, w//3:2*w//3] = 255
129
+ return Image.fromarray(mask)
130
 
131
  def try_on_cloth(human_image, cloth_image, progress=gr.Progress()):
132
+ """Main function for virtual try-on with CPU-safe operations"""
133
  if human_image is None or cloth_image is None:
134
  return None, "Please upload both human and cloth images."
135
 
136
  if processor is None or model is None or pipe is None:
137
+ return None, "Models not loaded. Please refresh the page and try again."
138
 
139
  try:
140
  progress(0.1, desc="Processing images...")
141
 
142
+ # Ensure images are PIL Images
143
  if isinstance(human_image, np.ndarray):
144
  human_image = Image.fromarray(human_image)
145
  if isinstance(cloth_image, np.ndarray):
146
  cloth_image = Image.fromarray(cloth_image)
147
+
148
+ # Convert to RGB
149
+ if human_image.mode != 'RGB':
150
+ human_image = human_image.convert('RGB')
151
+ if cloth_image.mode != 'RGB':
152
+ cloth_image = cloth_image.convert('RGB')
153
+
154
+ # Resize for processing
155
  target_size = (512, 512)
156
  human_image = human_image.resize(target_size, Image.Resampling.LANCZOS)
157
  cloth_image = cloth_image.resize(target_size, Image.Resampling.LANCZOS)
158
 
159
  progress(0.3, desc="Generating clothing mask...")
160
 
161
+ # Generate mask
162
  mask = segment_clothes(human_image)
163
 
164
+ progress(0.6, desc="Generating try-on result (this may take a few minutes on CPU)...")
165
 
166
+ # Prepare for inpainting
167
+ prompt = "a person wearing the clothing, realistic, high quality, natural lighting"
168
+ negative_prompt = "blurry, low quality, distorted, deformed, extra limbs"
169
 
170
+ # Create CPU generator
171
+ generator = torch.Generator(device='cpu').manual_seed(42)
 
 
 
 
 
 
 
 
 
172
 
173
+ # Generate with CPU-optimized settings
174
+ with torch.no_grad():
175
+ result = pipe(
176
+ prompt=prompt,
177
+ negative_prompt=negative_prompt,
178
+ image=human_image,
179
+ mask_image=mask,
180
+ num_inference_steps=15, # Reduced for CPU
181
+ strength=0.75,
182
+ guidance_scale=7.0,
183
+ generator=generator
184
+ ).images[0]
185
 
186
+ progress(1.0, desc="Complete!")
187
+ return result, "Try-on completed successfully! (Processed on CPU)"
188
 
189
  except Exception as e:
190
  error_msg = f"Error during try-on: {str(e)}"
191
  print(error_msg)
192
+ import traceback
193
+ traceback.print_exc()
194
+
195
+ # Attempt simple fallback
196
+ try:
197
+ progress(0.8, desc="Attempting simple blend fallback...")
198
+ mask_array = np.array(mask) / 255.0
199
+ cloth_resized = cloth_image.resize(human_image.size)
200
+
201
+ human_array = np.array(human_image).astype(np.float32)
202
+ cloth_array = np.array(cloth_resized).astype(np.float32)
203
+
204
+ mask_3d = np.stack([mask_array] * 3, axis=2)
205
+ result_array = human_array * (1 - mask_3d) + cloth_array * mask_3d
206
+ result = Image.fromarray(result_array.astype(np.uint8))
207
+
208
+ return result, "Used simple blending due to processing error."
209
+ except:
210
+ return None, error_msg
211
 
212
+ # Initialize models
213
+ print("Initializing models for CPU processing...")
214
  models_loaded = load_models()
215
 
216
+ # Gradio interface
217
+ with gr.Blocks(title="Virtual Cloth Try-On AI", theme=gr.themes.Default()) as interface:
218
  gr.Markdown("""
219
+ # 🧥 Virtual Cloth Try-On AI (CPU Version)
220
 
221
  Upload a photo of a person and a clothing item to see how the outfit would look!
222
 
223
+ **⚠️ Note: This app runs on CPU, so processing will take 2-5 minutes per image.**
 
 
 
224
 
225
+ **Instructions:**
226
+ 1. Upload a clear photo of a person (front-facing works best)
227
+ 2. Upload an image of the clothing item you want to try on
228
+ 3. Click "Generate Try-On" and be patient - CPU processing is slow but works!
229
  """)
230
 
231
  if not models_loaded:
232
+ gr.Markdown(" **Models failed to load. Please refresh the page.**")
233
+ else:
234
+ gr.Markdown("✅ **Models loaded successfully! Ready for try-on.**")
235
 
236
  with gr.Row():
237
  with gr.Column():
238
  human_input = gr.Image(
239
  type="pil",
240
+ label="👤 Human Photo"
241
  )
242
  cloth_input = gr.Image(
243
  type="pil",
244
+ label="👕 Clothing Item"
245
  )
246
 
247
  with gr.Column():
 
251
  )
252
  status_output = gr.Textbox(
253
  label="Status",
254
+ interactive=False,
255
+ placeholder="Upload images and click 'Generate Try-On'"
256
  )
257
 
258
  generate_btn = gr.Button(
259
+ "🎨 Generate Try-On (Takes 2-5 minutes)",
260
  variant="primary",
261
  size="lg"
262
  )
 
271
  gr.Markdown("""
272
  ---
273
  **Tips for better results:**
274
+ - Use clear, high-resolution images with good lighting
275
+ - Person should be facing forward with visible torso
276
+ - Clothing items should be clearly visible and unfolded
277
+ - Simple backgrounds work better than busy ones
278
+ - Be patient - CPU processing takes time but produces good results!
279
+
280
+ **Expected processing time: 2-5 minutes per try-on**
281
  """)
282
 
283
  if __name__ == "__main__":