File size: 6,624 Bytes
6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 4993e87 6268a55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from fire import Fire
from s3od import BackgroundRemoval
from s3od.visualizer import visualize_removal
# Model variants mapping
MODEL_VARIANTS = {
'General (Synth + Real)': 'okupyn/s3od',
'Synthetic Only': 'okupyn/s3od-synth',
'DIS-tuned': 'okupyn/s3od-dis',
'SOD-tuned': 'okupyn/s3od-sod',
}
# Cache loaded models to avoid reloading
_model_cache = {}
def get_detector(model_name):
"""Get or load detector for the specified model."""
if model_name not in _model_cache:
print(f"Loading model: {model_name}")
_model_cache[model_name] = BackgroundRemoval(model_id=model_name)
return _model_cache[model_name]
# Load default model
detector = get_detector('okupyn/s3od')
VISUALIZATION_METHODS = {
'Transparent Background': 'transparent',
'White Background': 'white',
'Green Background': 'green',
'Mask Only': 'mask'
}
def compute_mask_iou(mask1, mask2):
"""Compute IoU between two masks."""
intersection = np.logical_and(mask1 > 0.5, mask2 > 0.5).sum()
union = np.logical_or(mask1 > 0.5, mask2 > 0.5).sum()
return intersection / (union + 1e-6)
def is_ambiguous(all_masks, threshold=0.8):
"""Check if prediction is ambiguous based on mask IoU."""
if len(all_masks) < 2:
return False
# Compute IoU between all pairs
for i in range(len(all_masks)):
for j in range(i + 1, len(all_masks)):
iou = compute_mask_iou(all_masks[i], all_masks[j])
if iou < threshold:
return True
return False
def create_masks_grid(all_masks, all_ious, image_shape):
"""Create a grid showing all 3 masks side by side."""
h, w = image_shape[:2]
num_masks = len(all_masks)
# Create grid image
grid_w = w * num_masks
grid_h = h
grid = Image.new('L', (grid_w, grid_h), color=0)
for idx, (mask, iou) in enumerate(zip(all_masks, all_ious)):
# Convert mask to image
mask_img = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
# Paste into grid
grid.paste(mask_img, (idx * w, 0))
return grid
def process_image(image, model_key, method_key, threshold):
if image is None:
return None, None, None
# Get the appropriate model
model_id = MODEL_VARIANTS.get(model_key, 'okupyn/s3od')
detector = get_detector(model_id)
result = detector.remove_background(image, threshold=threshold)
method = VISUALIZATION_METHODS.get(method_key, 'transparent')
# Generate main output
if method == 'transparent':
main_output = result.rgba_image
elif method == 'white':
main_output = visualize_removal(image, result, background_color=(255, 255, 255))
elif method == 'green':
main_output = visualize_removal(image, result, background_color=(0, 255, 0))
elif method == 'mask':
mask_vis = (result.predicted_mask * 255).astype(np.uint8)
main_output = Image.fromarray(mask_vis, mode='L')
else:
main_output = result.rgba_image
# Create masks grid
masks_grid = create_masks_grid(result.all_masks, result.all_ious, image.shape)
# Check if ambiguous
ambiguous = is_ambiguous(result.all_masks)
ambiguity_label = "β οΈ Ambiguous prediction (IoU < 0.8 between masks)" if ambiguous else "β Clear prediction"
return main_output, masks_grid, ambiguity_label
with gr.Blocks(title="S3OD - Synthetic Salient Object Detection") as demo:
gr.Markdown("""
# S3OD: Synthetic Salient Object Detection
Upload an image to remove its background using **S3OD**!
S3OD is trained on a large-scale fully synthetic dataset (140K+ images) generated with diffusion models.
The model uses a DPT-based architecture with DINOv3 vision transformer backbone for robust salient object detection.
**Model Variants:**
- **General (Synth + Real)**: Default model trained on synthetic data and fine-tuned on all real datasets (DUTS, DIS, HR-SOD)
- **Synthetic Only**: Trained exclusively on S3OD synthetic dataset
- **DIS-tuned**: Fine-tuned specifically for highly-accurate dichotomous segmentation
- **SOD-tuned**: Optimized for general salient object detection tasks
**Key Features:**
- Single-step background removal with soft masks (smooth edges)
- Multi-mask prediction with IoU scoring
- Ambiguity detection for uncertain predictions
- Works on any image resolution
π [Paper](https://arxiv.org/abs/2510.21605) | π» [GitHub](https://github.com/KupynOrest/s3od) | π€ [Model](https://huggingface.co/okupyn/s3od) | ποΈ [Dataset](https://huggingface.co/datasets/okupyn/s3od_dataset)
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="numpy", label="Upload an Image")
model_dropdown = gr.Dropdown(
choices=list(MODEL_VARIANTS.keys()),
label="Model Variant",
value='General (Synth + Real)',
info="Choose the model variant trained on different datasets"
)
method_radio = gr.Radio(
list(VISUALIZATION_METHODS.keys()),
label="Output Format",
value='Transparent Background'
)
threshold_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.05,
label="Mask Threshold"
)
submit_btn = gr.Button("Remove Background", variant="primary")
with gr.Column():
output_image = gr.Image(type="pil", label="Result")
ambiguity_label = gr.Textbox(label="Prediction Quality", interactive=False)
with gr.Row():
masks_grid = gr.Image(type="pil", label="All 3 Predicted Masks (with IoU scores)")
submit_btn.click(
fn=process_image,
inputs=[input_image, model_dropdown, method_radio, threshold_slider],
outputs=[output_image, masks_grid, ambiguity_label]
)
# Also trigger on image upload
input_image.change(
fn=process_image,
inputs=[input_image, model_dropdown, method_radio, threshold_slider],
outputs=[output_image, masks_grid, ambiguity_label]
)
def main(server_name="0.0.0.0", server_port=7860, share=False):
demo.launch(
server_name=server_name,
server_port=server_port,
share=share
)
if __name__ == '__main__':
Fire(main)
|