beertoshi commited on
Commit
b6232bb
Β·
verified Β·
1 Parent(s): e6379ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -105
app.py CHANGED
@@ -5,190 +5,217 @@ from PIL import Image, ImageDraw, ImageFilter
5
  import numpy as np
6
  import os
7
 
8
- # Set memory allocation
9
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
10
 
11
- # Check device - wait for GPU
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
- print(f"Device detected: {device}")
 
14
 
15
- if device == "cuda":
16
- print(f"GPU: {torch.cuda.get_device_name(0)}")
17
- print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
18
 
19
- # Initialize model
20
  try:
21
- print("Loading Stable Diffusion Inpainting model...")
22
-
23
- model_id = "runwayml/stable-diffusion-inpainting"
24
 
 
25
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
26
- model_id,
27
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
28
  revision="fp16" if device == "cuda" else "main",
 
29
  safety_checker=None,
30
- requires_safety_checker=False
 
 
31
  )
32
 
33
- # Move to GPU if available, otherwise keep on CPU
34
- if device == "cuda":
35
- pipe = pipe.to(device)
36
- pipe.enable_attention_slicing()
37
- print("Model loaded on GPU!")
38
- else:
39
- # Don't use CPU offload on HF Spaces - it causes issues
40
- print("Warning: Running on CPU - will be slow!")
41
 
42
  MODEL_LOADED = True
 
43
 
44
  except Exception as e:
45
- print(f"Error loading model: {e}")
46
- MODEL_LOADED = False
47
- pipe = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # Clothing prompts
50
- CLOTHING_PROMPTS = {
51
- "Indian Sari": "woman wearing elegant red and gold sari, traditional Indian saree dress",
52
- "Japanese Kimono": "person wearing beautiful kimono with floral patterns, traditional Japanese clothing",
53
- "African Dashiki": "person wearing colorful dashiki with patterns, traditional African clothing",
54
  "Chinese Qipao": "woman wearing elegant qipao cheongsam dress, traditional Chinese dress",
55
- "Scottish Kilt": "man wearing Scottish kilt with tartan pattern, traditional highland dress",
56
  "Middle Eastern Thobe": "person wearing white thobe robe, traditional Middle Eastern clothing"
57
  }
58
 
59
- def create_body_mask(image):
60
- """Create mask for clothing area"""
61
- width, height = image.size
62
- mask = Image.new('L', (width, height), 0)
63
  draw = ImageDraw.Draw(mask)
64
 
65
- # Body area
66
- left = width * 0.2
67
- top = height * 0.25
68
- right = width * 0.8
69
- bottom = height * 0.95
70
 
71
  draw.ellipse([left, top, right, bottom], fill=255)
72
- mask = mask.filter(ImageFilter.GaussianBlur(radius=15))
73
 
74
  return mask
75
 
76
- def add_traditional_clothing(input_image, clothing_type, num_steps=20, guidance_scale=7.5):
77
- """Add traditional clothing to image"""
78
-
79
- if input_image is None:
80
- return None, "Please upload an image"
81
 
82
  if not MODEL_LOADED:
83
- return None, "Model failed to load. Please refresh the page."
 
 
 
84
 
85
  try:
86
  # Convert to PIL
87
- if isinstance(input_image, np.ndarray):
88
- image = Image.fromarray(input_image).convert("RGB")
89
  else:
90
- image = input_image.convert("RGB")
91
 
92
  # Store original size
93
  original_size = image.size
94
 
95
- # Resize for processing
96
- max_size = 512
97
- if max(image.size) > max_size:
98
- ratio = max_size / max(image.size)
99
- new_size = tuple(int(dim * ratio) for dim in image.size)
100
- image = image.resize(new_size, Image.Resampling.LANCZOS)
101
 
102
  # Create mask
103
- mask = create_body_mask(image)
104
-
105
- # Get prompt
106
- prompt = CLOTHING_PROMPTS[clothing_type]
107
- negative_prompt = "nude, naked, nsfw, bad quality, blurry"
108
 
109
  # Generate
