beertoshi commited on
Commit
da348cc
Β·
verified Β·
1 Parent(s): b6232bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -136
app.py CHANGED
@@ -3,184 +3,152 @@ import torch
3
  from diffusers import StableDiffusionInpaintPipeline
4
  from PIL import Image, ImageDraw, ImageFilter
5
  import numpy as np
6
- import os
7
 
8
- # Clear cache if needed
9
- os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
10
 
11
- # Device setup
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
- dtype = torch.float16 if device == "cuda" else torch.float32
14
- print(f"Using device: {device}")
 
 
15
 
16
- # Try loading model with better error handling
17
- MODEL_LOADED = False
18
- pipe = None
19
 
20
- try:
21
- print("Downloading and loading model... This may take a few minutes on first run.")
22
-
23
- # Use the correct model ID and revision
24
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
25
- "runwayml/stable-diffusion-inpainting",
26
- revision="fp16" if device == "cuda" else "main",
27
- torch_dtype=dtype,
28
- safety_checker=None,
29
- requires_safety_checker=False,
30
- use_safetensors=True, # Force safetensors format
31
- local_files_only=False # Allow downloading
32
- )
33
-
34
- pipe = pipe.to(device)
35
- pipe.enable_attention_slicing()
36
-
37
- MODEL_LOADED = True
38
- print("βœ… Model loaded successfully!")
39
-
40
- except Exception as e:
41
- print(f"❌ First attempt failed: {e}")
42
- print("Trying alternative model...")
43
-
44
- try:
45
- # Alternative: Try SD 2.0 inpainting
46
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
47
- "stabilityai/stable-diffusion-2-inpainting",
48
- torch_dtype=dtype,
49
- safety_checker=None,
50
- requires_safety_checker=False
51
- )
52
-
53
- pipe = pipe.to(device)
54
- pipe.enable_attention_slicing()
55
-
56
- MODEL_LOADED = True
57
- print("βœ… Alternative model loaded successfully!")
58
-
59
- except Exception as e2:
60
- print(f"❌ Both models failed to load: {e2}")
61
- MODEL_LOADED = False
62
 
63
  # Clothing prompts
64
- CLOTHES = {
65
- "Indian Sari": "woman wearing elegant red and gold sari, traditional Indian dress, beautiful saree",
66
- "Japanese Kimono": "person wearing beautiful floral kimono, traditional Japanese clothing, silk kimono",
67
- "African Dashiki": "person wearing colorful African dashiki with patterns, traditional clothing",
68
- "Chinese Qipao": "woman wearing elegant qipao cheongsam dress, traditional Chinese dress",
69
- "Scottish Kilt": "man wearing traditional Scottish kilt with tartan pattern",
70
- "Middle Eastern Thobe": "person wearing white thobe robe, traditional Middle Eastern clothing"
71
  }
72
 
73
- def create_mask(image):
74
- """Create body mask"""
75
- w, h = image.size
76
- mask = Image.new('L', (w, h), 0)
77
  draw = ImageDraw.Draw(mask)
78
 
79
  # Body area ellipse
80
- left = w * 0.2
81
- top = h * 0.25
82
- right = w * 0.8
83
- bottom = h * 0.9
84
 
85
  draw.ellipse([left, top, right, bottom], fill=255)
86
  mask = mask.filter(ImageFilter.GaussianBlur(radius=20))
87
 
88
  return mask
89
 
90
- def process_image(image, clothing_type, steps=20):
91
- """Generate traditional clothing"""
 
92
 
93
- if not MODEL_LOADED:
94
- return None, "⚠️ Model is still loading or failed to load. Please refresh the page and try again."
95
-
96
- if image is None:
97
  return None, "Please upload an image first"
98
 
99
  try:
100
- # Convert to PIL
101
- if isinstance(image, np.ndarray):
102
- image = Image.fromarray(image).convert("RGB")
 
 
 
103
  else:
104
- image = image.convert("RGB")
105
 
106
  # Store original size
107
  original_size = image.size
108
 
109
- # Resize for processing (max 512x512 for stability)
110
- if max(image.size) > 512:
111
- image.thumbnail((512, 512), Image.Resampling.LANCZOS)
 
 
 
112
 
113
  # Create mask
114
- mask = create_mask(image)
115
 
116
- # Generate
117
- prompt = CLOTHES[clothing_type]
118
- negative_prompt = "nude, naked, nsfw, blurry, bad quality"
119
 
120
- with torch.no_grad():
121
- if device == "cuda":
122
- with torch.autocast("cuda"):
123
- result = pipe(
124
- prompt=prompt,
125
- negative_prompt=negative_prompt,
126
- image=image,
127
- mask_image=mask,
128
- num_inference_steps=steps,
129
- guidance_scale=7.5,
130
- strength=0.95
131
- ).images[0]
132
- else:
133
- result = pipe(
134
- prompt=prompt,
135
- negative_prompt=negative_prompt,
136
- image=image,
137
- mask_image=mask,
138
- num_inference_steps=steps,
139
- guidance_scale=7.5,
140
- strength=0.95
141
- ).images[0]
142
 
