ihabooe commited on
Commit
1cbc077
·
verified ·
1 Parent(s): 866ec3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -145
app.py CHANGED
@@ -1,181 +1,207 @@
1
- # ... keep all imports and processing functions the same until the interface part ...
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- # Update the Gradio interface setup
4
- title = "Background Removal Tool"
5
- description = """
6
- <style>
7
- /* ... previous styles ... */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- /* Image comparison slider styles */
10
  .image-comparison {
11
  position: relative;
12
  width: 100% !important;
13
  max-width: 800px !important;
14
  margin: 0 auto !important;
15
- overflow: hidden !important;
16
  border-radius: 12px !important;
17
  border: 2px solid var(--neon-cyan) !important;
18
  background: rgba(18, 18, 56, 0.7) !important;
 
19
  }
20
 
21
- .comparison-slider {
22
- position: absolute !important;
23
- width: 4px !important;
24
- height: 100% !important;
25
- background: var(--neon-cyan) !important;
26
- box-shadow: 0 0 10px rgba(0, 255, 255, 0.5) !important;
27
- z-index: 10 !important;
28
- cursor: ew-resize !important;
29
- }
30
-
31
- .slider-handle {
32
- position: absolute !important;
33
- width: 40px !important;
34
- height: 40px !important;
35
- background: var(--neon-cyan) !important;
36
- border-radius: 50% !important;
37
- top: 50% !important;
38
- left: 50% !important;
39
- transform: translate(-50%, -50%) !important;
40
- box-shadow: 0 0 15px rgba(0, 255, 255, 0.8) !important;
41
- cursor: ew-resize !important;
42
  }
43
 
44
- .slider-handle::before,
45
- .slider-handle::after {
46
- content: '';
47
- position: absolute;
48
- width: 2px;
49
- height: 50%;
50
- background: rgba(0, 0, 0, 0.8);
51
- left: 50%;
52
- transform: translateX(-50%);
53
  }
54
 
55
- .slider-handle::before {
56
- top: 25%;
57
- transform: translateX(-50%) rotate(45deg);
 
 
 
 
 
 
58
  }
59
 
60
- .slider-handle::after {
61
- top: 25%;
62
- transform: translateX(-50%) rotate(-45deg);
63
  }
64
- </style>
65
- <div class="custom-container">
66
- <h1 class="title-text">AI Background Removal</h1>
67
- <p class="subtitle-text">
68
- Remove backgrounds instantly using advanced AI technology
69
- </p>
70
- </div>
71
- """
72
-
73
- # Create the Gradio interface
74
- with gr.Blocks(css="""
75
- /* ... previous CSS styles ... */
76
  """) as demo:
77
- gr.Markdown(description)
 
 
 
78
 
79
- with gr.Column(scale=1):
80
  input_image = gr.Image(
81
  type="numpy",
82
  label="Upload Your Image",
83
- elem_id="input-image",
84
- elem_classes="input-image",
85
- container=True
86
  )
87
 
88
- # Hidden output image for processing
89
  output_image = gr.Image(
90
  type="numpy",
91
  visible=False
92
  )
93
 
