beertoshi commited on
Commit
3af7399
Β·
verified Β·
1 Parent(s): 01847b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -169
app.py CHANGED
@@ -1,140 +1,57 @@
1
  import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionInpaintPipeline
4
- from PIL import Image, ImageDraw, ImageFilter, ImageEnhance
5
  import numpy as np
6
  import spaces
7
 
8
  # Load model
9
- inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
10
  "stabilityai/stable-diffusion-2-inpainting",
11
  torch_dtype=torch.float16,
12
  safety_checker=None,
13
  requires_safety_checker=False
14
  )
15
- inpaint_pipe.enable_attention_slicing()
16
- inpaint_pipe.enable_vae_slicing()
17
- inpaint_pipe.enable_vae_tiling()
18
 
19
- print("βœ… Model loaded!")
20
-
21
- CLOTHING_PROMPTS = {
22
- "Indian Sari": "woman wearing luxurious red and gold silk sari with intricate embroidery, traditional Indian saree, professional fashion photography, studio lighting, ultra detailed fabric texture, 8k quality",
23
- "Japanese Kimono": "person wearing exquisite silk kimono with cherry blossom patterns, traditional Japanese formal wear, professional portrait, studio lighting, highly detailed fabric, photorealistic",
24
- "African Dashiki": "person wearing vibrant African dashiki with authentic kente patterns, traditional clothing, professional photography, rich colors, detailed textile work, high resolution",
25
- "Chinese Qipao": "elegant woman in traditional Chinese qipao cheongsam, silk dress with intricate patterns, professional fashion shoot, studio lighting, ultra high quality",
26
- "Scottish Kilt": "man wearing traditional Scottish highland kilt with tartan pattern, formal Scottish attire, professional photography, detailed fabric texture",
27
- "Middle Eastern Thobe": "person wearing flowing white thobe robe, traditional Middle Eastern clothing, elegant fabric, studio portrait, high resolution"
28
  }
29
 
30
- def make_divisible_by_8(image):
31
- """Ensure image dimensions are divisible by 8"""
32
- width, height = image.size
33
-
34
- # Calculate new dimensions divisible by 8
35
- new_width = width - (width % 8)
36
- new_height = height - (height % 8)
37
-
38
- # Only resize if needed
39
- if new_width != width or new_height != height:
40
- image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
41
-
42
- return image
43
 
44
- def resize_with_aspect_ratio(image, target_size):
45
- """Resize keeping aspect ratio and ensuring divisible by 8"""
46
- width, height = image.size
47
-
48
- # Calculate scaling factor
49
- scale = target_size / max(width, height)
50
-
51
- # Calculate new dimensions
52
- new_width = int(width * scale)
53
- new_height = int(height * scale)
54
-
55
- # Make divisible by 8
56
- new_width = new_width - (new_width % 8)
57
- new_height = new_height - (new_height % 8)
58
-
59
- # Ensure minimum size
60
- new_width = max(new_width, 64)
61
- new_height = max(new_height, 64)
62
-
63
- return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
64
-
65
- def create_professional_mask(image, face_margin=0.35):
66
- """Create mask avoiding face area"""
67
- width, height = image.size
68
  mask = Image.new('L', (width, height), 0)
69
  draw = ImageDraw.Draw(mask)
70
 
71
- # Face-safe area
72
- face_bottom = height * face_margin
73
-
74
- # Body area
75
- body_coords = [
76
- width * 0.1,
77
- face_bottom,
78
- width * 0.9,
79
- height * 0.98
80
- ]
81
-
82
- # Draw body
83
- draw.ellipse(body_coords, fill=255)
84
-
85
- # Gradient for smooth transition
86
- for i in range(30):
87
- opacity = int(255 * (i / 30))
88
- y = face_bottom - (30 - i)
89
- if y >= 0:
90
- draw.rectangle([body_coords[0], y, body_coords[2], y + 1], fill=opacity)
91
 
92
- # Smooth blur
93
  mask = mask.filter(ImageFilter.GaussianBlur(radius=25))
94
 
95
  return mask
96
 