143
- # Resize back to original
144
  if result.size != original_size:
145
  result = result.resize(original_size, Image.Resampling.LANCZOS)
146
 
 
 
 
 
147
  return result, f"βœ… Successfully added {clothing_type}!"
148
 
149
  except Exception as e:
150
  print(f"Generation error: {e}")
151
- return None, f"❌ Error during generation: {str(e)}"
152
 
153
- # Create UI
154
- with gr.Blocks(title="Traditional Clothing AI") as app:
155
- gr.Markdown(f"""
156
  # πŸ‘˜ Traditional Clothing Addition Tool
157
 
158
- **Device:** {device.upper()} {"πŸš€ Fast" if device == "cuda" else "🐌 Slow"}
159
- **Model:** {"βœ… Ready" if MODEL_LOADED else "❌ Loading failed - please refresh"}
160
 
161
- Transform your photos with traditional clothing from different cultures!
 
162
  """)
163
 
164
  with gr.Row():
165
  with gr.Column():
166
- input_img = gr.Image(
167
  label="Upload Your Photo",
168
  type="pil"
169
  )
170
 
171
- clothing_dropdown = gr.Dropdown(
172
- choices=list(CLOTHES.keys()),
173
  value="Indian Sari",
174
- label="Choose Traditional Clothing"
175
  )
176
 
177
- steps_slider = gr.Slider(
178
- minimum=10,
179
- maximum=50,
180
- value=20,
181
- step=5,
182
- label="Quality (more steps = better but slower)"
183
- )
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  generate_btn = gr.Button(
186
  "🎨 Add Traditional Clothing",
@@ -189,33 +157,38 @@ with gr.Blocks(title="Traditional Clothing AI") as app:
189
  )
190
 
191
  with gr.Column():
192
- output_img = gr.Image(
193
  label="Result"
194
  )
195
 
196
  status_text = gr.Textbox(
197
  label="Status",
198
- interactive=False
199
  )
200
 
201
- # Examples (create dummy examples or remove if no images)
202
  gr.Markdown("""
203
- ### πŸ’‘ Tips:
 
204
  - Use clear, front-facing photos
205
  - Good lighting improves results
206
- - Processing takes 15-30 seconds on GPU, 1-2 minutes on CPU
207
- - Try different clothing types for variety
208
 
209
  ### 🌍 Cultural Note:
210
- This AI creates artistic interpretations of traditional clothing.
211
- Results may not be culturally accurate - please use respectfully.
 
 
 
 
 
212
  """)
213
 
214
  # Connect button
215
  generate_btn.click(
216
- fn=process_image,
217
- inputs=[input_img, clothing_dropdown, steps_slider],
218
- outputs=[output_img, status_text]
219
  )
220
 
221
  if __name__ == "__main__":
 
3
  from diffusers import StableDiffusionInpaintPipeline
4
  from PIL import Image, ImageDraw, ImageFilter
5
  import numpy as np
6
+ import spaces
7
 
8
+ # Initialize model globally (loaded on CPU first)
9
+ print("Loading model on CPU first (ZeroGPU will move it to GPU when needed)...")
10
 
11
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
12
+ "stabilityai/stable-diffusion-2-inpainting",
13
+ torch_dtype=torch.float16,
14
+ safety_checker=None,
15
+ requires_safety_checker=False
16
+ )
17
 
18
+ # Don't move to GPU yet - ZeroGPU will handle this
19
+ pipe.enable_attention_slicing()
 
20
 
