ihabooe commited on
Commit
14ec6bc
·
verified ·
1 Parent(s): 270ef81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -232
app.py CHANGED
@@ -12,7 +12,6 @@ import time
12
  import uuid
13
  import shutil
14
 
15
- # Load the pre-trained model
16
  print("Loading model...")
17
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -20,14 +19,12 @@ net.to(device)
20
  net.eval()
21
  print(f"Model loaded on {device}")
22
 
23
- # Create output directory if it doesn't exist
24
  OUTPUT_DIR = "output_images"
25
  os.makedirs(OUTPUT_DIR, exist_ok=True)
26
 
27
  def resize_image(image, max_size=1024):
28
  width, height = image.size
29
  aspect_ratio = width / height
30
-
31
  if width > max_size or height > max_size:
32
  if width > height:
33
  new_width = max_size
@@ -35,16 +32,12 @@ def resize_image(image, max_size=1024):
35
  else:
36
  new_height = max_size
37
  new_width = int(max_size * aspect_ratio)
38
- image = image.resize((new_width, new_height), Image.LANCZOS)
39
-
40
  return image
41
 
42
- # ... existing imports and model loading code ...
43
-
44
  def process(image, progress=gr.Progress()):
45
  if image is None:
46
  return None, None
47
-
48
  try:
49
  progress(0, desc="Starting processing...")
50
  orig_image = Image.fromarray(image)
@@ -54,7 +47,6 @@ def process(image, progress=gr.Progress()):
54
  process_image = resize_image(orig_image)
55
  w, h = process_image.size
56
 
57
- # Convert image to tensor
58
  im_np = np.array(process_image)
59
  im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
60
  im_tensor = torch.unsqueeze(im_tensor, 0)
@@ -89,238 +81,91 @@ def process(image, progress=gr.Progress()):
89
  filepath = os.path.join(OUTPUT_DIR, filename)
90
  new_im.save(filepath, format='PNG', quality=100)
91
 
92
- # Convert to RGBA array for display
93
  output_array = np.array(new_im.convert('RGBA'))
94
 
95
  progress(1.0, desc="Done!")
96
  return output_array, gr.File.update(value=filepath, visible=True)
97
-
98
  except Exception as e:
99
  print(f"Error processing image: {str(e)}")
100
  return None, None
101
 
102
- # ... rest of the code remains the same ...
 
103
 
