rdjarbeng's picture
correct usage of the dropdown
09cdf11
raw
history blame
2.33 kB
import gradio as gr
from rembg import remove, new_session
from PIL import Image
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
def remove_background(input_image, bg_color, model_name, alpha_matting, post_process_mask, only_mask):
try:
# Set up the session with the chosen model
session = new_session(model_name) if model_name else None
# Prepare additional options
remove_options = {
"session": session,
"bgcolor": bg_color if bg_color else None,
"alpha_matting": alpha_matting,
"post_process_mask": post_process_mask,
"only_mask": only_mask
}
# Remove the background
output_image = remove(input_image, **{k: v for k, v in remove_options.items() if v is not None})
logging.info("Background removed")
# Convert to RGB mode if necessary
if not only_mask and output_image.mode != 'RGB':
output_image = output_image.convert('RGB')
logging.info("Converted to RGB mode")
# Save the output image to a temporary file
output_path = "output_image.png"
output_image.save(output_path)
logging.info(f"Saved output image {output_path}")
return output_image, output_path
except Exception as e:
logging.error(f"An error occurred: {e}")
return None, None
# Gradio interface
iface = gr.Interface(
fn=remove_background,
inputs=[
gr.Image(type="pil"),
gr.ColorPicker(label="Background Color", value=None), # Background color picker
gr.Dropdown(choices=["u2net", "isnet-general-use", "unet"], label="Model Selection", value="u2net"),
gr.Checkbox(label="Enable Alpha Matting", value=False),
gr.Checkbox(label="Post-Process Mask", value=False),
gr.Checkbox(label="Only Return Mask", value=False)
],
outputs=[
gr.Image(type="pil", label="Output Image"),
gr.File(label="Download the output image")
],
title="Advanced Background Remover",
description="Upload an image to remove the background. Customize the result with different options, including background color, model selection, alpha matting, and more.",
allow_flagging="never",
)
if __name__ == "__main__":
iface.launch()