97
- def enhance_image(image):
98
- """Enhance image quality"""
99
- # Sharpness
100
- enhancer = ImageEnhance.Sharpness(image)
101
- image = enhancer.enhance(1.2)
102
-
103
- # Color
104
- enhancer = ImageEnhance.Color(image)
105
- image = enhancer.enhance(1.1)
106
-
107
- return image
108
-
109
- def blend_images(original, generated, mask):
110
- """Blend images smoothly"""
111
- # Extra smooth blending
112
- blend_mask = mask.filter(ImageFilter.GaussianBlur(radius=40))
113
-
114
- # Convert to RGBA
115
- original_rgba = original.convert("RGBA")
116
- generated_rgba = generated.convert("RGBA")
117
-
118
- # Composite
119
- result = Image.composite(generated_rgba, original_rgba, blend_mask)
120
-
121
- return result.convert("RGB")
122
-
123
- @spaces.GPU(duration=120)
124
- def generate_professional(
125
- input_image,
126
- clothing_type,
127
- face_margin=0.35,
128
- quality_preset="ultra"
129
- ):
130
- """Generate with proper dimension handling"""
131
-
132
  if input_image is None:
133
  return None, "Please upload an image"
134
 
135
  try:
136
  # Move to GPU
137
- inpaint_pipe.to("cuda")
138
 
139
  # Convert to PIL
140
  if isinstance(input_image, np.ndarray):
@@ -142,118 +59,145 @@ def generate_professional(
142
  else:
143
  image = input_image.convert("RGB")
144
 
145
- # Store original
146
- original_image = image.copy()
147
  original_size = image.size
148
 
149
  # Quality settings
150
  quality_settings = {
151
- "fast": {"size": 512, "steps": 30},
152
- "balanced": {"size": 768, "steps": 50},
153
- "ultra": {"size": 1024, "steps": 70}
154
  }
155
 
156
- settings = quality_settings[quality_preset]
 
157
 
158
- # Resize with proper dimensions
159
- if max(image.size) > settings["size"]:
160
- image = resize_with_aspect_ratio(image, settings["size"])
161
- original_resized = original_image.resize(image.size, Image.Resampling.LANCZOS)
 
162
  else:
163
- # Still ensure divisible by 8
164
- image = make_divisible_by_8(image)
165
- original_resized = original_image.resize(image.size, Image.Resampling.LANCZOS)
 
 
166
 
167
- print(f"Processing at: {image.size} (divisible by 8)")
 
 
168
 
169
- # Enhance
170
- image = enhance_image(image)
 
171
 
172
- # Create mask
173
- mask = create_professional_mask(image, face_margin)
 
 
174
 
175
  # Generate
176
- prompt = CLOTHING_PROMPTS[clothing_type]
177
- negative_prompt = "blurry, low quality, distorted face, bad anatomy, ugly"
178
 
179
  with torch.autocast("cuda"):
180
- result = inpaint_pipe(
181
  prompt=prompt,
182
  negative_prompt=negative_prompt,
183
- image=image,
184
  mask_image=mask,
185
  num_inference_steps=settings["steps"],
186
- guidance_scale=8.0,
187
- strength=0.88
188
  ).images[0]
189
 
190
- # Blend
191
- final = blend_images(original_resized, result, mask)
 
192
 
193
- # Final enhancement
194
- final = enhance_image(final)
 
195
 
196
- # Resize back to original
 
 
 
 
 
 
197
  if final.size != original_size:
198
  final = final.resize(original_size, Image.Resampling.LANCZOS)
199
 
200
  # Cleanup
201
- inpaint_pipe.to("cpu")
202
  torch.cuda.empty_cache()
203
 
204
- return final, f"βœ… {clothing_type} applied successfully!"
205
 
206
  except Exception as e:
 
207
  return None, f"Error: {str(e)}"
208
 
209
  # UI
210
- with gr.Blocks(title="Professional Clothing AI") as app:
211
  gr.Markdown("""
212
- # πŸ‘˜ Professional Traditional Clothing AI
213
- ### Perfect Face Preservation β€’ Studio Quality
 
214
  """)
215
 
216
  with gr.Row():
217
  with gr.Column():
218
- input_image = gr.Image(type="pil", label="Upload Photo")
 
 
 
219
 
220
  clothing_type = gr.Dropdown(
221
- choices=list(CLOTHING_PROMPTS.keys()),
222
  value="Indian Sari",
223
- label="Traditional Clothing"
224
  )
225
 
226
- with gr.Accordion("Settings", open=True):
227
- face_margin = gr.Slider(
228
- 0.25, 0.45, 0.35, 0.05,
229
- label="Face Protection Zone",
230
- info="Higher = more face area protected"
231
- )
232
-
233
- quality_preset = gr.Radio(
234
- ["fast", "balanced", "ultra"],
235
- value="balanced",
236
- label="Quality",
237
- info="Ultra = best (2-3 min)"
238
- )
239
 
240
- generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
 
 
 
 
241
 
242
  with gr.Column():
243
- output_image = gr.Image(label="Result")
244
- status = gr.Textbox(label="Status")
 
 
 
 
 
 
245
 
246
  gr.Markdown("""
247
- ### Tips:
248
- - Face margin 0.35 = perfect face preservation
249
- - Balanced mode = good quality in ~1 minute
250
- - Works with any image size
 
 
251
  """)
252
 
253
  generate_btn.click(
254
- generate_professional,
255
- inputs=[input_image, clothing_type, face_margin, quality_preset],
256
- outputs=[output_image, status]
257
  )
258
 
259
- app.launch()
 
 
1
  import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionInpaintPipeline
4
+ from PIL import Image, ImageDraw, ImageFilter
5
  import numpy as np
6
  import spaces
7
 
8
  # Load model
9
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
10
  "stabilityai/stable-diffusion-2-inpainting",
11
  torch_dtype=torch.float16,
12
  safety_checker=None,
13
  requires_safety_checker=False
14
  )