110
- if device == "cuda":
111
- with torch.autocast("cuda"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  result = pipe(
113
  prompt=prompt,
114
  negative_prompt=negative_prompt,
115
  image=image,
116
  mask_image=mask,
117
- num_inference_steps=num_steps,
118
- guidance_scale=guidance_scale,
119
  strength=0.95
120
  ).images[0]
121
- else:
122
- result = pipe(
123
- prompt=prompt,
124
- negative_prompt=negative_prompt,
125
- image=image,
126
- mask_image=mask,
127
- num_inference_steps=num_steps,
128
- guidance_scale=guidance_scale,
129
- strength=0.95
130
- ).images[0]
131
 
132
- # Resize back
133
  if result.size != original_size:
134
  result = result.resize(original_size, Image.Resampling.LANCZOS)
135
 
136
- return result, f"Successfully added {clothing_type}!"
137
 
138
  except Exception as e:
139
  print(f"Generation error: {e}")
140
- return None, f"Error: {str(e)}"
141
 
142
- # Create interface
143
- with gr.Blocks(title="Traditional Clothing Addition") as app:
144
  gr.Markdown(f"""
145
  # πŸ‘˜ Traditional Clothing Addition Tool
146
 
147
- **Status:** Running on {device.upper()} {"βœ…" if device == "cuda" else "⚠️ (Slow)"}
 
148
 
149
- Add beautiful traditional clothing to your photos using AI.
150
  """)
151
 
152
  with gr.Row():
153
  with gr.Column():
154
- input_image = gr.Image(label="Upload Photo", type="pil")
 
 
 
155
 
156
- clothing_type = gr.Dropdown(
157
- choices=list(CLOTHING_PROMPTS.keys()),
158
  value="Indian Sari",
159
- label="Select Traditional Clothing"
160
  )
161
 
162
- with gr.Accordion("Settings", open=False):
163
- num_steps = gr.Slider(
164
- minimum=10, maximum=50, value=20, step=5,
165
- label="Steps (more = better quality)"
166
- )
167
- guidance_scale = gr.Slider(
168
- minimum=5, maximum=15, value=7.5, step=0.5,
169
- label="Guidance Scale"
170
- )
171
 
172
- generate_btn = gr.Button("🎨 Add Traditional Clothing", variant="primary")
 
 
 
 
173
 
174
  with gr.Column():
175
- output_image = gr.Image(label="Result")
176
- status = gr.Textbox(label="Status")
 
 
 
 
 
 
177
 
 
178
  gr.Markdown("""
179
- ### Tips:
180
- - Processing takes 15-30 seconds on GPU
181
  - Use clear, front-facing photos
182
- - Try different settings for variety
 
 
183
 
184
- ### Note:
185
- AI-generated traditional clothing may not be culturally accurate. Use respectfully.
 
186
  """)
187
 
 
188
  generate_btn.click(
189
- add_traditional_clothing,
190
- inputs=[input_image, clothing_type, num_steps, guidance_scale],
191
- outputs=[output_image, status]
192
  )
193
 
194
  if __name__ == "__main__":
 
5
  import numpy as np
6
  import os
7
 
8
+ # Clear cache if needed
9
+ os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
10
 
11
+ # Device setup
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ dtype = torch.float16 if device == "cuda" else torch.float32
14
+ print(f"Using device: {device}")
15
 
16
+ # Try loading model with better error handling
17
+ MODEL_LOADED = False
18
+ pipe = None
19
 
 
20
  try:
21
+ print("Downloading and loading model... This may take a few minutes on first run.")
 
 
22
 
23
+ # Use the correct model ID and revision
24
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
25
+ "runwayml/stable-diffusion-inpainting",
 
26
  revision="fp16" if device == "cuda" else "main",
27
+ torch_dtype=dtype,
28
  safety_checker=None,
29
+ requires_safety_checker=False,
30
+ use_safetensors=True, # Force safetensors format
31
+ local_files_only=False # Allow downloading
32
  )
33
 
34
+ pipe = pipe.to(device)
35
+ pipe.enable_attention_slicing()
 
 
 
 
 
 
36
 
37
  MODEL_LOADED = True
38
+ print("βœ… Model loaded successfully!")
39
 
40
  except Exception as e:
41
+ print(f"❌ First attempt failed: {e}")
42
+ print("Trying alternative model...")
43
+
44
+ try:
45
+ # Alternative: Try SD 2.0 inpainting
46
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
47
+ "stabilityai/stable-diffusion-2-inpainting",
48
+ torch_dtype=dtype,
49
+ safety_checker=None,
50
+ requires_safety_checker=False
51
+ )
52
+
53
+ pipe = pipe.to(device)
54
+ pipe.enable_attention_slicing()
55
+
56
+ MODEL_LOADED = True
57
+ print("βœ… Alternative model loaded successfully!")
58
+
59
+ except Exception as e2:
60
+ print(f"❌ Both models failed to load: {e2}")
61
+ MODEL_LOADED = False
62
 
63
  # Clothing prompts
64
+ CLOTHES = {
65
+ "Indian Sari": "woman wearing elegant red and gold sari, traditional Indian dress, beautiful saree",
66
+ "Japanese Kimono": "person wearing beautiful floral kimono, traditional Japanese clothing, silk kimono",
67
+ "African Dashiki": "person wearing colorful African dashiki with patterns, traditional clothing",
68
  "Chinese Qipao": "woman wearing elegant qipao cheongsam dress, traditional Chinese dress",
69
+ "Scottish Kilt": "man wearing traditional Scottish kilt with tartan pattern",
70
  "Middle Eastern Thobe": "person wearing white thobe robe, traditional Middle Eastern clothing"
71
  }
72
 
73
+ def create_mask(image):
74
+ """Create body mask"""
75
+ w, h = image.size
76
+ mask = Image.new('L', (w, h), 0)
77
  draw = ImageDraw.Draw(mask)
78
 
