ihabooe commited on
Commit
866ec3b
·
verified ·
1 Parent(s): 64dbdd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -278
app.py CHANGED
@@ -1,155 +1,65 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn.functional as F
4
- from torchvision.transforms.functional import normalize
5
- import gradio as gr
6
- from briarmbg import BriaRMBG
7
- import PIL
8
- from PIL import Image
9
- import tempfile
10
- import os
11
- import time
12
- import uuid
13
- import shutil
14
 
15
- # Load the pre-trained model
16
- print("Loading model...")
17
- net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
18
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- net.to(device)
20
- net.eval()
21
- print(f"Model loaded on {device}")
22
-
23
- # Create output directory if it doesn't exist
24
- OUTPUT_DIR = "output_images"
25
- os.makedirs(OUTPUT_DIR, exist_ok=True)
26
-
27
- def resize_image(image, max_size=1024):
28
- """Resize image while maintaining aspect ratio and quality"""
29
- # Get original size
30
- width, height = image.size
31
-
32
- # Calculate aspect ratio
33
- aspect_ratio = width / height
34
-
35
- # Only resize if the image is larger than max_size in either dimension
36
- if width > max_size or height > max_size:
37
- if width > height:
38
- new_width = max_size
39
- new_height = int(max_size / aspect_ratio)
40
- else:
41
- new_height = max_size
42
- new_width = int(max_size * aspect_ratio)
43
- image = image.resize((new_width, new_height), Image.LANCZOS)
44
-
45
- return image
46
-
47
- def process(image, progress=gr.Progress()):
48
- if image is None:
49
- return None, gr.update(visible=False)
50
-
51
- progress(0, desc="Starting processing...")
52
-
53
- # Prepare the input
54
- progress(0.1, desc="Preparing image...")
55
- orig_image = Image.fromarray(image)
56
- original_size = orig_image.size
57
-
58
- # Resize only if needed for processing
59
- process_image = resize_image(orig_image)
60
- w, h = process_image.size
61
-
62
- im_np = np.array(process_image)
63
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
64
- im_tensor = torch.unsqueeze(im_tensor, 0)
65
- im_tensor = torch.divide(im_tensor, 255.0)
66
- im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
67
-
68
- progress(0.3, desc="Processing with AI model...")
69
- if torch.cuda.is_available():
70
- im_tensor = im_tensor.cuda()
71
-
72
- # Inference with the model
73
- with torch.no_grad():
74
- result = net(im_tensor)
75
-
76
- progress(0.6, desc="Post-processing...")
77
- # Post-process the result
78
- result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
79
- ma = torch.max(result)
80
- mi = torch.min(result)
81
- result = (result - mi) / (ma - mi)
82
-
83
- # Convert the result to an image
84
- result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
85
- pil_mask = Image.fromarray(np.squeeze(result_array))
86
-
87
- # Resize mask back to original size if needed
88
- if pil_mask.size != original_size:
89
- pil_mask = pil_mask.resize(original_size, Image.LANCZOS)
90
-
91
- # Add the mask as alpha channel to the original image
92
- new_im = orig_image.copy()
93
- new_im.putalpha(pil_mask)
94
-
95
- progress(0.8, desc="Preparing download...")
96
- # Generate a unique filename
97
- unique_id = str(uuid.uuid4())[:8]
98
- filename = f"background_removed_{unique_id}.png"
99
- filepath = os.path.join(OUTPUT_DIR, filename)
100
-
101
- # Save the processed image in original resolution
102
- new_im.save(filepath, format='PNG', quality=100)
103
-
104
- # Convert to numpy array for display
105
- output_array = np.array(new_im.convert("RGBA"))
106
-
107
- progress(1.0, desc="Done!")
108
-
109
- return output_array, gr.update(visible=True, value=filepath, interactive=True)
110
-
111
- # Gradio interface setup
112
  title = "Background Removal Tool"
