Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from rembg import remove | |
| from PIL import ImageOps, ImageEnhance, Image | |
| from streamlit_option_menu import option_menu | |
| from markup import real_estate_app, real_estate_app_hf, sliders_intro, perspective_intro, manual_bg_intro, segement_intro | |
| from perspective_correction import perspective_correction, perspective_correction2 | |
| from streamlit_drawable_canvas import st_canvas | |
| import tempfile | |
| from ultralytics import YOLO | |
| import numpy as np | |
| import cv2 | |
| import gdown | |
| import os | |
| from manual_removal import remove_background | |
| def tab1(): | |
| st.header("Image Background Remover") | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.image("image.jpg", use_column_width=True) | |
| with col2: | |
| st.markdown(real_estate_app(), unsafe_allow_html=True) | |
| st.markdown(real_estate_app_hf(),unsafe_allow_html=True) | |
| github_link = '[<img src="https://badgen.net/badge/icon/github?icon=github&label">](https://github.com/ethanrom)' | |
| huggingface_link = '[<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue">](https://huggingface.co/ethanrom)' | |
| st.write(github_link + ' ' + huggingface_link, unsafe_allow_html=True) | |
| def tab2(): | |
| st.header("Image Background Remover") | |
| st.markdown(sliders_intro(), unsafe_allow_html=True) | |
| upload_option = st.radio("Upload Option", ("Single Image", "Multiple Images")) | |
| if upload_option == "Single Image": | |
| uploaded_images = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"], accept_multiple_files=False) | |
| images = [Image.open(uploaded_images)] if uploaded_images else [] | |
| else: | |
| uploaded_images = st.file_uploader("Upload multiple images", type=["png", "jpg", "jpeg"], accept_multiple_files=True) | |
| images = [Image.open(image) for image in uploaded_images] if uploaded_images else [] | |
| if images: | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.image(images[0], caption="Original Image", use_column_width=True) | |
| image = preprocess_image_1(images[0]) | |
| with col2: | |
| st.subheader("RGB Adjustments") | |
| with st.expander("Expand"): | |
| r_min, r_max = st.slider("Red", min_value=0, max_value=255, value=(0, 255), step=1) | |
| g_min, g_max = st.slider("Green", min_value=0, max_value=255, value=(0, 255), step=1) | |
| b_min, b_max = st.slider("Blue", min_value=0, max_value=255, value=(0, 255), step=1) | |
| adjusted_image = adjust_rgb(image, r_min, r_max, g_min, g_max, b_min, b_max) | |
| st.image(adjusted_image, caption="Adjusted Image", use_column_width=True) | |
| st.subheader("Curves Adjustment") | |
| with st.expander("Expand"): | |
| r_curve = st.slider("Red Curve", min_value=0.0, max_value=1.0, value=1.0, step=0.05) | |
| g_curve = st.slider("Green Curve", min_value=0.0, max_value=1.0, value=1.0, step=0.05) | |
| b_curve = st.slider("Blue Curve", min_value=0.0, max_value=1.0, value=1.0, step=0.05) | |
| adjusted_image = adjust_curves(adjusted_image, r_curve, g_curve, b_curve) | |
| st.image(adjusted_image, caption="Adjusted Image", use_column_width=True) | |
| st.subheader("Masking") | |
| with st.expander("Expand"): | |
| threshold = st.slider("Threshold", min_value=0, max_value=255, value=128, step=1) | |
| adjusted_image = apply_masking(adjusted_image, threshold) | |
| st.image(adjusted_image, caption="Adjusted Image", use_column_width=True) | |
| with col1: | |
| if st.button("Remove Background"): | |
| with st.spinner("Removing background..."): | |
| output_images = [] | |
| for image in images: | |
| processed_image = preprocess_image_1(image) | |
| adjusted_image = adjust_rgb(processed_image, r_min, r_max, g_min, g_max, b_min, b_max) | |
| adjusted_image = adjust_curves(adjusted_image, r_curve, g_curve, b_curve) | |
| adjusted_image = apply_masking(adjusted_image, threshold) | |
| output_images.append(remove(adjusted_image)) | |
| with st.expander("Background Removed Images"): | |
| for i in range(len(output_images)): | |
| st.image(output_images[i], caption=f"Background Removed Image {i + 1}", use_column_width=True) | |
| def preprocess_image_1(image): | |
| if image.mode != "RGBA": | |
| image = image.convert("RGBA") | |
| return image | |
| def adjust_rgb(image, r_min, r_max, g_min, g_max, b_min, b_max): | |
| r, g, b, a = image.split() | |
| r = ImageOps.autocontrast(r.point(lambda p: int(p * (r_max - r_min) / 255 + r_min))) | |
| g = ImageOps.autocontrast(g.point(lambda p: int(p * (g_max - g_min) / 255 + g_min))) | |
| b = ImageOps.autocontrast(b.point(lambda p: int(p * (b_max - b_min) / 255 + b_min))) | |
| return Image.merge("RGBA", (r, g, b, a)) | |
| def adjust_curves(image, r_curve, g_curve, b_curve): | |
| r, g, b, a = image.split() | |
| enhancer_r = ImageEnhance.Brightness(r).enhance(r_curve) | |
| enhancer_g = ImageEnhance.Brightness(g).enhance(g_curve) | |
| enhancer_b = ImageEnhance.Brightness(b).enhance(b_curve) | |
| return Image.merge("RGBA", (enhancer_r, enhancer_g, enhancer_b, a)) | |
| def apply_masking(image, threshold): | |
| r, g, b, a = image.split() | |
| mask = a.point(lambda p: 255 if p > threshold else 0) | |
| return Image.merge("RGBA", (r, g, b, mask)) | |
| def tab3(): | |
| model_file_v1 = 'bestslab-seg.onnx' | |
| model_url_v1 = 'https://drive.google.com/uc?id=1---iqs2llLrgDbzr_S1nzkKUr3sJ_ru3' | |
| model_file_v2 = 'best2-seg.onnx' | |
| model_url_v2 = 'https://drive.google.com/file/d/1k8HASXNFnAhEiDPPNM-hX22rM0bHMWtb/view?usp=sharing' | |
| st.header("Background Removal with instance Segmentaion") | |
| st.markdown(segement_intro(), unsafe_allow_html=True) | |
| model_option = st.selectbox('Select Model', ['model_new', 'modelv1']) | |
| if model_option == 'modelv1': | |
| model_file = model_file_v1 | |
| model_url = model_url_v1 | |
| else: | |
| model_file = model_file_v2 | |
| model_url = model_url_v2 | |
| if not os.path.exists(model_file): | |
| gdown.download(model_url, model_file, quiet=False) | |
| model = YOLO(model_file) | |
| uploaded_file = st.file_uploader('Choose an image', type=['jpg', 'jpeg', 'png']) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| col1, col2 = st.columns([2,1]) | |
| with col2: | |
| iou_threshold = st.slider('IoU Threshold', min_value=0.0, max_value=1.0, value=0.7) | |
| conf_threshold = st.slider('Confidence Threshold', min_value=0.0, max_value=1.0, value=0.65) | |
| show_labels = st.checkbox('Show Labels', value=False) | |
| show_conf = st.checkbox('Show Confidence Scores', value=False) | |
| boxes = st.checkbox('Show Boxes', value=True) | |
| smooth_edges = st.checkbox('Smooth Edges', value=False) | |
| if smooth_edges: | |
| smoothing_method = st.selectbox('Smoothing Method', ['Morphological Closing', 'Image Dilation', 'Inward Morphological Closing', 'Inward Image Dilation']) | |
| if smoothing_method == 'Morphological Closing': | |
| closing_radius = st.slider('Closing Radius', min_value=1, max_value=10, value=5) | |
| elif smoothing_method == 'Image Dilation': | |
| dilation_strength = st.slider('Dilation Strength', min_value=1, max_value=10, value=5) | |
| elif smoothing_method == 'Inward Morphological Closing': | |
| closing_radius_inward = st.slider('Inward Closing Radius', min_value=1, max_value=50, value=10) | |
| elif smoothing_method == 'Inward Image Dilation': | |
| dilation_strength_inward = st.slider('Inward Dilation Strength', min_value=1, max_value=10, value=5) | |
| with col1: | |
| st.image(image, caption='Input Image', use_column_width=True) | |
| if st.button('Apply and Predict'): | |
| results = model( | |
| image_cv, | |
| iou=iou_threshold, | |
| conf=conf_threshold, | |
| show_labels=show_labels, | |
| show_conf=show_conf, | |
| boxes=boxes, | |
| ) | |
| masks = results[0].masks | |
| mask_image = np.zeros((image_cv.shape[0], image_cv.shape[1], 4), dtype=np.uint8) | |
| annotated_frame = results[0].plot() | |
| annotated_image = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)) | |
| for segment in masks.xy: | |
| segment = np.array(segment, dtype=np.int32) | |
| segment = segment.reshape((-1, 1, 2)) | |
| cv2.fillPoly(mask_image, [segment], (255, 255, 255, 255)) | |
| if smooth_edges: | |
| if smoothing_method == 'Morphological Closing': | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (closing_radius, closing_radius)) | |
| mask_image = cv2.morphologyEx(mask_image, cv2.MORPH_CLOSE, kernel) | |
| elif smoothing_method == 'Image Dilation': | |
| kernel_size = 2 * dilation_strength - 1 | |
| kernel = np.ones((kernel_size, kernel_size), np.uint8) | |
| mask_image = cv2.dilate(mask_image, kernel, iterations=1) | |
| elif smoothing_method == 'Inward Morphological Closing': | |
| kernel_inward = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (closing_radius_inward, closing_radius_inward)) | |
| mask_image_borders = cv2.morphologyEx(mask_image, cv2.MORPH_GRADIENT, kernel_inward) | |
| # Apply alpha blending to smooth the edges inward | |
| mask_image = cv2.subtract(mask_image, mask_image_borders) | |
| elif smoothing_method == 'Inward Image Dilation': | |
| kernel_size_inward = 2 * dilation_strength_inward - 1 | |
| kernel_inward = np.ones((kernel_size_inward, kernel_size_inward), np.uint8) | |
| mask_image_borders = cv2.dilate(mask_image, kernel_inward, iterations=1) | |
| # Apply alpha blending to smooth the edges inward | |
| mask_image = cv2.subtract(mask_image, mask_image_borders) | |
| alpha_channel = mask_image[:, :, 3] | |
| image_rgba = np.concatenate((image_cv, np.expand_dims(alpha_channel, axis=2)), axis=2) | |
| masked_image = image_rgba * (mask_image / 255) | |
| masked_pil = Image.fromarray(masked_image.astype(np.uint8), 'RGBA') | |
| st.image([annotated_image, masked_pil], caption=['Detections', 'Masked Image'], use_column_width=True) | |
| def tab4(): | |
| st.header("Manual Background Removal") | |
| st.markdown(manual_bg_intro(), unsafe_allow_html=True) | |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| col1, col2 = st.columns([2,1]) | |
| with col1: | |
| image = Image.open(uploaded_file) | |
| max_image_size = 700 | |
| if max(image.size) > max_image_size: | |
| image.thumbnail((max_image_size, max_image_size), Image.LANCZOS) # Updated resampling filter | |
| st.image(image, caption="Original Image") | |
| image_width, image_height = image.size | |
| with col2: | |
| drawing_mode = "point" | |
| stroke_width = st.slider("Stroke width: ", 1, 25, 3) | |
| #point_display_radius = st.slider("Point display radius: ", 1, 25, 3) | |
| realtime_update = st.checkbox("Update in realtime", True) | |
| with col1: | |
| st.subheader("Select Points on the Canvas") | |
| canvas_result = st_canvas( | |
| fill_color="rgba(255, 165, 0, 0.3)", | |
| stroke_width=stroke_width, | |
| background_image=image, | |
| update_streamlit=realtime_update, | |
| height=image_height, | |
| width=image_width, | |
| drawing_mode=drawing_mode, | |
| #point_display_radius=point_display_radius if drawing_mode == 'point' else 0, | |
| key="canvas", | |
| ) | |
| if st.button("Remove Background"): | |
| st.subheader("This Feature is broken in this hosting service, please DM for Private link") | |
| if canvas_result.json_data is not None: | |
| points = [] | |
| for obj in canvas_result.json_data["objects"]: | |
| if "type" in obj and obj["type"] == "circle": | |
| x = obj["left"] | |
| y = obj["top"] | |
| points.append((x, y)) | |
| img_array = np.array(image) | |
| result = remove_background(img_array, points) | |
| result_image = Image.fromarray(result) | |
| transparent_bg_result = result_image.convert("RGBA") | |
| file_path = "background_removed.png" | |
| transparent_bg_result.save(file_path, format="PNG") | |
| st.image(transparent_bg_result, caption="Background Removed Image") | |
| def tab5(): | |
| st.header("Image Perspective Correction") | |
| st.write("Upload a transparent PNG image which you have removed the background using the previous tab.") | |
| st.markdown(perspective_intro(),unsafe_allow_html=True) | |
| uploaded_file = st.file_uploader("Choose a PNG image", type="png") | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| col1, col2 = st.columns([2,1]) | |
| with col1: | |
| st.image(image, caption="Original Image", use_column_width=True) | |
| image_np = np.array(image) | |
| with col2: | |
| correction_method = st.selectbox("Correction Method", ["Four-Point Perspective Correction", "Convex Hull Homography Perspective Correction"]) | |
| if correction_method == "Four-Point Perspective Correction": | |
| threshold_value = st.slider("Threshold Value", min_value=1, max_value=255, value=100) | |
| min_line_length = st.slider("Minimum Line Length", min_value=1, max_value=500, value=100) | |
| max_line_gap = st.slider("Maximum Line Gap", min_value=1, max_value=100, value=10) | |
| length_ratio = st.number_input("Width Ratio", value=1, min_value=1, max_value=10, step=1) | |
| width_ratio = st.number_input("Length Ratio", value=1, min_value=1, max_value=10, step=1) | |
| elif correction_method == "Convex Hull Homography Perspective Correction": | |
| threshold_value = st.slider("Threshold Value", min_value=1, max_value=255, value=100) | |
| min_line_length = st.slider("Minimum Line Length", min_value=1, max_value=500, value=100) | |
| max_line_gap = st.slider("Maximum Line Gap", min_value=1, max_value=100, value=10) | |
| length_ratio = st.number_input("Width Ratio", value=1, min_value=1, max_value=10, step=1) | |
| width_ratio = st.number_input("Length Ratio", value=1, min_value=1, max_value=10, step=1) | |
| else: | |
| st.write("Invalid correction method selected.") | |
| return | |
| with col1: | |
| if st.button("Correct Perspective"): | |
| with st.spinner("Correcting Perspective..."): | |
| if uploaded_file is not None: | |
| if correction_method == "Four-Point Perspective Correction": | |
| corrected_image = perspective_correction(image_np, threshold_value, min_line_length, max_line_gap, length_ratio, width_ratio) | |
| elif correction_method == "Convex Hull Homography Perspective Correction": | |
| corrected_image = perspective_correction2(image_np, threshold_value, min_line_length, max_line_gap, length_ratio, width_ratio) | |
| else: | |
| st.write("Invalid correction method selected.") | |
| return | |
| st.image(corrected_image, caption="Corrected Image", use_column_width=True) | |
| def main(): | |
| st.set_page_config(page_title="Background Removal Demo", page_icon=":memo:", layout="wide") | |
| tabs = ["Intro", "AI Background Removal", "Background Removal with Segmentaion", "Manual Background Removal", "Perspective Correction"] | |
| with st.sidebar: | |
| current_tab = option_menu("Select a Tab", tabs, menu_icon="cast") | |
| tab_functions = { | |
| "Intro": tab1, | |
| "AI Background Removal": tab2, | |
| "Background Removal with Segmentaion": tab3, | |
| "Manual Background Removal": tab4, | |
| "Perspective Correction": tab5, | |
| } | |
| if current_tab in tab_functions: | |
| tab_functions[current_tab]() | |
| if __name__ == "__main__": | |
| main() |