Janeka commited on
Commit
d74c67e
·
verified ·
1 Parent(s): 1fe7394

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -81
app.py CHANGED
@@ -2,102 +2,93 @@ import gradio as gr
2
  from rembg import remove
3
  from PIL import Image
4
  import numpy as np
 
5
  import cv2
6
- from skimage import filters
7
- import time
8
 
9
- def enhance_mask(mask):
10
- """Refine the mask edges for better quality"""
11
- if len(mask.shape) == 3:
12
- mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
13
 
14
- mask = cv2.GaussianBlur(mask, (5, 5), 0)
15
- _, binary_mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
16
- edges = filters.sobel(binary_mask)
17
- refined_mask = np.where(edges > 0.1, 255, binary_mask)
18
- return refined_mask.astype(np.uint8)
19
-
20
- def resize_image(img, max_size=1024):
21
- """Resize large images while maintaining aspect ratio"""
22
- width, height = img.size
23
- if max(width, height) > max_size:
24
- ratio = max_size / max(width, height)
25
- new_size = (int(width * ratio), int(height * ratio))
26
- img = img.resize(new_size, Image.LANCZOS)
27
- return img
28
 
29
- def remove_background(input_image, post_process=True, alpha_matting=False):
30
- start_time = time.time()
 
 
31
 
 
 
 
 
 
 
 
 
 
32
  try:
 
33
  if isinstance(input_image, np.ndarray):
34
  input_image = Image.fromarray(input_image)
35
 
36
- input_image = resize_image(input_image)
37
-
38
- output = remove(
39
- input_image,
40
- post_process=post_process,
41
- alpha_matting=alpha_matting,
42
- alpha_matting_foreground_threshold=240,
43
- alpha_matting_background_threshold=10,
44
- alpha_matting_erode_size=10
45
- )
46
-
47
- if output.mode == 'RGBA':
48
- mask = output.split()[-1]
49
- mask_np = np.array(mask)
50
- if post_process:
51
- mask_np = enhance_mask(mask_np)
52
  else:
53
- mask_np = np.ones(output.size[::-1], dtype=np.uint8) * 255
54
-
55
- if post_process:
56
- output.putalpha(Image.fromarray(mask_np))
 
 
 
 
57
 
58
- proc_time = time.time() - start_time
59
- return output, Image.fromarray(mask_np), f"Processed in {proc_time:.2f} seconds"
60
 
61
  except Exception as e:
62
  print(f"Error processing image: {str(e)}")
63
- return None, None, "Error processing image"
64
-
65
- # Custom CSS for better UI
66
- custom_css = """
67
- .gradio-container { max-width: 900px !important; }
68
- .output-image { border: 1px solid #e2e8f0 !important; border-radius: 8px !important; }
69
- .processing-time { font-size: 0.9em; color: #64748b; margin-top: 8px; }
70
- """
71
 
72
  # Create interface
73
- with gr.Blocks(css=custom_css) as demo:
74
- gr.Markdown("""
75
- # 🖼️ Professional Background Remover
76
- *Powered by U²-Net with enhanced post-processing*
77
- """)
78
-
79
- with gr.Row():
80
- with gr.Column():
81
- input_img = gr.Image(label="Upload Image", type="pil", elem_id="input-image")
82
- with gr.Accordion("Advanced Options", open=False):
83
- post_process = gr.Checkbox(label="Enhanced Post-Processing", value=True)
84
- alpha_matting = gr.Checkbox(label="Use Alpha Matting (for fine details)", value=False)
85
- submit_btn = gr.Button("Remove Background", variant="primary")
86
-
87
- with gr.Column():
88
- output_img = gr.Image(label="Result", type="pil", elem_id="output-image")
89
- output_mask = gr.Image(label="Segmentation Mask", type="pil")
90
- time_text = gr.Markdown(elem_classes=["processing-time"])
91
-
92
- submit_btn.click(
93
- fn=remove_background,
94
- inputs=[input_img, post_process, alpha_matting],
95
- outputs=[output_img, output_mask, time_text]
96
- )
97
 
 
98
  if __name__ == "__main__":
