Update app.py
Browse files
app.py
CHANGED
|
@@ -2,102 +2,93 @@ import gradio as gr
|
|
| 2 |
from rembg import remove
|
| 3 |
from PIL import Image
|
| 4 |
import numpy as np
|
|
|
|
| 5 |
import cv2
|
| 6 |
-
|
| 7 |
-
import time
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
new_size = (int(width * ratio), int(height * ratio))
|
| 26 |
-
img = img.resize(new_size, Image.LANCZOS)
|
| 27 |
-
return img
|
| 28 |
|
| 29 |
-
def
|
| 30 |
-
|
|
|
|
|
|
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
try:
|
|
|
|
| 33 |
if isinstance(input_image, np.ndarray):
|
| 34 |
input_image = Image.fromarray(input_image)
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
alpha_matting_erode_size=10
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
if output.mode == 'RGBA':
|
| 48 |
-
mask = output.split()[-1]
|
| 49 |
-
mask_np = np.array(mask)
|
| 50 |
-
if post_process:
|
| 51 |
-
mask_np = enhance_mask(mask_np)
|
| 52 |
else:
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
|
| 59 |
-
return output, Image.fromarray(mask_np), f"Processed in {proc_time:.2f} seconds"
|
| 60 |
|
| 61 |
except Exception as e:
|
| 62 |
print(f"Error processing image: {str(e)}")
|
| 63 |
-
return None, None
|
| 64 |
-
|
| 65 |
-
# Custom CSS for better UI
|
| 66 |
-
custom_css = """
|
| 67 |
-
.gradio-container { max-width: 900px !important; }
|
| 68 |
-
.output-image { border: 1px solid #e2e8f0 !important; border-radius: 8px !important; }
|
| 69 |
-
.processing-time { font-size: 0.9em; color: #64748b; margin-top: 8px; }
|
| 70 |
-
"""
|
| 71 |
|
| 72 |
# Create interface
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
inputs=[input_img, post_process, alpha_matting],
|
| 95 |
-
outputs=[output_img, output_mask, time_text]
|
| 96 |
-
)
|
| 97 |
|
|
|
|
| 98 |
if __name__ == "__main__":
|
| 99 |
-
|
| 100 |
-
server_name="0.0.0.0",
|
| 101 |
-
server_port=7860,
|
| 102 |
-
show_error=True
|
| 103 |
-
)
|
|
|
|
| 2 |
from rembg import remove
|
| 3 |
from PIL import Image
|
| 4 |
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
import cv2
|
| 7 |
+
import os
|
|
|
|
| 8 |
|
| 9 |
+
# Initialize InSPyReNet if available
|
| 10 |
+
try:
|
| 11 |
+
from InSPyReNet.models.InSPyReNet import InSPyReNet
|
| 12 |
+
from InSPyReNet.utils.dataloader import test_dataset
|
| 13 |
|
| 14 |
+
# Download InSPyReNet weights
|
| 15 |
+
if not os.path.exists('InSPyReNet.pth'):
|
| 16 |
+
os.system('wget https://github.com/plemeri/InSPyReNet/releases/download/v1.0/InSPyReNet.pth')
|
| 17 |
+
|
| 18 |
+
# Load InSPyReNet model
|
| 19 |
+
inspyrenet = InSPyReNet()
|
| 20 |
+
inspyrenet.load_state_dict(torch.load('InSPyReNet.pth', map_location='cpu'))
|
| 21 |
+
inspyrenet.eval()
|
| 22 |
+
HAS_INSPYRE = True
|
| 23 |
+
except:
|
| 24 |
+
HAS_INSPYRE = False
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
def process_with_inspyrenet(image):
|
| 27 |
+
# Preprocess
|
| 28 |
+
image = test_dataset.preprocess(np.array(image))
|
| 29 |
+
image = torch.from_numpy(image).unsqueeze(0)
|
| 30 |
|
| 31 |
+
# Predict
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
pred = inspyrenet(image)
|
| 34 |
+
|
| 35 |
+
# Post-process
|
| 36 |
+
mask = (pred.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255
|
| 37 |
+
return mask
|
| 38 |
+
|
| 39 |
+
def remove_background(input_image, model_choice="Rembg (U²-Net)"):
|
| 40 |
try:
|
| 41 |
+
# Convert to PIL Image if it's a numpy array
|
| 42 |
if isinstance(input_image, np.ndarray):
|
| 43 |
input_image = Image.fromarray(input_image)
|
| 44 |
|
| 45 |
+
# Process with selected model
|
| 46 |
+
if model_choice == "InSPyReNet" and HAS_INSPYRE:
|
| 47 |
+
mask = process_with_inspyrenet(input_image)
|
| 48 |
+
mask_img = Image.fromarray(mask)
|
| 49 |
+
|
| 50 |
+
# Apply mask to original image
|
| 51 |
+
output = input_image.copy()
|
| 52 |
+
output.putalpha(mask_img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
else:
|
| 54 |
+
# Default to Rembg
|
| 55 |
+
output = remove(input_image)
|
| 56 |
+
if output.mode == 'RGBA':
|
| 57 |
+
mask = output.split()[-1]
|
| 58 |
+
mask_np = np.array(mask)
|
| 59 |
+
else:
|
| 60 |
+
mask_np = np.ones(output.size[::-1], dtype=np.uint8) * 255
|
| 61 |
+
mask_img = Image.fromarray(mask_np)
|
| 62 |
|
| 63 |
+
return output, mask_img
|
|
|
|
| 64 |
|
| 65 |
except Exception as e:
|
| 66 |
print(f"Error processing image: {str(e)}")
|
| 67 |
+
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
# Create interface
|
| 70 |
+
iface = gr.Interface(
|
| 71 |
+
fn=remove_background,
|
| 72 |
+
inputs=[
|
| 73 |
+
gr.Image(type="pil", label="Input Image"),
|
| 74 |
+
gr.Radio(
|
| 75 |
+
choices=["Rembg (U²-Net)", "InSPyReNet"],
|
| 76 |
+
value="Rembg (U²-Net)",
|
| 77 |
+
label="Model Selection"
|
| 78 |
+
)
|
| 79 |
+
],
|
| 80 |
+
outputs=[
|
| 81 |
+
gr.Image(type="pil", label="Result with Transparent Background"),
|
| 82 |
+
gr.Image(type="pil", label="Segmentation Mask")
|
| 83 |
+
],
|
| 84 |
+
title="Hybrid Background Remover (CPU)",
|
| 85 |
+
description="""
|
| 86 |
+
Upload an image to remove the background. Choose between:
|
| 87 |
+
- Rembg (U²-Net): Faster (5-15 sec)
|
| 88 |
+
- InSPyReNet: More accurate but slower (15-30 sec)
|
| 89 |
+
"""
|
| 90 |
+
)
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
+
# Launch with minimal configuration
|
| 93 |
if __name__ == "__main__":
|
| 94 |
+
iface.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
|
|
|
|
|
|
|