94
- # Image comparison component
95
- with gr.Row(elem_classes="image-comparison-container"):
96
- image_comparison = gr.Image(
97
- type="numpy",
98
- label="Before / After Comparison",
99
- elem_id="image-comparison",
100
- elem_classes="image-comparison",
101
- container=True
102
- )
103
-
104
- with gr.Row(elem_classes="download-container"):
105
- download_file = gr.File(
106
- label="",
107
- file_count="single",
108
- interactive=True,
109
- visible=False
110
- )
111
-
112
- # Custom JavaScript for image comparison slider
113
- demo.load(js="""
114
- function initComparison() {
115
- const container = document.querySelector('.image-comparison');
116
- if (!container) return;
117
-
118
- const slider = document.createElement('div');
119
- slider.className = 'comparison-slider';
120
- const handle = document.createElement('div');
121
- handle.className = 'slider-handle';
122
- slider.appendChild(handle);
123
- container.appendChild(slider);
124
-
125
- let isDown = false;
126
- let startX;
127
- let sliderLeft;
128
-
129
- slider.addEventListener('mousedown', (e) => {
130
- isDown = true;
131
- startX = e.pageX - slider.offsetLeft;
132
- });
133
-
134
- document.addEventListener('mouseup', () => {
135
- isDown = false;
136
- });
137
-
138
- document.addEventListener('mousemove', (e) => {
139
- if (!isDown) return;
140
- e.preventDefault();
141
-
142
- const x = e.pageX - container.offsetLeft;
143
- const walk = x - startX;
144
-
145
- const containerWidth = container.offsetWidth;
146
- let newLeft = (x / containerWidth) * 100;
147
- newLeft = Math.max(0, Math.min(100, newLeft));
148
-
149
- slider.style.left = `${newLeft}%`;
150
- container.style.setProperty('--slider-position', `${newLeft}%`);
151
- });
152
- }
153
-
154
- // Initialize comparison slider when images are loaded
155
- document.addEventListener('DOMContentLoaded', initComparison);
156
- // Reinitialize when new images are loaded
157
- const observer = new MutationObserver(initComparison);
158
- observer.observe(document.body, { childList: true, subtree: true });
159
- """)
160
-
161
- def update_comparison(image, result):
162
- if result is None:
163
- return None
164
- # Create side-by-side comparison
165
- orig_image = Image.fromarray(image)
166
- result_image = Image.fromarray(result)
167
-
168
- # Ensure both images are the same size
169
- width = max(orig_image.width, result_image.width)
170
- height = max(orig_image.height, result_image.height)
171
-
172
- comparison = Image.new('RGBA', (width * 2, height))
173
- comparison.paste(orig_image, (0, 0))
174
- comparison.paste(result_image, (width, 0))
175
 
176
- return np.array(comparison)
 
 
 
 
 
 
177
 