79
+ # Body area ellipse
80
+ left = w * 0.2
81
+ top = h * 0.25
82
+ right = w * 0.8
83
+ bottom = h * 0.9
84
 
85
  draw.ellipse([left, top, right, bottom], fill=255)
86
+ mask = mask.filter(ImageFilter.GaussianBlur(radius=20))
87
 
88
  return mask
89
 
90
+ def process_image(image, clothing_type, steps=20):
91
+ """Generate traditional clothing"""
 
 
 
92
 
93
  if not MODEL_LOADED:
94
+ return None, "⚠️ Model is still loading or failed to load. Please refresh the page and try again."
95
+
96
+ if image is None:
97
+ return None, "Please upload an image first"
98
 
99
  try:
100
  # Convert to PIL
101
+ if isinstance(image, np.ndarray):
102
+ image = Image.fromarray(image).convert("RGB")
103
  else:
104
+ image = image.convert("RGB")
105
 
106
  # Store original size
107
  original_size = image.size
108
 
109
+ # Resize for processing (max 512x512 for stability)
110
+ if max(image.size) > 512:
111
+ image.thumbnail((512, 512), Image.Resampling.LANCZOS)
 
 
 
112
 
113
  # Create mask
114
+ mask = create_mask(image)
 
 
 
 
115
 
116
  # Generate
117
+ prompt = CLOTHES[clothing_type]
118
+ negative_prompt = "nude, naked, nsfw, blurry, bad quality"
119
+
120
+ with torch.no_grad():
121
+ if device == "cuda":
122
+ with torch.autocast("cuda"):
123
+ result = pipe(
124
+ prompt=prompt,
125
+ negative_prompt=negative_prompt,
126
+ image=image,
127
+ mask_image=mask,
128
+ num_inference_steps=steps,
129
+ guidance_scale=7.5,
130
+ strength=0.95
131
+ ).images[0]
132
+ else:
133
  result = pipe(
134
  prompt=prompt,
135
  negative_prompt=negative_prompt,
136
  image=image,
137
  mask_image=mask,
138
+ num_inference_steps=steps,
139
+ guidance_scale=7.5,
140
  strength=0.95
141
  ).images[0]
 
 
 
 
 
 
 
 
 
 
142
 
143
+ # Resize back to original
144
  if result.size != original_size:
145
  result = result.resize(original_size, Image.Resampling.LANCZOS)
146
 
147
+ return result, f"βœ… Successfully added {clothing_type}!"
148
 
149
  except Exception as e:
150
  print(f"Generation error: {e}")
151
+ return None, f"❌ Error during generation: {str(e)}"
152
 
153
+ # Create UI
154
+ with gr.Blocks(title="Traditional Clothing AI") as app:
155
  gr.Markdown(f"""
156
  # πŸ‘˜ Traditional Clothing Addition Tool
157
 
158
+ **Device:** {device.upper()} {"πŸš€ Fast" if device == "cuda" else "🐌 Slow"}
159
+ **Model:** {"βœ… Ready" if MODEL_LOADED else "❌ Loading failed - please refresh"}
160
 
161
+ Transform your photos with traditional clothing from different cultures!
162
  """)
163
 
164
  with gr.Row():
165
  with gr.Column():
166
+ input_img = gr.Image(
167
+ label="Upload Your Photo",
168
+ type="pil"
169
+ )
170
 
171
+ clothing_dropdown = gr.Dropdown(
172
+ choices=list(CLOTHES.keys()),
173
  value="Indian Sari",
174
+ label="Choose Traditional Clothing"
175
  )
176
 
177
+ steps_slider = gr.Slider(
178
+ minimum=10,
179
+ maximum=50,
180
+ value=20,
181
+ step=5,
182
+ label="Quality (more steps = better but slower)"
183
+ )
 
 
184
 
185
+ generate_btn = gr.Button(
186
+ "🎨 Add Traditional Clothing",
187
+ variant="primary",
188
+ size="lg"
189
+ )
190
 
191
  with gr.Column():
192
+ output_img = gr.Image(
193
+ label="Result"
194
+ )
195
+
196
+ status_text = gr.Textbox(
197
+ label="Status",
198
+ interactive=False
199
+ )
200
 
201
+ # Examples (create dummy examples or remove if no images)
202
  gr.Markdown("""
203
+ ### πŸ’‘ Tips:
 
204
  - Use clear, front-facing photos
205
+ - Good lighting improves results
206
+ - Processing takes 15-30 seconds on GPU, 1-2 minutes on CPU
207
+ - Try different clothing types for variety
208
 
209
+ ### 🌍 Cultural Note:
210
+ This AI creates artistic interpretations of traditional clothing.
211
+ Results may not be culturally accurate - please use respectfully.
212
  """)
213
 
214
+ # Connect button
215
  generate_btn.click(
216
+ fn=process_image,
217
+ inputs=[input_img, clothing_dropdown, steps_slider],
218
+ outputs=[output_img, status_text]
219
  )
220
 
221
  if __name__ == "__main__":