ihabooe commited on
Commit
02c83f5
·
verified ·
1 Parent(s): c0a712f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -2
app.py CHANGED
@@ -25,7 +25,6 @@ OUTPUT_DIR = "output_images"
25
  os.makedirs(OUTPUT_DIR, exist_ok=True)
26
 
27
  def resize_image(image, max_size=1024):
28
- """Resize image while maintaining aspect ratio"""
29
  width, height = image.size
30
  aspect_ratio = width / height
31
 
@@ -41,7 +40,6 @@ def resize_image(image, max_size=1024):
41
  return image
42
 
43
  def process(image, progress=gr.Progress()):
44
- """Process the image and remove background"""
45
  if image is None:
46
  return None, None
47
 
@@ -89,6 +87,154 @@ def process(image, progress=gr.Progress()):
89
  filepath = os.path.join(OUTPUT_DIR, filename)
90
  new_im.save(filepath, format='PNG', quality=100)
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  progress(1.0, desc="Done!")
93
  return gr.Image.update(value=filepath, visible=True), gr.File.update(value=filepath, visible=True)
94
 
 
25
  os.makedirs(OUTPUT_DIR, exist_ok=True)
26
 
27
  def resize_image(image, max_size=1024):
 
28
  width, height = image.size
29
  aspect_ratio = width / height
30
 
 
40
  return image
41
 
42
  def process(image, progress=gr.Progress()):
 
43
  if image is None:
44
  return None, None
45
 
 
87
  filepath = os.path.join(OUTPUT_DIR, filename)
88
  new_im.save(filepath, format='PNG', quality=100)
89
 
90
+ # Convert to RGBA array for display
91
+ output_array = np.array(new_im.convert('RGBA'))
92
+
93
+ progress(1.0, desc="Done!")
94
+ return output_array, gr.File.update(value=filepath, visible=True)
95
+
96
+ except Exception as e:
97
+ print(f"Error processing image: {str(e)}")
98
+ return None, None
99
+
100
+ # Gradio interface
101
+ with gr.Blocks(css="""
102
+ @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap');
103
+
104
+ .container { max-width: 850px; margin: 0 auto; padding: 20px; }
105
+
106
+ .title-text {
107
+ color: #ff00de;
108
+ font-family: 'Orbitron', sans-serif;
109
+ font-size: 2.5em;
110
+ text-align: center;
111
+ margin: 20px 0;
112
+ text-shadow: 0 0 10px rgba(255, 0, 222, 0.7);
113
+ animation: glow 2s ease-in-out infinite alternate;
114
+ }
115
+
116
+ .subtitle-text {
117
+ color: #00ffff;
118
+ text-align: center;
119
+ margin-bottom: 30px;
120
+ font-size: 1.2em;
121
+ text-shadow: 0 0 8px rgba(0, 255, 255, 0.7);
122
+ }
123
+
124
+ .image-container {
125
+ background: rgba(10, 10, 30, 0.3);
126
+ border-radius: 15px;
127
+ padding: 20px;
128
+ margin: 10px 0;
129
+ border: 2px solid #00ffff;
130
+ box-shadow: 0 0 15px rgba(0, 255, 255, 0.2);
131
+ transition: all 0.3s ease;
132
+ }
133
+
134
+ .image-container:hover {
135
+ box-shadow: 0 0 20px rgba(0, 255, 255, 0.4);
136
+ transform: translateY(-2px);
137
+ }
138
+
139
+ .download-btn {
140
+ background: linear-gradient(45deg, #00ffff, #ff00de);
141
+ border: none;
142
+ padding: 12px 25px;
143
+ border-radius: 8px;
144
+ color: white;
145
+ font-family: 'Orbitron', sans-serif;
146
+ cursor: pointer;
147
+ transition: all 0.3s ease;
148
+ margin-top: 10px;
149
+ text-align: center;
150
+ text-transform: uppercase;
151
+ letter-spacing: 1px;
152
+ }
153
+
154
+ .download-btn:hover {
155
+ transform: translateY(-2px);
156
+ box-shadow: 0 5px 15px rgba(0, 255, 255, 0.4);
157
+ }
158
+
159
+ @keyframes glow {
160
+ from {
161
+ text-shadow: 0 0 5px #ff00de, 0 0 10px #ff00de, 0 0 15px #ff00de;
162
+ }
163
+ to {
164
+ text-shadow: 0 0 10px #ff00de, 0 0 20px #ff00de, 0 0 30px #ff00de;
165
+ }
166
+ }
167
+
168
+ @media (max-width: 768px) {
169
+ .title-text { font-size: 1.8em; }
170
+ .subtitle-text { font-size: 1em; }
171
+ .image-container { padding: 10px; }
172
+ .download-btn { padding: 10px 20px; }
173
+ }
174
+ """) as demo:
175
+ gr.Markdown("""
176
+ <h1 class="title-text">AI Background Removal</h1>
177
+ <p class="subtitle-text">Remove backgrounds instantly using advanced AI technology</p>
178
+ """)
179
+
180
+ with gr.Row():
181
+ with gr.Column():
182
+ input_image = gr.Image(
183
+ label="Upload Image",
184
+ type="numpy",
185
+ elem_classes="image-container"
186
+ )
187
+
188
+ output_image = gr.Image(
189
+ label="Result",
190
+ type="numpy", # Changed from filepath to numpy
191
+ elem_classes="image-container"
192
+ )
193
+
194
+ download_file = gr.File(
195
+ label="Download Processed Image",
196
+ visible=False,
197
+ elem_classes="download-btn"
198
+ )
199
+
200
+ # Automatic processing when image is uploaded
201
+ input_image.change(
202
+ fn=process,
203
+ inputs=input_image,
204
+ outputs=[output_image, download_file],
205
+ api_name="process_image"
206
+ )
207
+
208
+ if __name__ == "__main__":
209
+ demo.launch()
210
+ progress(0.4, desc="Processing with AI model...")
211
+ if torch.cuda.is_available():
212
+ im_tensor = im_tensor.cuda()
213
+
214
+ with torch.no_grad():
215
+ result = net(im_tensor)
216
+
217
+ progress(0.6, desc="Post-processing...")
218
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
219
+ ma = torch.max(result)
220
+ mi = torch.min(result)
221
+ result = (result - mi) / (ma - mi)
222
+
223
+ result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
224
+ pil_mask = Image.fromarray(np.squeeze(result_array))
225
+
226
+ if pil_mask.size != original_size:
227
+ pil_mask = pil_mask.resize(original_size, Image.LANCZOS)
228
+
229
+ new_im = orig_image.copy()
230
+ new_im.putalpha(pil_mask)
231
+
232
+ progress(0.8, desc="Saving result...")
233
+ unique_id = str(uuid.uuid4())[:8]
234
+ filename = f"background_removed_{unique_id}.png"
235
+ filepath = os.path.join(OUTPUT_DIR, filename)
236
+ new_im.save(filepath, format='PNG', quality=100)
237
+
238
  progress(1.0, desc="Done!")
239
  return gr.Image.update(value=filepath, visible=True), gr.File.update(value=filepath, visible=True)
240