ihabooe commited on
Commit
64dbdd9
·
verified ·
1 Parent(s): 3b63c3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -21
app.py CHANGED
@@ -24,17 +24,26 @@ print(f"Model loaded on {device}")
24
  OUTPUT_DIR = "output_images"
25
  os.makedirs(OUTPUT_DIR, exist_ok=True)
26
 
27
- # Define different sizes for different devices
28
- DESKTOP_SIZE = (512, 512)
29
- MOBILE_SIZE = (300, 300)
30
-
31
- # Resize the input image for model compatibility
32
- def resize_image(image, size=DESKTOP_SIZE):
33
- image = image.convert('RGB')
34
- image = image.resize(size, Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
35
  return image
36
 
37
- # Background removal process
38
  def process(image, progress=gr.Progress()):
39
  if image is None:
40
  return None, gr.update(visible=False)
@@ -44,11 +53,13 @@ def process(image, progress=gr.Progress()):
44
  # Prepare the input
45
  progress(0.1, desc="Preparing image...")
46
  orig_image = Image.fromarray(image)
47
- # Resize input image to fixed size
48
- orig_image = resize_image(orig_image, DESKTOP_SIZE)
49
- w, h = DESKTOP_SIZE
50
- image = orig_image
51
- im_np = np.array(image)
 
 
52
  im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
53
  im_tensor = torch.unsqueeze(im_tensor, 0)
54
  im_tensor = torch.divide(im_tensor, 255.0)
@@ -73,6 +84,10 @@ def process(image, progress=gr.Progress()):
73
  result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
74
  pil_mask = Image.fromarray(np.squeeze(result_array))
75
 
 
 
 
 
76
  # Add the mask as alpha channel to the original image
77
  new_im = orig_image.copy()
78
  new_im.putalpha(pil_mask)
@@ -83,8 +98,8 @@ def process(image, progress=gr.Progress()):
83
  filename = f"background_removed_{unique_id}.png"
84
  filepath = os.path.join(OUTPUT_DIR, filename)
85
 
86
- # Save the processed image
87
- new_im.save(filepath, format='PNG')
88
 
89
  # Convert to numpy array for display
90
  output_array = np.array(new_im.convert("RGBA"))
@@ -178,9 +193,9 @@ with gr.Blocks(css="""
178
  /* Input/Output areas with responsive sizing */
179
  .input-image, .output-image {
180
  width: 100% !important;
181
- max-width: 512px !important;
182
  height: auto !important;
183
- aspect-ratio: 1/1 !important;
184
  object-fit: contain !important;
185
  background: rgba(18, 18, 56, 0.7) !important;
186
  border: 2px solid var(--neon-cyan) !important;
@@ -191,9 +206,10 @@ with gr.Blocks(css="""
191
  }
192
 
193
  .input-image img, .output-image img {
194
- width: 100% !important;
195
- height: 100% !important;
196
  object-fit: contain !important;
 
197
  }
198
 
199
  /* Responsive columns */
@@ -236,7 +252,11 @@ with gr.Blocks(css="""
236
  /* Responsive layout */
237
  @media (max-width: 768px) {
238
  .input-image, .output-image {
239
- max-width: 300px !important;
 
 
 
 
240
  }
241
 
242
  label {
 
24
  OUTPUT_DIR = "output_images"
25
  os.makedirs(OUTPUT_DIR, exist_ok=True)
26
 
27
+ def resize_image(image, max_size=1024):
28
+ """Resize image while maintaining aspect ratio and quality"""
29
+ # Get original size
30
+ width, height = image.size
31
+
32
+ # Calculate aspect ratio
33
+ aspect_ratio = width / height
34
+
35
+ # Only resize if the image is larger than max_size in either dimension
36
+ if width > max_size or height > max_size:
37
+ if width > height:
38
+ new_width = max_size
39
+ new_height = int(max_size / aspect_ratio)
40
+ else:
41
+ new_height = max_size
42
+ new_width = int(max_size * aspect_ratio)
43
+ image = image.resize((new_width, new_height), Image.LANCZOS)
44
+
45
  return image
46
 
 
47
  def process(image, progress=gr.Progress()):
48
  if image is None:
49
  return None, gr.update(visible=False)
 
53
  # Prepare the input
54
  progress(0.1, desc="Preparing image...")
55
  orig_image = Image.fromarray(image)
56
+ original_size = orig_image.size
57
+
58
+ # Resize only if needed for processing
59
+ process_image = resize_image(orig_image)
60
+ w, h = process_image.size
61
+
62
+ im_np = np.array(process_image)
63
  im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
64
  im_tensor = torch.unsqueeze(im_tensor, 0)
65
  im_tensor = torch.divide(im_tensor, 255.0)
 
84
  result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
85
  pil_mask = Image.fromarray(np.squeeze(result_array))
86
 
87
+ # Resize mask back to original size if needed
88
+ if pil_mask.size != original_size:
89
+ pil_mask = pil_mask.resize(original_size, Image.LANCZOS)
90
+
91
  # Add the mask as alpha channel to the original image
92
  new_im = orig_image.copy()
93
  new_im.putalpha(pil_mask)
 
98
  filename = f"background_removed_{unique_id}.png"
99
  filepath = os.path.join(OUTPUT_DIR, filename)
100
 
101
+ # Save the processed image in original resolution
102
+ new_im.save(filepath, format='PNG', quality=100)
103
 
104
  # Convert to numpy array for display
105
  output_array = np.array(new_im.convert("RGBA"))
 
193
  /* Input/Output areas with responsive sizing */
194
  .input-image, .output-image {
195
  width: 100% !important;
196
+ max-width: 800px !important;
197
  height: auto !important;
198
+ min-height: 300px !important;
199
  object-fit: contain !important;
200
  background: rgba(18, 18, 56, 0.7) !important;
201
  border: 2px solid var(--neon-cyan) !important;
 
206
  }
207
 
208
  .input-image img, .output-image img {
209
+ max-width: 100% !important;
210
+ max-height: 800px !important;
211
  object-fit: contain !important;
212
+ margin: auto !important;
213
  }
214
 
215
  /* Responsive columns */
 
252
  /* Responsive layout */
253
  @media (max-width: 768px) {
254
  .input-image, .output-image {
255
+ min-height: 200px !important;
256
+ }
257
+
258
+ .input-image img, .output-image img {
259
+ max-height: 500px !important;
260
  }
261
 
262
  label {