15
+ pipe.enable_attention_slicing()
 
 
16
 
17
+ CLOTHES = {
18
+ "Indian Sari": "woman wearing beautiful red and gold sari, traditional Indian dress, high quality photo",
19
+ "Japanese Kimono": "person wearing elegant kimono, traditional Japanese clothing, professional photo",
20
+ "African Dashiki": "person wearing colorful dashiki, traditional African clothing, detailed",
21
+ "Chinese Qipao": "woman wearing elegant qipao dress, traditional Chinese clothing",
22
+ "Scottish Kilt": "man wearing Scottish kilt, traditional highland dress",
23
+ "Middle Eastern Thobe": "person wearing white thobe, traditional Middle Eastern clothing"
 
 
24
  }
25
 
26
+ def make_divisible_by_8(width, height):
27
+ """Ensure dimensions are divisible by 8"""
28
+ return width - (width % 8), height - (height % 8)
 
 
 
 
 
 
 
 
 
 
29
 
30
+ def create_body_mask(image_size):
31
+ """Create mask for body area only"""
32
+ width, height = image_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  mask = Image.new('L', (width, height), 0)
34
  draw = ImageDraw.Draw(mask)
35
 
36
+ # Body area (avoiding face)
37
+ top = height * 0.35 # Start below face
38
+ left = width * 0.1
39
+ right = width * 0.9
40
+ bottom = height * 0.98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ draw.ellipse([left, top, right, bottom], fill=255)
43
  mask = mask.filter(ImageFilter.GaussianBlur(radius=25))
44
 
45
  return mask
46
 