178
- # Process automatically when image is uploaded
179
  input_image.change(
180
  fn=process,
181
  inputs=input_image,
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision.transforms.functional import normalize
5
+ import gradio as gr # Make sure this import is present
6
+ from briarmbg import BriaRMBG
7
+ import PIL
8
+ from PIL import Image
9
+ import tempfile
10
+ import os
11
+ import time
12
+ import uuid
13
+ import shutil
14
 
15
+ # Load the pre-trained model
16
+ print("Loading model...")
17
+ net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ net.to(device)
20
+ net.eval()
21
+ print(f"Model loaded on {device}")
22
+
23
+ # Create output directory if it doesn't exist
24
+ OUTPUT_DIR = "output_images"
25
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
26
+
27
+ def resize_image(image, max_size=1024):
28
+ """Resize image while maintaining aspect ratio and quality"""
29
+ width, height = image.size
30
+ aspect_ratio = width / height
31
+
32
+ if width > max_size or height > max_size:
33
+ if width > height:
34
+ new_width = max_size
35
+ new_height = int(max_size / aspect_ratio)
36
+ else:
37
+ new_height = max_size
38
+ new_width = int(max_size * aspect_ratio)
39
+ image = image.resize((new_width, new_height), Image.LANCZOS)
40
+
41
+ return image
42
+
43
+ def process(image, progress=gr.Progress()):
44
+ if image is None:
45
+ return None, gr.update(visible=False)
46
+
47
+ progress(0, desc="Starting processing...")
48
+ orig_image = Image.fromarray(image)
49
+ original_size = orig_image.size
50
+
51
+ progress(0.1, desc="Preparing image...")
52
+ process_image = resize_image(orig_image)
53
+ w, h = process_image.size
54
+
55
+ im_np = np.array(process_image)
56
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
57
+ im_tensor = torch.unsqueeze(im_tensor, 0)
58
+ im_tensor = torch.divide(im_tensor, 255.0)
59
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
60
+
61
+ progress(0.3, desc="Processing with AI model...")
62
+ if torch.cuda.is_available():
63
+ im_tensor = im_tensor.cuda()
64
+
65
+ with torch.no_grad():
66
+ result = net(im_tensor)
67
+
68
+ progress(0.6, desc="Post-processing...")
69
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
70
+ ma = torch.max(result)
71
+ mi = torch.min(result)
72
+ result = (result - mi) / (ma - mi)
73
+
74
+ result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
75
+ pil_mask = Image.fromarray(np.squeeze(result_array))
76
+
77
+ if pil_mask.size != original_size:
78
+ pil_mask = pil_mask.resize(original_size, Image.LANCZOS)
79
+
80
+ new_im = orig_image.copy()
81
+ new_im.putalpha(pil_mask)
82
+
83
+ progress(0.8, desc="Preparing download...")
84
+ unique_id = str(uuid.uuid4())[:8]
85
+ filename = f"background_removed_{unique_id}.png"
86
+ filepath = os.path.join(OUTPUT_DIR, filename)
87
+
88
+ new_im.save(filepath, format='PNG', quality=100)
89
+ output_array = np.array(new_im.convert("RGBA"))
90
+
91
+ progress(1.0, desc="Done!")
92
+ return output_array, gr.update(visible=True, value=filepath, interactive=True)
93
+
94
+ def update_comparison(image, result):
95
+ if result is None:
96
+ return None
97
+ orig_image = Image.fromarray(image)
98
+ result_image = Image.fromarray(result)
99
+
100
+ width = max(orig_image.width, result_image.width)
101
+ height = max(orig_image.height, result_image.height)
102
+
103
+ comparison = Image.new('RGBA', (width * 2, height))
104
+ comparison.paste(orig_image, (0, 0))
105
+ comparison.paste(result_image, (width, 0))
106
+
107
+ return np.array(comparison)
108
+
109
+ # Gradio interface
110
+ with gr.Blocks(css="""
111
+ @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&family=Roboto+Mono:wght@300;400;700&display=swap');
112
+
113
+ :root {
114
+ --neon-cyan: #00ffff;
115
+ --neon-pink: #ff00de;
116
+ --dark-background: #0a0a1e;
117
+ }
118
+
119
+ body {
120
+ font-family: 'Roboto Mono', monospace;
121
+ background: linear-gradient(135deg, var(--dark-background) 0%, #121238 100%);
122
+ color: #ffffff;
123
+ }
124
+
125
+ .container {
126
+ max-width: 800px;
127
+ margin: 0 auto;
128
+ padding: 20px;
129
+ }
130
 
 
131
  .image-comparison {
132
  position: relative;
133
  width: 100% !important;
134
  max-width: 800px !important;
135
  margin: 0 auto !important;
 
136
  border-radius: 12px !important;
137
  border: 2px solid var(--neon-cyan) !important;
138
  background: rgba(18, 18, 56, 0.7) !important;
139
+ overflow: hidden !important;
140
  }
141
 
142
+ .title-text {
143
+ color: var(--neon-pink);
144
+ font-family: 'Orbitron', sans-serif;
145
+ font-size: 2em;
146
+ text-align: center;
147
+ margin: 20px 0;
148
+ text-shadow: 0 0 10px rgba(255, 0, 222, 0.5);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  }
150
 
151
+ .subtitle-text {
152
+ color: var(--neon-cyan);
153
+ text-align: center;
154
+ margin-bottom: 20px;
155
+ text-shadow: 0 0 10px rgba(0, 255, 255, 0.5);
 
 
 
 
156
  }
157
 
158
+ .download-button {
159
+ background: linear-gradient(45deg, var(--neon-cyan), var(--neon-pink)) !important;
160
+ border: none !important;
161
+ padding: 10px 20px !important;
162
+ border-radius: 8px !important;
163
+ color: white !important;
164
+ font-family: 'Orbitron', sans-serif !important;
165
+ cursor: pointer !important;
166
+ transition: all 0.3s ease !important;
167
  }
168
 
169
+ .download-button:hover {
170
+ transform: translateY(-2px) !important;
171
+ box-shadow: 0 5px 15px rgba(0, 255, 255, 0.4) !important;
172
  }
 
 
 
 
 
 
 
 
 
 
 
 
173
  """) as demo:
174
+ gr.Markdown("""
175
+ <h1 class="title-text">AI Background Removal</h1>
176
+ <p class="subtitle-text">Remove backgrounds instantly using advanced AI technology</p>
177
+ """)
178
 
179
+ with gr.Column():
180
  input_image = gr.Image(
181
  type="numpy",
182
  label="Upload Your Image",
183
+ elem_classes="image-comparison"
 
 
184
  )
185
 
 
186
  output_image = gr.Image(
187
  type="numpy",
188
  visible=False
189
  )
190
 
191
+ image_comparison = gr.Image(
192
+ type="numpy",
193
+ label="Before / After Comparison",
194
+ elem_classes="image-comparison"
195
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ download_file = gr.File(
198
+ label="",
199
+ file_count="single",
200
+ interactive=True,
201
+ visible=False,
202
+ elem_classes="download-button"
203
+ )
204
 
 
205
  input_image.change(
206
  fn=process,
207
  inputs=input_image,