104
- # Gradio interface
105
- with gr.Blocks(css="""
106
- @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap');
107
-
108
- .container { max-width: 850px; margin: 0 auto; padding: 20px; }
109
-
110
- .title-text {
111
- color: #ff00de;
112
- font-family: 'Orbitron', sans-serif;
113
- font-size: 2.5em;
114
- text-align: center;
115
- margin: 20px 0;
116
- text-shadow: 0 0 10px rgba(255, 0, 222, 0.7);
117
- animation: glow 2s ease-in-out infinite alternate;
118
- }
119
-
120
- .subtitle-text {
121
- color: #00ffff;
122
- text-align: center;
123
- margin-bottom: 30px;
124
- font-size: 1.2em;
125
- text-shadow: 0 0 8px rgba(0, 255, 255, 0.7);
126
- }
127
-
128
- .image-container {
129
- background: rgba(10, 10, 30, 0.3);
130
- border-radius: 15px;
131
- padding: 20px;
132
- margin: 10px 0;
133
- border: 2px solid #00ffff;
134
- box-shadow: 0 0 15px rgba(0, 255, 255, 0.2);
135
- transition: all 0.3s ease;
136
- }
137
-
138
- .image-container:hover {
139
- box-shadow: 0 0 20px rgba(0, 255, 255, 0.4);
140
- transform: translateY(-2px);
141
- }
142
-
143
- .download-btn {
144
- background: linear-gradient(45deg, #00ffff, #ff00de);
145
- border: none;
146
- padding: 12px 25px;
147
- border-radius: 8px;
148
- color: white;
149
- font-family: 'Orbitron', sans-serif;
150
- cursor: pointer;
151
- transition: all 0.3s ease;
152
- margin-top: 10px;
153
- text-align: center;
154
- text-transform: uppercase;
155
- letter-spacing: 1px;
156
- }
157
-
158
- .download-btn:hover {
159
- transform: translateY(-2px);
160
- box-shadow: 0 5px 15px rgba(0, 255, 255, 0.4);
161
- }
162
-
163
- @keyframes glow {
164
- from {
165
- text-shadow: 0 0 5px #ff00de, 0 0 10px #ff00de, 0 0 15px #ff00de;
166
- }
167
- to {
168
- text-shadow: 0 0 10px #ff00de, 0 0 20px #ff00de, 0 0 30px #ff00de;
169
- }
170
- }
171
-
172
- @media (max-width: 768px) {
173
- .title-text { font-size: 1.8em; }
174
- .subtitle-text { font-size: 1em; }
175
- .image-container { padding: 10px; }
176
- .download-btn { padding: 10px 20px; }
177
- }
178
- """) as demo:
179
- gr.Markdown("""
180
- <h1 class="title-text">AI Background Removal</h1>
181
- <p class="subtitle-text">Remove backgrounds instantly using advanced AI technology</p>
182
- """)
183
-
184
- with gr.Row():
185
- with gr.Column():
186
- input_image = gr.Image(
187
- label="Upload Image",
188
- type="numpy",
189
- elem_classes="image-container"
190
- )
191
-
192
- output_image = gr.Image(
193
- label="Result",
194
- type="numpy", # Changed from filepath to numpy
195
- elem_classes="image-container"
196
- )
197
-
198
- download_file = gr.File(
199
- label="Download Processed Image",
200
- visible=False,
201
- elem_classes="download-btn"
202
- )
203
-
204
- # Automatic processing when image is uploaded
205
- input_image.change(
206
- fn=process,
207
- inputs=input_image,
208
- outputs=[output_image, download_file],
209
- api_name="process_image"
210
- )
211
 
212
- if __name__ == "__main__":
213
- demo.launch()
214
- progress(0.4, desc="Processing with AI model...")
215
- if torch.cuda.is_available():
216
- im_tensor = im_tensor.cuda()
217
-
218
- with torch.no_grad():
219
- result = net(im_tensor)
220
-
221
- progress(0.6, desc="Post-processing...")
222
- result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
223
- ma = torch.max(result)
224
- mi = torch.min(result)
225
- result = (result - mi) / (ma - mi)
226
-
227
- result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
228
- pil_mask = Image.fromarray(np.squeeze(result_array))
229
-
230
- if pil_mask.size != original_size:
231
- pil_mask = pil_mask.resize(original_size, Image.LANCZOS)
232
-
233
- new_im = orig_image.copy()
234
- new_im.putalpha(pil_mask)
235
-
236
- progress(0.8, desc="Saving result...")
237
- unique_id = str(uuid.uuid4())[:8]
238
- filename = f"background_removed_{unique_id}.png"
239
- filepath = os.path.join(OUTPUT_DIR, filename)
240
- new_im.save(filepath, format='PNG', quality=100)
241
-
242
- progress(1.0, desc="Done!")
243
- return gr.Image.update(value=filepath, visible=True), gr.File.update(value=filepath, visible=True)
244
-
245
- except Exception as e:
246
- print(f"Error processing image: {str(e)}")
247
- return None, None
248
 
