okupyn commited on
Commit
4993e87
·
verified ·
1 Parent(s): 6268a55

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -47
app.py CHANGED
@@ -1,12 +1,31 @@
1
  import gradio as gr
2
  import numpy as np
3
- from PIL import Image
4
  from fire import Fire
5
 
6
  from s3od import BackgroundRemoval
7
  from s3od.visualizer import visualize_removal
8
 
9
- detector = BackgroundRemoval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  VISUALIZATION_METHODS = {
12
  'Transparent Background': 'transparent',
@@ -16,73 +35,151 @@ VISUALIZATION_METHODS = {
16
  }
17
 
18
 
19
- def process_image(image, method, threshold):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  if image is None:
21
- return None
 
 
 
 
22
 
23
  result = detector.remove_background(image, threshold=threshold)
 
24
 
 
25
  if method == 'transparent':
26
- return result.rgba_image
27
  elif method == 'white':
28
- return visualize_removal(image, result, background_color=(255, 255, 255))
29
  elif method == 'green':
30
- return visualize_removal(image, result, background_color=(0, 255, 0))
31
  elif method == 'mask':
32
  mask_vis = (result.predicted_mask * 255).astype(np.uint8)
33
- return Image.fromarray(mask_vis, mode='L')
34
-
35
- return result.rgba_image
36
-
37
-
38
- iface = gr.Interface(
39
- fn=process_image,
40
- inputs=[
41
- gr.Image(type="numpy", label="Upload an Image"),
42
- gr.Radio(
43
- list(VISUALIZATION_METHODS.keys()),
44
- label="Output Format",
45
- value='Transparent Background'
46
- ),
47
- gr.Slider(
48
- minimum=0.0,
49
- maximum=1.0,
50
- value=0.5,
51
- step=0.05,
52
- label="Mask Threshold"
53
- )
54
- ],
55
- outputs=gr.Image(type="pil", label="Result"),
56
- title="Demo: S3OD - Synthetic Salient Object Detection",
57
- description="""
58
  Upload an image to remove its background using **S3OD**!
59
 
60
  S3OD is trained on a large-scale fully synthetic dataset (140K+ images) generated with diffusion models.
61
- Despite being trained only on synthetic data, it achieves state-of-the-art performance on real-world images.
62
 
63
- The model uses a DPT-based architecture with DINOv3 vision transformer backbone for robust salient object detection
64
- and can process images of any size. Choose from four visualization methods: transparent background (RGBA),
65
- white background, green background (chroma key), or mask only.
 
 
66
 
67
  **Key Features:**
68
- - Single-step background removal
69
  - Multi-mask prediction with IoU scoring
70
- - Adjustable threshold for fine-tuning
71
  - Works on any image resolution
72
 
73
- Ideal for applications in e-commerce, content creation, photo editing, and computer vision research.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- 📄 [Paper](https://arxiv.org/abs/XXXX.XXXXX) | 💻 [GitHub](https://github.com/KupynOrest/s3od) | 🤗 [Model](https://huggingface.co/okupyn/s3od) | 🗂️ [Dataset](https://huggingface.co/datasets/okupyn/s3od_dataset)
76
- """,
77
- allow_flagging='never',
78
- examples=[
79
- # Add example images here when available
80
- ]
81
- )
 
 
 
 
 
82
 
83
 
84
  def main(server_name="0.0.0.0", server_port=7860, share=False):
85
- iface.launch(
86
  server_name=server_name,
87
  server_port=server_port,
88
  share=share
 
1
  import gradio as gr
2
  import numpy as np
3
+ from PIL import Image, ImageDraw, ImageFont
4
  from fire import Fire
5
 
6
  from s3od import BackgroundRemoval
7
  from s3od.visualizer import visualize_removal
8
 
9
+ # Model variants mapping
10
+ MODEL_VARIANTS = {
11
+ 'General (Synth + Real)': 'okupyn/s3od',
12
+ 'Synthetic Only': 'okupyn/s3od-synth',
13
+ 'DIS-tuned': 'okupyn/s3od-dis',
14
+ 'SOD-tuned': 'okupyn/s3od-sod',
15
+ }
16
+
17
+ # Cache loaded models to avoid reloading
18
+ _model_cache = {}
19
+
20
+ def get_detector(model_name):
21
+ """Get or load detector for the specified model."""
22
+ if model_name not in _model_cache:
23
+ print(f"Loading model: {model_name}")
24
+ _model_cache[model_name] = BackgroundRemoval(model_id=model_name)
25
+ return _model_cache[model_name]
26
+
27
+ # Load default model
28
+ detector = get_detector('okupyn/s3od')
29
 
30
  VISUALIZATION_METHODS = {
31
  'Transparent Background': 'transparent',
 
35
  }
36
 
37
 
38
+ def compute_mask_iou(mask1, mask2):
39
+ """Compute IoU between two masks."""
40
+ intersection = np.logical_and(mask1 > 0.5, mask2 > 0.5).sum()
41
+ union = np.logical_or(mask1 > 0.5, mask2 > 0.5).sum()
42
+ return intersection / (union + 1e-6)
43
+
44
+
45
+ def is_ambiguous(all_masks, threshold=0.8):
46
+ """Check if prediction is ambiguous based on mask IoU."""
47
+ if len(all_masks) < 2:
48
+ return False
49
+
50
+ # Compute IoU between all pairs
51
+ for i in range(len(all_masks)):
52
+ for j in range(i + 1, len(all_masks)):
53
+ iou = compute_mask_iou(all_masks[i], all_masks[j])
54
+ if iou < threshold:
55
+ return True
56
+ return False
57
+
58
+
59
+ def create_masks_grid(all_masks, all_ious, image_shape):
60
+ """Create a grid showing all 3 masks side by side."""
61
+ h, w = image_shape[:2]
62
+ num_masks = len(all_masks)
63
+
64
+ # Create grid image
65
+ grid_w = w * num_masks
66
+ grid_h = h
67
+ grid = Image.new('L', (grid_w, grid_h), color=0)
68
+
69
+ for idx, (mask, iou) in enumerate(zip(all_masks, all_ious)):
70
+ # Convert mask to image
71
+ mask_img = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
72
+
73
+ # Paste into grid
74
+ grid.paste(mask_img, (idx * w, 0))
75
+
76
+ return grid
77
+
78
+
79
+ def process_image(image, model_key, method_key, threshold):
80
  if image is None:
81
+ return None, None, None
82
+
83
+ # Get the appropriate model
84
+ model_id = MODEL_VARIANTS.get(model_key, 'okupyn/s3od')
85
+ detector = get_detector(model_id)
86
 
87
  result = detector.remove_background(image, threshold=threshold)
88
+ method = VISUALIZATION_METHODS.get(method_key, 'transparent')
89
 
90
+ # Generate main output
91
  if method == 'transparent':
92
+ main_output = result.rgba_image
93
  elif method == 'white':
94
+ main_output = visualize_removal(image, result, background_color=(255, 255, 255))
95
  elif method == 'green':
96
+ main_output = visualize_removal(image, result, background_color=(0, 255, 0))
97
  elif method == 'mask':
98
  mask_vis = (result.predicted_mask * 255).astype(np.uint8)
99
+ main_output = Image.fromarray(mask_vis, mode='L')
100
+ else:
101
+ main_output = result.rgba_image
102
+
103
+ # Create masks grid
104
+ masks_grid = create_masks_grid(result.all_masks, result.all_ious, image.shape)
105
+
106
+ # Check if ambiguous
107
+ ambiguous = is_ambiguous(result.all_masks)
108
+ ambiguity_label = "⚠️ Ambiguous prediction (IoU < 0.8 between masks)" if ambiguous else "✓ Clear prediction"
109
+
110
+ return main_output, masks_grid, ambiguity_label
111
+
112
+
113
+ with gr.Blocks(title="S3OD - Synthetic Salient Object Detection") as demo:
114
+ gr.Markdown("""
115
+ # S3OD: Synthetic Salient Object Detection
116
+
 
 
 
 
 
 
 
117
  Upload an image to remove its background using **S3OD**!
118
 
119
  S3OD is trained on a large-scale fully synthetic dataset (140K+ images) generated with diffusion models.
120
+ The model uses a DPT-based architecture with DINOv3 vision transformer backbone for robust salient object detection.
121
 
122
+ **Model Variants:**
123
+ - **General (Synth + Real)**: Default model trained on synthetic data and fine-tuned on all real datasets (DUTS, DIS, HR-SOD)
124
+ - **Synthetic Only**: Trained exclusively on S3OD synthetic dataset
125
+ - **DIS-tuned**: Fine-tuned specifically for highly-accurate dichotomous segmentation
126
+ - **SOD-tuned**: Optimized for general salient object detection tasks
127
 
128
  **Key Features:**
129
+ - Single-step background removal with soft masks (smooth edges)
130
  - Multi-mask prediction with IoU scoring
131
+ - Ambiguity detection for uncertain predictions
132
  - Works on any image resolution
133
 
134
+ 📄 [Paper](https://arxiv.org/abs/2510.21605) | 💻 [GitHub](https://github.com/KupynOrest/s3od) | 🤗 [Model](https://huggingface.co/okupyn/s3od) | 🗂️ [Dataset](https://huggingface.co/datasets/okupyn/s3od_dataset)
135
+ """)
136
+
137
+ with gr.Row():
138
+ with gr.Column():
139
+ input_image = gr.Image(type="numpy", label="Upload an Image")
140
+ model_dropdown = gr.Dropdown(
141
+ choices=list(MODEL_VARIANTS.keys()),
142
+ label="Model Variant",
143
+ value='General (Synth + Real)',
144
+ info="Choose the model variant trained on different datasets"
145
+ )
146
+ method_radio = gr.Radio(
147
+ list(VISUALIZATION_METHODS.keys()),
148
+ label="Output Format",
149
+ value='Transparent Background'
150
+ )
151
+ threshold_slider = gr.Slider(
152
+ minimum=0.0,
153
+ maximum=1.0,
154
+ value=0.5,
155
+ step=0.05,
156
+ label="Mask Threshold"
157
+ )
158
+ submit_btn = gr.Button("Remove Background", variant="primary")
159
+
160
+ with gr.Column():
161
+ output_image = gr.Image(type="pil", label="Result")
162
+ ambiguity_label = gr.Textbox(label="Prediction Quality", interactive=False)
163
+
164
+ with gr.Row():
165
+ masks_grid = gr.Image(type="pil", label="All 3 Predicted Masks (with IoU scores)")
166
 
167
+ submit_btn.click(
168
+ fn=process_image,
169
+ inputs=[input_image, model_dropdown, method_radio, threshold_slider],
170
+ outputs=[output_image, masks_grid, ambiguity_label]
171
+ )
172
+
173
+ # Also trigger on image upload
174
+ input_image.change(
175
+ fn=process_image,
176
+ inputs=[input_image, model_dropdown, method_radio, threshold_slider],
177
+ outputs=[output_image, masks_grid, ambiguity_label]
178
+ )
179
 
180
 
181
  def main(server_name="0.0.0.0", server_port=7860, share=False):
182
+ demo.launch(
183
  server_name=server_name,
184
  server_port=server_port,
185
  share=share