47
+ @spaces.GPU(duration=90)
48
+ def generate_clothing(input_image, clothing_type, quality_mode="balanced"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  if input_image is None:
50
  return None, "Please upload an image"
51
 
52
  try:
53
  # Move to GPU
54
+ pipe.to("cuda")
55
 
56
  # Convert to PIL
57
  if isinstance(input_image, np.ndarray):
 
59
  else:
60
  image = input_image.convert("RGB")
61
 
62
+ # Store original size
 
63
  original_size = image.size
64
 
65
  # Quality settings
66
  quality_settings = {
67
+ "fast": {"size": 512, "steps": 25},
68
+ "balanced": {"size": 768, "steps": 40},
69
+ "ultra": {"size": 1024, "steps": 60}
70
  }
71
 
72
+ settings = quality_settings[quality_mode]
73
+ target_size = settings["size"]
74
 
75
+ # Calculate new size maintaining aspect ratio
76
+ if max(image.size) > target_size:
77
+ scale = target_size / max(image.size)
78
+ new_width = int(image.width * scale)
79
+ new_height = int(image.height * scale)
80
  else:
81
+ new_width = image.width
82
+ new_height = image.height
83
+
84
+ # Make divisible by 8
85
+ new_width, new_height = make_divisible_by_8(new_width, new_height)
86
 
87
+ # Ensure minimum size
88
+ new_width = max(new_width, 64)
89
+ new_height = max(new_height, 64)
90
 
91
+ # Resize all images to the same size
92
+ working_size = (new_width, new_height)
93
+ image_resized = image.resize(working_size, Image.Resampling.LANCZOS)
94
 
95
+ # Create mask at the same size
96
+ mask = create_body_mask(working_size)
97
+
98
+ print(f"Processing at size: {working_size}")
99
 
100
  # Generate
101
+ prompt = CLOTHES[clothing_type] + ", professional photography, preserve facial features"
102
+ negative_prompt = "blurry, low quality, distorted face, bad anatomy"
103
 
104
  with torch.autocast("cuda"):
105
+ result = pipe(
106
  prompt=prompt,
107
  negative_prompt=negative_prompt,
108
+ image=image_resized,
109
  mask_image=mask,
110
  num_inference_steps=settings["steps"],
111
+ guidance_scale=7.5,
112
+ strength=0.85
113
  ).images[0]
114
 
115
+ # Ensure result is the same size (it should be, but just in case)
116
+ if result.size != working_size:
117
+ result = result.resize(working_size, Image.Resampling.LANCZOS)
118
 
119
+ # Blend with original to preserve face
120
+ # Create smooth blend mask
121
+ blend_mask = mask.filter(ImageFilter.GaussianBlur(radius=40))
122
 
123
+ # All images must be the same size for composite
124
+ assert image_resized.size == result.size == blend_mask.size, f"Size mismatch: {image_resized.size}, {result.size}, {blend_mask.size}"
125
+
126
+ # Blend
127
+ final = Image.composite(result, image_resized, blend_mask)
128
+
129
+ # Resize back to original size
130
  if final.size != original_size:
131
  final = final.resize(original_size, Image.Resampling.LANCZOS)
132
 
133
  # Cleanup
134
+ pipe.to("cpu")
135
  torch.cuda.empty_cache()
136
 
137
+ return final, f"βœ… Successfully added {clothing_type}!"
138
 
139
  except Exception as e:
140
+ print(f"Error details: {str(e)}")
141
  return None, f"Error: {str(e)}"
142
 
143
  # UI
144
+ with gr.Blocks(title="Traditional Clothing AI", theme=gr.themes.Soft()) as app:
145
  gr.Markdown("""
146
+ # πŸ‘˜ Traditional Clothing AI - Face Preserved
147
+
148
+ Add traditional clothing while keeping your face perfectly intact.
149
  """)
150
 
151
  with gr.Row():
152
  with gr.Column():
153
+ input_image = gr.Image(
154
+ type="pil",
155
+ label="Upload Your Photo"
156
+ )
157
 
158
  clothing_type = gr.Dropdown(
159
+ choices=list(CLOTHES.keys()),
160
  value="Indian Sari",
161
+ label="Select Traditional Clothing"
162
  )
163
 
164
+ quality_mode = gr.Radio(
165
+ choices=["fast", "balanced", "ultra"],
166
+ value="balanced",
167
+ label="Quality Mode",
168
+ info="Higher quality = longer processing time"
169
+ )
 
 
 
 
 
 
 
170
 
171
+ generate_btn = gr.Button(
172
+ "🎨 Add Traditional Clothing",
173
+ variant="primary",
174
+ size="lg"
175
+ )
176
 
177
  with gr.Column():
178
+ output_image = gr.Image(
179
+ label="Result"
180
+ )
181
+
182
+ status_text = gr.Textbox(
183
+ label="Status",
184
+ placeholder="Upload an image and click generate..."
185
+ )
186
 
187
  gr.Markdown("""
188
+ ### How it works:
189
+ - 🎯 Only modifies clothing area (below face)
190
+ - 😊 Your face remains untouched
191
+ - 🎨 Smooth blending for natural results
192
+ - ⚑ Fast mode: ~30 seconds
193
+ - πŸ”¬ Ultra mode: ~2 minutes (best quality)
194
  """)
195
 
196
  generate_btn.click(
197
+ fn=generate_clothing,
198
+ inputs=[input_image, clothing_type, quality_mode],
199
+ outputs=[output_image, status_text]
200
  )
201
 
202
+ if __name__ == "__main__":
203
+ app.launch()