ihabooe commited on
Commit
c0a712f
·
verified ·
1 Parent(s): 116136a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -50
app.py CHANGED
@@ -25,6 +25,7 @@ OUTPUT_DIR = "output_images"
25
  os.makedirs(OUTPUT_DIR, exist_ok=True)
26
 
27
  def resize_image(image, max_size=1024):
 
28
  width, height = image.size
29
  aspect_ratio = width / height
30
 
@@ -40,53 +41,60 @@ def resize_image(image, max_size=1024):
40
  return image
41
 
42
  def process(image, progress=gr.Progress()):
 
43
  if image is None:
44
- return None
45
-
46
- progress(0, desc="Starting processing...")
47
- orig_image = Image.fromarray(image)
48
- original_size = orig_image.size
49
-
50
- progress(0.2, desc="Preparing image...")
51
- process_image = resize_image(orig_image)
52
- w, h = process_image.size
53
-
54
- im_np = np.array(process_image)
55
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
56
- im_tensor = torch.unsqueeze(im_tensor, 0)
57
- im_tensor = torch.divide(im_tensor, 255.0)
58
- im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
59
-
60
- progress(0.4, desc="Processing with AI model...")
61
- if torch.cuda.is_available():
62
- im_tensor = im_tensor.cuda()
63
-
64
- with torch.no_grad():
65
- result = net(im_tensor)
66
-
67
- progress(0.6, desc="Post-processing...")
68
- result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
69
- ma = torch.max(result)
70
- mi = torch.min(result)
71
- result = (result - mi) / (ma - mi)
72
-
73
- result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
74
- pil_mask = Image.fromarray(np.squeeze(result_array))
75
-
76
- if pil_mask.size != original_size:
77
- pil_mask = pil_mask.resize(original_size, Image.LANCZOS)
78
-
79
- new_im = orig_image.copy()
80
- new_im.putalpha(pil_mask)
81
-
82
- progress(0.8, desc="Saving result...")
83
- unique_id = str(uuid.uuid4())[:8]
84
- filename = f"background_removed_{unique_id}.png"
85
- filepath = os.path.join(OUTPUT_DIR, filename)
86
- new_im.save(filepath, format='PNG', quality=100)
87
-
88
- progress(1.0, desc="Done!")
89
- return gr.Image.update(value=filepath, visible=True)
 
 
 
 
 
 
90
 
91
  # Gradio interface
92
  with gr.Blocks(css="""
@@ -101,6 +109,7 @@ with gr.Blocks(css="""
101
  text-align: center;
102
  margin: 20px 0;
103
  text-shadow: 0 0 10px rgba(255, 0, 222, 0.7);
 
104
  }
105
 