21
+ print("βœ… Model loaded! ZeroGPU will activate when generating.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Clothing prompts
24
+ CLOTHING_PROMPTS = {
25
+ "Indian Sari": "woman wearing beautiful red and gold silk sari, traditional Indian saree dress, intricate embroidery",
26
+ "Japanese Kimono": "person wearing elegant floral kimono with obi belt, traditional Japanese clothing, silk fabric",
27
+ "African Dashiki": "person wearing colorful African dashiki with geometric patterns, traditional clothing, vibrant",
28
+ "Chinese Qipao": "woman wearing elegant red qipao cheongsam dress, traditional Chinese dress, silk with gold patterns",
29
+ "Scottish Kilt": "man wearing traditional Scottish kilt with tartan pattern, highland dress, sporran",
30
+ "Middle Eastern Thobe": "person wearing white thobe robe, traditional Middle Eastern clothing, flowing fabric"
31
  }
32
 
33
+ def create_body_mask(image):
34
+ """Create mask for clothing area"""
35
+ width, height = image.size
36
+ mask = Image.new('L', (width, height), 0)
37
  draw = ImageDraw.Draw(mask)
38
 
39
  # Body area ellipse
40
+ left = width * 0.2
41
+ top = height * 0.25
42
+ right = width * 0.8
43
+ bottom = height * 0.95
44
 
45
  draw.ellipse([left, top, right, bottom], fill=255)
46
  mask = mask.filter(ImageFilter.GaussianBlur(radius=20))
47
 
48
  return mask
49
 
50
+ @spaces.GPU(duration=60) # Request GPU for 60 seconds
51
+ def generate_clothing(input_image, clothing_type, num_steps=25, guidance_scale=7.5):
52
+ """Generate traditional clothing with ZeroGPU"""
53
 
54
+ if input_image is None:
 
 
 
55
  return None, "Please upload an image first"
56
 
57
  try:
58
+ # Move model to GPU (ZeroGPU allocates it now)
59
+ pipe.to("cuda")
60
+
61
+ # Convert to PIL if needed
62
+ if isinstance(input_image, np.ndarray):
63
+ image = Image.fromarray(input_image).convert("RGB")
64
  else:
65
+ image = input_image.convert("RGB")
66
 
67
  # Store original size
68
  original_size = image.size
69
 
70
+ # Resize for processing
71
+ max_size = 512
72
+ if max(image.size) > max_size:
73
+ ratio = max_size / max(image.size)
74
+ new_size = tuple(int(dim * ratio) for dim in image.size)
75
+ image = image.resize(new_size, Image.Resampling.LANCZOS)
76
 
77
  # Create mask
78
+ mask = create_body_mask(image)
79
 
80
+ # Get prompt
81
+ prompt = CLOTHING_PROMPTS[clothing_type]
82
+ negative_prompt = "nude, naked, nsfw, bad quality, blurry, distorted"
83
 
84
+ # Generate with GPU
85
+ with torch.autocast("cuda"):
86
+ result = pipe(
87
+ prompt=prompt,
88
+ negative_prompt=negative_prompt,
89
+ image=image,
90
+ mask_image=mask,
91
+ num_inference_steps=num_steps,
92
+ guidance_scale=guidance_scale,
93
+ strength=0.95
94
+ ).images[0]
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # Resize back
97
  if result.size != original_size:
98
  result = result.resize(original_size, Image.Resampling.LANCZOS)
99
 
100
+ # Move model back to CPU to free GPU
101
+ pipe.to("cpu")
102
+ torch.cuda.empty_cache()
103
+
104
  return result, f"βœ… Successfully added {clothing_type}!"
105
 
106
  except Exception as e:
107
  print(f"Generation error: {e}")
108
+ return None, f"Error: {str(e)}"
109
 
110
+ # Create interface
111
+ with gr.Blocks(title="Traditional Clothing AI - ZeroGPU", theme=gr.themes.Soft()) as app:
112
+ gr.Markdown("""
113
  # πŸ‘˜ Traditional Clothing Addition Tool
114
 
115
+ **Powered by ZeroGPU** πŸš€ - Free GPU acceleration!
 
116
 
117
+ Add beautiful traditional clothing from various cultures to your photos.
118
+ Generation takes about 30-45 seconds per image.
119
  """)
120
 
121
  with gr.Row():
122
  with gr.Column():
123
+ input_image = gr.Image(
124
  label="Upload Your Photo",
125
  type="pil"
126
  )
127
 
128
+ clothing_type = gr.Dropdown(
129
+ choices=list(CLOTHING_PROMPTS.keys()),
130
  value="Indian Sari",
131
+ label="Select Traditional Clothing"
132
  )
133
 
134
+ with gr.Accordion("Advanced Settings", open=False):
135
+ num_steps = gr.Slider(
136
+ minimum=15,
137
+ maximum=50,
138
+ value=25,
139
+ step=5,
140
+ label="Quality Steps",
141
+ info="More steps = better quality but slower"
142
+ )
143
+
144
+ guidance_scale = gr.Slider(
145
+ minimum=5,
146
+ maximum=15,
147
+ value=7.5,
148
+ step=0.5,
149
+ label="Guidance Scale",
150
+ info="Higher = more adherence to prompt"
151
+ )
152
 
153
  generate_btn = gr.Button(
154
  "🎨 Add Traditional Clothing",
 
157
  )
158
 
159
  with gr.Column():
160
+ output_image = gr.Image(
161
  label="Result"
162
  )
163
 
164
  status_text = gr.Textbox(
165
  label="Status",
166
+ placeholder="Upload an image and click generate..."
167
  )
168
 
 
169
  gr.Markdown("""
170
+ ---
171
+ ### πŸ’‘ Tips for Best Results:
172
  - Use clear, front-facing photos
173
  - Good lighting improves results
174
+ - The person should be fully visible
175
+ - Processing uses free GPU via ZeroGPU
176
 
177
  ### 🌍 Cultural Note:
178
+ This tool celebrates cultural diversity through traditional clothing.
179
+ AI-generated results are artistic interpretations.
180
+ Please use respectfully.
181
+
182
+ ### ⚑ About ZeroGPU:
183
+ This Space uses Hugging Face's free ZeroGPU feature.
184
+ GPU is allocated only during generation, which saves resources!
185
  """)
186
 
187
  # Connect button
188
  generate_btn.click(
189
+ fn=generate_clothing,
190
+ inputs=[input_image, clothing_type, num_steps, guidance_scale],
191
+ outputs=[output_image, status_text]
192
  )
193
 
194
  if __name__ == "__main__":