|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
from fire import Fire |
|
|
|
|
|
from s3od import BackgroundRemoval |
|
|
from s3od.visualizer import visualize_removal |
|
|
|
|
|
|
|
|
MODEL_VARIANTS = { |
|
|
'General (Synth + Real)': 'okupyn/s3od', |
|
|
'Synthetic Only': 'okupyn/s3od-synth', |
|
|
'DIS-tuned': 'okupyn/s3od-dis', |
|
|
'SOD-tuned': 'okupyn/s3od-sod', |
|
|
} |
|
|
|
|
|
|
|
|
_model_cache = {} |
|
|
|
|
|
def get_detector(model_name): |
|
|
"""Get or load detector for the specified model.""" |
|
|
if model_name not in _model_cache: |
|
|
print(f"Loading model: {model_name}") |
|
|
_model_cache[model_name] = BackgroundRemoval(model_id=model_name) |
|
|
return _model_cache[model_name] |
|
|
|
|
|
|
|
|
detector = get_detector('okupyn/s3od') |
|
|
|
|
|
VISUALIZATION_METHODS = { |
|
|
'Transparent Background': 'transparent', |
|
|
'White Background': 'white', |
|
|
'Green Background': 'green', |
|
|
'Mask Only': 'mask' |
|
|
} |
|
|
|
|
|
|
|
|
def compute_mask_iou(mask1, mask2): |
|
|
"""Compute IoU between two masks.""" |
|
|
intersection = np.logical_and(mask1 > 0.5, mask2 > 0.5).sum() |
|
|
union = np.logical_or(mask1 > 0.5, mask2 > 0.5).sum() |
|
|
return intersection / (union + 1e-6) |
|
|
|
|
|
|
|
|
def is_ambiguous(all_masks, threshold=0.8): |
|
|
"""Check if prediction is ambiguous based on mask IoU.""" |
|
|
if len(all_masks) < 2: |
|
|
return False |
|
|
|
|
|
|
|
|
for i in range(len(all_masks)): |
|
|
for j in range(i + 1, len(all_masks)): |
|
|
iou = compute_mask_iou(all_masks[i], all_masks[j]) |
|
|
if iou < threshold: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def create_masks_grid(all_masks, all_ious, image_shape): |
|
|
"""Create a grid showing all 3 masks side by side.""" |
|
|
h, w = image_shape[:2] |
|
|
num_masks = len(all_masks) |
|
|
|
|
|
|
|
|
grid_w = w * num_masks |
|
|
grid_h = h |
|
|
grid = Image.new('L', (grid_w, grid_h), color=0) |
|
|
|
|
|
for idx, (mask, iou) in enumerate(zip(all_masks, all_ious)): |
|
|
|
|
|
mask_img = Image.fromarray((mask * 255).astype(np.uint8), mode='L') |
|
|
|
|
|
|
|
|
grid.paste(mask_img, (idx * w, 0)) |
|
|
|
|
|
return grid |
|
|
|
|
|
|
|
|
def process_image(image, model_key, method_key, threshold): |
|
|
if image is None: |
|
|
return None, None, None |
|
|
|
|
|
|
|
|
model_id = MODEL_VARIANTS.get(model_key, 'okupyn/s3od') |
|
|
detector = get_detector(model_id) |
|
|
|
|
|
result = detector.remove_background(image, threshold=threshold) |
|
|
method = VISUALIZATION_METHODS.get(method_key, 'transparent') |
|
|
|
|
|
|
|
|
if method == 'transparent': |
|
|
main_output = result.rgba_image |
|
|
elif method == 'white': |
|
|
main_output = visualize_removal(image, result, background_color=(255, 255, 255)) |
|
|
elif method == 'green': |
|
|
main_output = visualize_removal(image, result, background_color=(0, 255, 0)) |
|
|
elif method == 'mask': |
|
|
mask_vis = (result.predicted_mask * 255).astype(np.uint8) |
|
|
main_output = Image.fromarray(mask_vis, mode='L') |
|
|
else: |
|
|
main_output = result.rgba_image |
|
|
|
|
|
|
|
|
masks_grid = create_masks_grid(result.all_masks, result.all_ious, image.shape) |
|
|
|
|
|
|
|
|
ambiguous = is_ambiguous(result.all_masks) |
|
|
ambiguity_label = "β οΈ Ambiguous prediction (IoU < 0.8 between masks)" if ambiguous else "β Clear prediction" |
|
|
|
|
|
return main_output, masks_grid, ambiguity_label |
|
|
|
|
|
|
|
|
with gr.Blocks(title="S3OD - Synthetic Salient Object Detection") as demo: |
|
|
gr.Markdown(""" |
|
|
# S3OD: Synthetic Salient Object Detection |
|
|
|
|
|
Upload an image to remove its background using **S3OD**! |
|
|
|
|
|
S3OD is trained on a large-scale fully synthetic dataset (140K+ images) generated with diffusion models. |
|
|
The model uses a DPT-based architecture with DINOv3 vision transformer backbone for robust salient object detection. |
|
|
|
|
|
**Model Variants:** |
|
|
- **General (Synth + Real)**: Default model trained on synthetic data and fine-tuned on all real datasets (DUTS, DIS, HR-SOD) |
|
|
- **Synthetic Only**: Trained exclusively on S3OD synthetic dataset |
|
|
- **DIS-tuned**: Fine-tuned specifically for highly-accurate dichotomous segmentation |
|
|
- **SOD-tuned**: Optimized for general salient object detection tasks |
|
|
|
|
|
**Key Features:** |
|
|
- Single-step background removal with soft masks (smooth edges) |
|
|
- Multi-mask prediction with IoU scoring |
|
|
- Ambiguity detection for uncertain predictions |
|
|
- Works on any image resolution |
|
|
|
|
|
π [Paper](https://arxiv.org/abs/2510.21605) | π» [GitHub](https://github.com/KupynOrest/s3od) | π€ [Model](https://huggingface.co/okupyn/s3od) | ποΈ [Dataset](https://huggingface.co/datasets/okupyn/s3od_dataset) |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image(type="numpy", label="Upload an Image") |
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=list(MODEL_VARIANTS.keys()), |
|
|
label="Model Variant", |
|
|
value='General (Synth + Real)', |
|
|
info="Choose the model variant trained on different datasets" |
|
|
) |
|
|
method_radio = gr.Radio( |
|
|
list(VISUALIZATION_METHODS.keys()), |
|
|
label="Output Format", |
|
|
value='Transparent Background' |
|
|
) |
|
|
threshold_slider = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.5, |
|
|
step=0.05, |
|
|
label="Mask Threshold" |
|
|
) |
|
|
submit_btn = gr.Button("Remove Background", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_image = gr.Image(type="pil", label="Result") |
|
|
ambiguity_label = gr.Textbox(label="Prediction Quality", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
masks_grid = gr.Image(type="pil", label="All 3 Predicted Masks (with IoU scores)") |
|
|
|
|
|
submit_btn.click( |
|
|
fn=process_image, |
|
|
inputs=[input_image, model_dropdown, method_radio, threshold_slider], |
|
|
outputs=[output_image, masks_grid, ambiguity_label] |
|
|
) |
|
|
|
|
|
|
|
|
input_image.change( |
|
|
fn=process_image, |
|
|
inputs=[input_image, model_dropdown, method_radio, threshold_slider], |
|
|
outputs=[output_image, masks_grid, ambiguity_label] |
|
|
) |
|
|
|
|
|
|
|
|
def main(server_name="0.0.0.0", server_port=7860, share=False): |
|
|
demo.launch( |
|
|
server_name=server_name, |
|
|
server_port=server_port, |
|
|
share=share |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
Fire(main) |
|
|
|
|
|
|