106
  .subtitle-text {
@@ -118,6 +127,12 @@ with gr.Blocks(css="""
118
  margin: 10px 0;
119
  border: 2px solid #00ffff;
120
  box-shadow: 0 0 15px rgba(0, 255, 255, 0.2);
 
 
 
 
 
 
121
  }
122
 
123
  .download-btn {
@@ -129,6 +144,10 @@ with gr.Blocks(css="""
129
  font-family: 'Orbitron', sans-serif;
130
  cursor: pointer;
131
  transition: all 0.3s ease;
 
 
 
 
132
  }
133
 
134
  .download-btn:hover {
@@ -136,10 +155,20 @@ with gr.Blocks(css="""
136
  box-shadow: 0 5px 15px rgba(0, 255, 255, 0.4);
137
  }
138
 
 
 
 
 
 
 
 
 
 
139
  @media (max-width: 768px) {
140
  .title-text { font-size: 1.8em; }
141
  .subtitle-text { font-size: 1em; }
142
  .image-container { padding: 10px; }
 
143
  }
144
  """) as demo:
145
  gr.Markdown("""
@@ -158,15 +187,22 @@ with gr.Blocks(css="""
158
  output_image = gr.Image(
159
  label="Result",
160
  type="filepath",
161
- elem_classes="image-container"
 
162
  )
163
 
164
- process_btn = gr.Button("Remove Background", variant="primary")
 
 
 
 
165
 
166
- process_btn.click(
 
167
  fn=process,
168
  inputs=input_image,
169
- outputs=output_image
 
170
  )
171
 
172
  if __name__ == "__main__":
 
25
  os.makedirs(OUTPUT_DIR, exist_ok=True)
26
 
27
  def resize_image(image, max_size=1024):
28
+ """Resize image while maintaining aspect ratio"""
29
  width, height = image.size
30
  aspect_ratio = width / height
31
 
 
41
  return image
42
 
43
  def process(image, progress=gr.Progress()):
44
+ """Process the image and remove background"""
45
  if image is None:
46
+ return None, None
47
+
48
+ try:
49
+ progress(0, desc="Starting processing...")
50
+ orig_image = Image.fromarray(image)
51
+ original_size = orig_image.size
52
+
53
+ progress(0.2, desc="Preparing image...")
54
+ process_image = resize_image(orig_image)
55
+ w, h = process_image.size
56
+
57
+ # Convert image to tensor
58
+ im_np = np.array(process_image)
59
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
60
+ im_tensor = torch.unsqueeze(im_tensor, 0)
61
+ im_tensor = torch.divide(im_tensor, 255.0)
62
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
63
+
64
+ progress(0.4, desc="Processing with AI model...")
65
+ if torch.cuda.is_available():
66
+ im_tensor = im_tensor.cuda()
67
+
68
+ with torch.no_grad():
69
+ result = net(im_tensor)
70
+
71
+ progress(0.6, desc="Post-processing...")
72
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
73
+ ma = torch.max(result)
74
+ mi = torch.min(result)
75
+ result = (result - mi) / (ma - mi)
76
+
77
+ result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
78
+ pil_mask = Image.fromarray(np.squeeze(result_array))
79
+
80
+ if pil_mask.size != original_size:
81
+ pil_mask = pil_mask.resize(original_size, Image.LANCZOS)
82
+
83
+ new_im = orig_image.copy()
84
+ new_im.putalpha(pil_mask)
85
+
86
+ progress(0.8, desc="Saving result...")
87
+ unique_id = str(uuid.uuid4())[:8]
88
+ filename = f"background_removed_{unique_id}.png"
89
+ filepath = os.path.join(OUTPUT_DIR, filename)
90
+ new_im.save(filepath, format='PNG', quality=100)
91
+
92
+ progress(1.0, desc="Done!")
93
+ return gr.Image.update(value=filepath, visible=True), gr.File.update(value=filepath, visible=True)
94
+
95
+ except Exception as e:
96
+ print(f"Error processing image: {str(e)}")
97
+ return None, None
98
 
99
  # Gradio interface
100
  with gr.Blocks(css="""
 
109
  text-align: center;
110
  margin: 20px 0;
111
  text-shadow: 0 0 10px rgba(255, 0, 222, 0.7);
112
+ animation: glow 2s ease-in-out infinite alternate;
113
  }
114
 
115
  .subtitle-text {
 
127
  margin: 10px 0;
128
  border: 2px solid #00ffff;
129
  box-shadow: 0 0 15px rgba(0, 255, 255, 0.2);
130
+ transition: all 0.3s ease;
131
+ }
132
+
133
+ .image-container:hover {
134
+ box-shadow: 0 0 20px rgba(0, 255, 255, 0.4);
135
+ transform: translateY(-2px);
136
  }
137
 
138
  .download-btn {
 
144
  font-family: 'Orbitron', sans-serif;
145
  cursor: pointer;
146
  transition: all 0.3s ease;
147
+ margin-top: 10px;
148
+ text-align: center;
149
+ text-transform: uppercase;
150
+ letter-spacing: 1px;
151
  }
152
 
153
  .download-btn:hover {
 
155
  box-shadow: 0 5px 15px rgba(0, 255, 255, 0.4);
156
  }
157
 
158
+ @keyframes glow {
159
+ from {
160
+ text-shadow: 0 0 5px #ff00de, 0 0 10px #ff00de, 0 0 15px #ff00de;
161
+ }
162
+ to {
163
+ text-shadow: 0 0 10px #ff00de, 0 0 20px #ff00de, 0 0 30px #ff00de;
164
+ }
165
+ }
166
+
167
  @media (max-width: 768px) {
168
  .title-text { font-size: 1.8em; }
169
  .subtitle-text { font-size: 1em; }
170
  .image-container { padding: 10px; }
171
+ .download-btn { padding: 10px 20px; }
172
  }
173
  """) as demo:
174
  gr.Markdown("""
 
187
  output_image = gr.Image(
188
  label="Result",
189
  type="filepath",
190
+ elem_classes="image-container",
191
+ visible=True
192
  )
193
 
194
+ download_file = gr.File(
195
+ label="Download Processed Image",
196
+ visible=False,
197
+ elem_classes="download-btn"
198
+ )
199
 
200
+ # Automatic processing when image is uploaded
201
+ input_image.change(
202
  fn=process,
203
  inputs=input_image,
204
+ outputs=[output_image, download_file],
205
+ api_name="process_image"
206
  )
207
 
208
  if __name__ == "__main__":