99
- demo.launch(
100
- server_name="0.0.0.0",
101
- server_port=7860,
102
- show_error=True
103
- )
 
2
  from rembg import remove
3
  from PIL import Image
4
  import numpy as np
5
+ import torch
6
  import cv2
7
+ import os
 
8
 
9
+ # Initialize InSPyReNet if available
10
+ try:
11
+ from InSPyReNet.models.InSPyReNet import InSPyReNet
12
+ from InSPyReNet.utils.dataloader import test_dataset
13
 
14
+ # Download InSPyReNet weights
15
+ if not os.path.exists('InSPyReNet.pth'):
16
+ os.system('wget https://github.com/plemeri/InSPyReNet/releases/download/v1.0/InSPyReNet.pth')
17
+
18
+ # Load InSPyReNet model
19
+ inspyrenet = InSPyReNet()
20
+ inspyrenet.load_state_dict(torch.load('InSPyReNet.pth', map_location='cpu'))
21
+ inspyrenet.eval()
22
+ HAS_INSPYRE = True
23
+ except:
24
+ HAS_INSPYRE = False
 
 
 
25
 
26
+ def process_with_inspyrenet(image):
27
+ # Preprocess
28
+ image = test_dataset.preprocess(np.array(image))
29
+ image = torch.from_numpy(image).unsqueeze(0)
30
 
31
+ # Predict
32
+ with torch.no_grad():
33
+ pred = inspyrenet(image)
34
+
35
+ # Post-process
36
+ mask = (pred.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255
37
+ return mask
38
+
39
+ def remove_background(input_image, model_choice="Rembg (U²-Net)"):
40
  try:
41
+ # Convert to PIL Image if it's a numpy array
42
  if isinstance(input_image, np.ndarray):
43
  input_image = Image.fromarray(input_image)
44
 
45
+ # Process with selected model
46
+ if model_choice == "InSPyReNet" and HAS_INSPYRE:
47
+ mask = process_with_inspyrenet(input_image)
48
+ mask_img = Image.fromarray(mask)
49
+
50
+ # Apply mask to original image
51
+ output = input_image.copy()
52
+ output.putalpha(mask_img)
 
 
 
 
 
 
 
 
53
  else:
54
+ # Default to Rembg
55
+ output = remove(input_image)
56
+ if output.mode == 'RGBA':
57
+ mask = output.split()[-1]
58
+ mask_np = np.array(mask)
59
+ else:
60
+ mask_np = np.ones(output.size[::-1], dtype=np.uint8) * 255
61
+ mask_img = Image.fromarray(mask_np)
62
 
63
+ return output, mask_img
 
64
 
65
  except Exception as e:
66
  print(f"Error processing image: {str(e)}")
67
+ return None, None
 
 
 
 
 
 
 
68
 
69
  # Create interface
70
+ iface = gr.Interface(
71
+ fn=remove_background,
72
+ inputs=[
73
+ gr.Image(type="pil", label="Input Image"),
74
+ gr.Radio(
75
+ choices=["Rembg (U²-Net)", "InSPyReNet"],
76
+ value="Rembg (U²-Net)",
77
+ label="Model Selection"
78
+ )
79
+ ],
80
+ outputs=[
81
+ gr.Image(type="pil", label="Result with Transparent Background"),
82
+ gr.Image(type="pil", label="Segmentation Mask")
83
+ ],
84
+ title="Hybrid Background Remover (CPU)",
85
+ description="""
86
+ Upload an image to remove the background. Choose between:
87
+ - Rembg (-Net): Faster (5-15 sec)
88
+ - InSPyReNet: More accurate but slower (15-30 sec)
89
+ """
90
+ )
 
 
 
91
 
92
+ # Launch with minimal configuration
93
  if __name__ == "__main__":
94
+ iface.launch(server_name="0.0.0.0", server_port=7860)