113
  description = """
114
  <style>
115
- .custom-container {
116
- text-align: center;
117
- max-width: 100%;
118
- margin: 0 auto;
119
- padding: 20px;
120
- background: rgba(10, 10, 30, 0.6);
121
- border-radius: 15px;
122
- box-shadow: 0 0 20px rgba(0, 255, 255, 0.2);
 
 
 
 
123
  }
124
- .title-text {
125
- color: #ff00de;
126
- font-family: 'Orbitron', sans-serif;
127
- font-size: 2em;
128
- margin: 20px 0;
129
- text-shadow: 0 0 10px rgba(255, 0, 222, 0.5);
130
- animation: title-pulse 2s infinite alternate;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  }
132
- .subtitle-text {
133
- color: #00ffff;
134
- font-family: 'Roboto Mono', monospace;
135
- font-size: 1em;
136
- margin-top: 10px;
137
- line-height: 1.5;
138
- text-shadow: 0 0 10px rgba(0, 255, 255, 0.5);
 
 
 
139
  }
140
- @keyframes title-pulse {
141
- 0% { text-shadow: 0 0 5px rgba(255, 0, 222, 0.5); }
142
- 100% { text-shadow: 0 0 15px rgba(255, 0, 222, 0.8), 0 0 25px rgba(255, 0, 222, 0.5); }
 
143
  }
144
 
145
- /* Responsive text sizes */
146
- @media (max-width: 768px) {
147
- .title-text {
148
- font-size: 1.5em;
149
- }
150
- .subtitle-text {
151
- font-size: 0.9em;
152
- }
153
  }
154
  </style>
155
  <div class="custom-container">
@@ -162,154 +72,118 @@ description = """
162
 
163
  # Create the Gradio interface
164
  with gr.Blocks(css="""
165
- /* Import fonts */
166
- @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&family=Roboto+Mono:wght@300;400;700&display=swap');
167
-
168
- /* Variables */
169
- :root {
170
- --neon-cyan: #00ffff;
171
- --neon-pink: #ff00de;
172
- --neon-yellow: #ffdd00;
173
- --dark-background: #0a0a1e;
174
- --deep-blue: #121238;
175
- }
176
-
177
- /* Global styles */
178
- body {
179
- font-family: 'Roboto Mono', monospace;
180
- background: linear-gradient(135deg, var(--dark-background) 0%, var(--deep-blue) 100%);
181
- color: #ffffff;
182
- min-height: 100vh;
183
- }
184
-
185
- /* Responsive container */
186
- .container {
187
- width: 100%;
188
- max-width: 1200px;
189
- margin: 0 auto;
190
- padding: 10px;
191
- }
192
-
193
- /* Input/Output areas with responsive sizing */
194
- .input-image, .output-image {
195
- width: 100% !important;
196
- max-width: 800px !important;
197
- height: auto !important;
198
- min-height: 300px !important;
199
- object-fit: contain !important;
200
- background: rgba(18, 18, 56, 0.7) !important;
201
- border: 2px solid var(--neon-cyan) !important;
202
- border-radius: 12px !important;
203
- transition: all 0.3s ease !important;
204
- overflow: hidden !important;
205
- margin: 0 auto !important;
206
- }
207
-
208
- .input-image img, .output-image img {
209
- max-width: 100% !important;
210
- max-height: 800px !important;
211
- object-fit: contain !important;
212
- margin: auto !important;
213
- }
214
-
215
- /* Responsive columns */
216
- .contain-center {
217
- display: flex;
218
- flex-direction: column;
219
- align-items: center;
220
- gap: 20px;
221
- }
222
-
223
- /* Download button styling */
224
- .download-container [data-testid="file"] button {
225
- background: linear-gradient(45deg, var(--neon-cyan), var(--neon-pink)) !important;
226
- color: white !important;
227
- border: none !important;
228
- padding: 12px 28px !important;
229
- font-family: 'Orbitron', sans-serif !important;
230
- font-size: 16px !important;
231
- font-weight: 600 !important;
232
- text-transform: uppercase !important;
233
- letter-spacing: 1px !important;
234
- border-radius: 8px !important;
235
- cursor: pointer !important;
236
- transition: all 0.3s ease !important;
237
- animation: button-glow 2s infinite alternate !important;
238
- width: 100% !important;
239
- max-width: 300px !important;
240
- }
241
-
242
- /* Labels */
243
- label {
244
- color: var(--neon-cyan) !important;
245
- font-family: 'Orbitron', sans-serif !important;
246
- font-size: 1.1em !important;
247
- text-shadow: 0 0 5px rgba(0, 255, 255, 0.5) !important;
248
- margin-bottom: 8px !important;
249
- text-align: center !important;
250
- }
251
-
252
- /* Responsive layout */
253
- @media (max-width: 768px) {
254
- .input-image, .output-image {
255
- min-height: 200px !important;
256
- }
257
-
258
- .input-image img, .output-image img {
259
- max-height: 500px !important;
260
- }
261
-
262
- label {
263
- font-size: 0.9em !important;
264
- }
265
-
266
- .download-container [data-testid="file"] button {
267
- padding: 10px 20px !important;
268
- font-size: 14px !important;
269
- }
270
- }
271
-
272
- /* Additional Animations */
273
- @keyframes button-glow {
274
- 0% { box-shadow: 0 0 5px rgba(0, 255, 255, 0.5); }
275
- 100% { box-shadow: 0 0 15px rgba(0, 255, 255, 0.8), 0 0 25px rgba(255, 0, 222, 0.5); }
276
- }
277
  """) as demo:
278
  gr.Markdown(description)
279
 
280
- with gr.Row(equal_height=True):
281
- with gr.Column(scale=1):
282
- input_image = gr.Image(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  type="numpy",
284
- label="Upload Your Image",
285
- elem_id="input-image",
286
- elem_classes="input-image",
287
  container=True
288
  )
289
 
290
- with gr.Column(scale=1):
291
- output_image = gr.Image(
292
- type="numpy",
293
- label="Result",
294
- elem_id="output-image",
295
- elem_classes="output-image",
296
- container=True
297
  )
 
 
 
 
 
 
298
 
299
- with gr.Row(elem_classes="download-container"):
300
- download_file = gr.File(
301
- label="",
302
- file_count="single",
303
- interactive=True,
304
- visible=False
305
- )
306
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  # Process automatically when image is uploaded
308
  input_image.change(
309
  fn=process,
310
  inputs=input_image,
311
- outputs=[output_image, download_file],
312
- show_progress="full"
 
 
 
313
  )
314
 
315
  if __name__ == "__main__":
 
1
+ # ... keep all imports and processing functions the same until the interface part ...
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ # Update the Gradio interface setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  title = "Background Removal Tool"
5
  description = """
6
  <style>
7
+ /* ... previous styles ... */
8
+
9
+ /* Image comparison slider styles */
10
+ .image-comparison {
11
+ position: relative;
12
+ width: 100% !important;
13
+ max-width: 800px !important;
14
+ margin: 0 auto !important;
15
+ overflow: hidden !important;
16
+ border-radius: 12px !important;
17
+ border: 2px solid var(--neon-cyan) !important;
18
+ background: rgba(18, 18, 56, 0.7) !important;
19
  }
20
+
21
+ .comparison-slider {
22
+ position: absolute !important;
23
+ width: 4px !important;
24
+ height: 100% !important;
25
+ background: var(--neon-cyan) !important;
26
+ box-shadow: 0 0 10px rgba(0, 255, 255, 0.5) !important;
27
+ z-index: 10 !important;
28
+ cursor: ew-resize !important;
29
+ }
30
+
31
+ .slider-handle {
32
+ position: absolute !important;
33
+ width: 40px !important;
34
+ height: 40px !important;
35
+ background: var(--neon-cyan) !important;
36
+ border-radius: 50% !important;
37
+ top: 50% !important;
38
+ left: 50% !important;
39
+ transform: translate(-50%, -50%) !important;
40
+ box-shadow: 0 0 15px rgba(0, 255, 255, 0.8) !important;
41
+ cursor: ew-resize !important;
42
  }
43
+
44
+ .slider-handle::before,
45
+ .slider-handle::after {
46
+ content: '';
47
+ position: absolute;
48
+ width: 2px;
49
+ height: 50%;
50
+ background: rgba(0, 0, 0, 0.8);
51
+ left: 50%;
52
+ transform: translateX(-50%);
53
  }
54
+
55
+ .slider-handle::before {
56
+ top: 25%;
57
+ transform: translateX(-50%) rotate(45deg);
58
  }
59
 
60
+ .slider-handle::after {
61
+ top: 25%;
62
+ transform: translateX(-50%) rotate(-45deg);
 
 
 
 
 
63
  }
64
  </style>
65
  <div class="custom-container">
 
72
 
73
  # Create the Gradio interface
74
  with gr.Blocks(css="""