249
- # Gradio interface
250
- with gr.Blocks(css="""
251
- @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap');
252
-
253
- .container { max-width: 850px; margin: 0 auto; padding: 20px; }
254
-
255
- .title-text {
256
- color: #ff00de;
257
- font-family: 'Orbitron', sans-serif;
258
- font-size: 2.5em;
259
- text-align: center;
260
- margin: 20px 0;
261
- text-shadow: 0 0 10px rgba(255, 0, 222, 0.7);
262
- animation: glow 2s ease-in-out infinite alternate;
263
- }
264
-
265
- .subtitle-text {
266
- color: #00ffff;
267
- text-align: center;
268
- margin-bottom: 30px;
269
- font-size: 1.2em;
270
- text-shadow: 0 0 8px rgba(0, 255, 255, 0.7);
271
- }
272
-
273
- .image-container {
274
- background: rgba(10, 10, 30, 0.3);
275
- border-radius: 15px;
276
- padding: 20px;
277
- margin: 10px 0;
278
- border: 2px solid #00ffff;
279
- box-shadow: 0 0 15px rgba(0, 255, 255, 0.2);
280
- transition: all 0.3s ease;
281
- }
282
-
283
- .image-container:hover {
284
- box-shadow: 0 0 20px rgba(0, 255, 255, 0.4);
285
- transform: translateY(-2px);
286
- }
287
-
288
- .download-btn {
289
- background: linear-gradient(45deg, #00ffff, #ff00de);
290
- border: none;
291
- padding: 12px 25px;
292
- border-radius: 8px;
293
- color: white;
294
- font-family: 'Orbitron', sans-serif;
295
- cursor: pointer;
296
- transition: all 0.3s ease;
297
- margin-top: 10px;
298
- text-align: center;
299
- text-transform: uppercase;
300
- letter-spacing: 1px;
301
- }
302
-
303
- .download-btn:hover {
304
- transform: translateY(-2px);
305
- box-shadow: 0 5px 15px rgba(0, 255, 255, 0.4);
306
- }
307
-
308
- @keyframes glow {
309
- from {
310
- text-shadow: 0 0 5px #ff00de, 0 0 10px #ff00de, 0 0 15px #ff00de;
311
- }
312
- to {
313
- text-shadow: 0 0 10px #ff00de, 0 0 20px #ff00de, 0 0 30px #ff00de;
314
- }
315
  }
316
-
317
- @media (max-width: 768px) {
318
- .title-text { font-size: 1.8em; }
319
- .subtitle-text { font-size: 1em; }
320
- .image-container { padding: 10px; }
321
- .download-btn { padding: 10px 20px; }
322
  }
323
- """) as demo:
 
 
 
 
 
 
 
 
 
 
324
  gr.Markdown("""
325
  <h1 class="title-text">AI Background Removal</h1>
326
  <p class="subtitle-text">Remove backgrounds instantly using advanced AI technology</p>
@@ -336,9 +181,8 @@ with gr.Blocks(css="""
336
 
337
  output_image = gr.Image(
338
  label="Result",
339
- type="filepath",
340
- elem_classes="image-container",
341
- visible=True
342
  )
343
 
344
  download_file = gr.File(
@@ -347,7 +191,6 @@ with gr.Blocks(css="""
347
  elem_classes="download-btn"
348
  )
349
 
350
- # Automatic processing when image is uploaded
351
  input_image.change(
352
  fn=process,
353
  inputs=input_image,
 
12
  import uuid
13
  import shutil
14
 
 
15
  print("Loading model...")
16
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
19
  net.eval()
20
  print(f"Model loaded on {device}")
21
 
 
22
  OUTPUT_DIR = "output_images"
23
  os.makedirs(OUTPUT_DIR, exist_ok=True)
24
 
25
  def resize_image(image, max_size=1024):
26
  width, height = image.size
27
  aspect_ratio = width / height
 
28
  if width > max_size or height > max_size:
29
  if width > height:
30
  new_width = max_size
 
32
  else:
33
  new_height = max_size
34
  new_width = int(max_size * aspect_ratio)
35
+ image = resize_image.resize((new_width, new_height), Image.LANCZOS)
 
36
  return image
37
 
 
 
38
  def process(image, progress=gr.Progress()):
39
  if image is None:
40
  return None, None
 
41
  try:
42
  progress(0, desc="Starting processing...")
43
  orig_image = Image.fromarray(image)
 
47
  process_image = resize_image(orig_image)
48
  w, h = process_image.size
49
 
 
50
  im_np = np.array(process_image)
51
  im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
52
  im_tensor = torch.unsqueeze(im_tensor, 0)
 
81
  filepath = os.path.join(OUTPUT_DIR, filename)
82
  new_im.save(filepath, format='PNG', quality=100)
83
 
 
84
  output_array = np.array(new_im.convert('RGBA'))
85
 
86
  progress(1.0, desc="Done!")
87
  return output_array, gr.File.update(value=filepath, visible=True)
88
+
89
  except Exception as e:
90
  print(f"Error processing image: {str(e)}")
91
  return None, None
92
 
93
+ css = """
94
+ @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap');
95
 
96
+ .container { max-width: 850px; margin: 0 auto; padding: 20px; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ .title-text {
99
+ color: #ff00de;
100
+ font-family: 'Orbitron', sans-serif;
101
+ font-size: 2.5em;
102
+ text-align: center;
103
+ margin: 20px 0;
104
+ text-shadow: 0 0 10px rgba(255, 0, 222, 0.7);
105
+ animation: glow 2s ease-in-out infinite alternate;
106
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ .subtitle-text {
109
+ color: #00ffff;
110
+ text-align: center;
111
+ margin-bottom: 30px;
112
+ font-size: 1.2em;
113
+ text-shadow: 0 0 8px rgba(0, 255, 255, 0.7);
114
+ }
115
+
116
+ .image-container {
117
+ background: rgba(10, 10, 30, 0.3);
118
+ border-radius: 15px;
119
+ padding: 20px;
120
+ margin: 10px 0;
121
+ border: 2px solid #00ffff;
122
+ box-shadow: 0 0 15px rgba(0, 255, 255, 0.2);
123
+ transition: all 0.3s ease;
124
+ }
125
+
126
+ .image-container:hover {
127
+ box-shadow: 0 0 20px rgba(0, 255, 255, 0.4);
128
+ transform: translateY(-2px);
129
+ }
130
+
131
+ .download-btn {
132
+ background: linear-gradient(45deg, #00ffff, #ff00de);
133
+ border: none;
134
+ padding: 12px 25px;
135
+ border-radius: 8px;
136
+ color: white;
137
+ font-family: 'Orbitron', sans-serif;
138
+ cursor: pointer;
139
+ transition: all 0.3s ease;
140
+ margin-top: 10px;
141
+ text-align: center;
142
+ text-transform: uppercase;
143
+ letter-spacing: 1px;
144
+ }
145
+
146
+ .download-btn:hover {
147
+ transform: translateY(-2px);
148
+ box-shadow: 0 5px 15px rgba(0, 255, 255, 0.4);
149
+ }
150
+
151
+ @keyframes glow {
152
+ from {
153
+ text-shadow: 0 0 5px #ff00de, 0 0 10px #ff00de, 0 0 15px #ff00de;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  }
155
+ to {
156
+ text-shadow: 0 0 10px #ff00de, 0 0 20px #ff00de, 0 0 30px #ff00de;
 
 
 
 
157
  }
158
+ }
159
+
160
+ @media (max-width: 768px) {
161
+ .title-text { font-size: 1.8em; }
162
+ .subtitle-text { font-size: 1em; }
163
+ .image-container { padding: 10px; }
164
+ .download-btn { padding: 10px 20px; }
165
+ }
166
+ """
167
+
168
+ with gr.Blocks(css=css) as demo:
169
  gr.Markdown("""
170
  <h1 class="title-text">AI Background Removal</h1>
171
  <p class="subtitle-text">Remove backgrounds instantly using advanced AI technology</p>
 
181
 
182
  output_image = gr.Image(
183
  label="Result",
184
+ type="numpy",
185
+ elem_classes="image-container"
 
186
  )
187
 
188
  download_file = gr.File(
 
191
  elem_classes="download-btn"
192
  )
193
 
 
194
  input_image.change(
195
  fn=process,
196
  inputs=input_image,