75
+ /* ... previous CSS styles ... */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  """) as demo:
77
  gr.Markdown(description)
78
 
79
+ with gr.Column(scale=1):
80
+ input_image = gr.Image(
81
+ type="numpy",
82
+ label="Upload Your Image",
83
+ elem_id="input-image",
84
+ elem_classes="input-image",
85
+ container=True
86
+ )
87
+
88
+ # Hidden output image for processing
89
+ output_image = gr.Image(
90
+ type="numpy",
91
+ visible=False
92
+ )
93
+
94
+ # Image comparison component
95
+ with gr.Row(elem_classes="image-comparison-container"):
96
+ image_comparison = gr.Image(
97
  type="numpy",
98
+ label="Before / After Comparison",
99
+ elem_id="image-comparison",
100
+ elem_classes="image-comparison",
101
  container=True
102
  )
103
 
104
+ with gr.Row(elem_classes="download-container"):
105
+ download_file = gr.File(
106
+ label="",
107
+ file_count="single",
108
+ interactive=True,
109
+ visible=False
 
110
  )
111
+
112
+ # Custom JavaScript for image comparison slider
113
+ demo.load(js="""
114
+ function initComparison() {
115
+ const container = document.querySelector('.image-comparison');
116
+ if (!container) return;
117
 
118
+ const slider = document.createElement('div');
119
+ slider.className = 'comparison-slider';
120
+ const handle = document.createElement('div');
121
+ handle.className = 'slider-handle';
122
+ slider.appendChild(handle);
123
+ container.appendChild(slider);
124
+
125
+ let isDown = false;
126
+ let startX;
127
+ let sliderLeft;
128
+
129
+ slider.addEventListener('mousedown', (e) => {
130
+ isDown = true;
131
+ startX = e.pageX - slider.offsetLeft;
132
+ });
133
+
134
+ document.addEventListener('mouseup', () => {
135
+ isDown = false;
136
+ });
137
+
138
+ document.addEventListener('mousemove', (e) => {
139
+ if (!isDown) return;
140
+ e.preventDefault();
141
+
142
+ const x = e.pageX - container.offsetLeft;
143
+ const walk = x - startX;
144
+
145
+ const containerWidth = container.offsetWidth;
146
+ let newLeft = (x / containerWidth) * 100;
147
+ newLeft = Math.max(0, Math.min(100, newLeft));
148
+
149
+ slider.style.left = `${newLeft}%`;
150
+ container.style.setProperty('--slider-position', `${newLeft}%`);
151
+ });
152
+ }
153
+
154
+ // Initialize comparison slider when images are loaded
155
+ document.addEventListener('DOMContentLoaded', initComparison);
156
+ // Reinitialize when new images are loaded
157
+ const observer = new MutationObserver(initComparison);
158
+ observer.observe(document.body, { childList: true, subtree: true });
159
+ """)
160
+
161
+ def update_comparison(image, result):
162
+ if result is None:
163
+ return None
164
+ # Create side-by-side comparison
165
+ orig_image = Image.fromarray(image)
166
+ result_image = Image.fromarray(result)
167
+
168
+ # Ensure both images are the same size
169
+ width = max(orig_image.width, result_image.width)
170
+ height = max(orig_image.height, result_image.height)
171
+
172
+ comparison = Image.new('RGBA', (width * 2, height))
173
+ comparison.paste(orig_image, (0, 0))
174
+ comparison.paste(result_image, (width, 0))
175
+
176
+ return np.array(comparison)
177
+
178
  # Process automatically when image is uploaded
179
  input_image.change(
180
  fn=process,
181
  inputs=input_image,
182
+ outputs=[output_image, download_file]
183
+ ).then(
184
+ fn=update_comparison,
185
+ inputs=[input_image, output_image],
186
+ outputs=image_comparison
187
  )
188
 
189
  if __name__ == "__main__":