heyoujue commited on
Commit
59fdec9
·
0 Parent(s):

initial commit

Browse files
Files changed (3) hide show
  1. app.py +330 -0
  2. corrupt_mask.py +212 -0
  3. in/semantic_class_0.png +0 -0
app.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tempfile
4
+ import os
5
+ from pathlib import Path
6
+ import cv2
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ import io
10
+ from typing import Optional
11
+
12
+
13
+ from corrupt_mask import MaskCorruptor
14
+
15
+
16
+ class GradioMaskCorruptor:
17
+ """Wrapper for MaskCorruptor with Gradio-specific functionality."""
18
+
19
+ def __init__(self):
20
+ self.corruptor = None
21
+ self.original_mask = None
22
+
23
+ def visualize_masks(self, original_mask, corrupted_mask, colormap='viridis'):
24
+ """Create a visualization comparing original and corrupted masks."""
25
+ # Create figure with subplots
26
+ fig, axes = plt.subplots(1, 3, figsize=(12, 4))
27
+
28
+ # Plot original mask
29
+ im1 = axes[0].imshow(original_mask, cmap=colormap)
30
+ axes[0].set_title('Original Mask')
31
+ axes[0].axis('off')
32
+ plt.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)
33
+
34
+ # Plot corrupted mask
35
+ im2 = axes[1].imshow(corrupted_mask, cmap=colormap)
36
+ axes[1].set_title('Corrupted Mask')
37
+ axes[1].axis('off')
38
+ plt.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)
39
+
40
+ # Plot difference
41
+ diff = np.where(corrupted_mask != original_mask, 1, 0)
42
+ im3 = axes[2].imshow(diff, cmap='Reds', vmin=0, vmax=1)
43
+ axes[2].set_title('Corrupted Pixels (Red)')
44
+ axes[2].axis('off')
45
+ plt.colorbar(im3, ax=axes[2], fraction=0.046, pad=0.04)
46
+
47
+ plt.tight_layout()
48
+
49
+ # Convert to PIL Image
50
+ buf = io.BytesIO()
51
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
52
+ buf.seek(0)
53
+ img = Image.open(buf)
54
+ plt.close(fig)
55
+
56
+ return img
57
+
58
+ def process_single_mask(self,
59
+ input_image,
60
+ drop_probability: float,
61
+ mislabel_probability: float,
62
+ max_label: int,
63
+ preserve_background: bool,
64
+ seed: Optional[int],
65
+ colormap: str):
66
+ """Process a single uploaded mask image."""
67
+ try:
68
+ # Convert PIL Image to numpy array
69
+ if isinstance(input_image, np.ndarray):
70
+ mask = input_image
71
+ else:
72
+ mask = np.array(input_image)
73
+
74
+ # If RGB, convert to grayscale
75
+ if len(mask.shape) == 3:
76
+ mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
77
+
78
+ # Store original mask
79
+ self.original_mask = mask.copy()
80
+
81
+ # Initialize corruptor with parameters
82
+ self.corruptor = MaskCorruptor(
83
+ drop_probability=drop_probability,
84
+ mislabel_probability=mislabel_probability,
85
+ seed=seed if seed != 0 else None
86
+ )
87
+
88
+ # Corrupt the mask
89
+ corrupted_mask = self.corruptor.corrupt_single_mask(
90
+ mask=mask,
91
+ max_label=max_label if max_label > 0 else None,
92
+ preserve_background=preserve_background
93
+ )
94
+
95
+ # Get statistics
96
+ original_labels = np.unique(mask)
97
+ corrupted_labels = np.unique(corrupted_mask)
98
+
99
+ dropped_instances = len(original_labels) - len(corrupted_labels)
100
+ if preserve_background and 0 in original_labels:
101
+ dropped_instances -= 1 # Don't count background
102
+
103
+ # Create visualization
104
+ viz_image = self.visualize_masks(mask, corrupted_mask, colormap)
105
+
106
+ # Create statistics text
107
+ stats_text = f"""
108
+ ## 📊 Corruption Statistics:
109
+
110
+ ### Original Mask:
111
+ - Unique labels: {len(original_labels)}
112
+ - Label values: {original_labels.tolist()}
113
+ - Shape: {mask.shape}
114
+
115
+ ### Corrupted Mask:
116
+ - Unique labels: {len(corrupted_labels)}
117
+ - Label values: {corrupted_labels.tolist()}
118
+ - Dropped instances: {max(0, dropped_instances)}
119
+ - Corruption probability: {drop_probability * 100:.1f}%
120
+ - Mislabel probability: {mislabel_probability * 100:.1f}%
121
+
122
+ ### Parameters:
123
+ - Preserve background: {preserve_background}
124
+ - Max label: {max_label if max_label > 0 else 'Auto'}
125
+ - Random seed: {seed if seed != 0 else 'Random'}
126
+ """
127
+
128
+ return viz_image, stats_text # ✅ Only return image and stats
129
+
130
+ except Exception as e:
131
+ return None, f"❌ Error processing image: {str(e)}"
132
+
133
+ def create_example_mask(self):
134
+ """Create an example synthetic mask for demonstration."""
135
+ # Create a synthetic mask with 5 instances
136
+ mask = np.zeros((256, 256), dtype=np.uint8)
137
+ for i in range(1, 6):
138
+ mask[30 * i:30 * i + 20, 30 * i:30 * i + 20] = i
139
+
140
+ # Add some non-rectangular shapes
141
+ cv2.circle(mask, (100, 100), 15, 6, -1)
142
+ cv2.ellipse(mask, (200, 150), (20, 10), 0, 0, 360, 7, -1)
143
+
144
+ # Convert to PIL Image
145
+ mask_img = Image.fromarray(mask.astype(np.uint8))
146
+ return mask_img
147
+
148
+ def create_gradio_app():
149
+ """Create and configure the Gradio interface."""
150
+
151
+ corruptor = GradioMaskCorruptor()
152
+
153
+ # Create example mask
154
+ example_mask = corruptor.create_example_mask()
155
+
156
+ # Define CSS for better styling
157
+ css = """
158
+ .gradio-container {
159
+ max-width: 1200px !important;
160
+ }
161
+ .output-image {
162
+ border: 2px solid #4CAF50;
163
+ border-radius: 10px;
164
+ }
165
+ .stats-box {
166
+ background-color: #f0f8ff;
167
+ padding: 15px;
168
+ border-radius: 10px;
169
+ border-left: 5px solid #2196F3;
170
+ }
171
+ """
172
+
173
+ # Define the Gradio interface
174
+ with gr.Blocks(title="Mask Corruption Tool", css=css) as app:
175
+ gr.Markdown("""
176
+ # 🎭 Mask Corruption Tool
177
+
178
+ Upload a segmentation mask and artificially corrupt it by:
179
+ 1. **Randomly dropping** mask instances
180
+ 2. **Assigning wrong labels** to mask instances
181
+
182
+ Perfect for testing segmentation model robustness!
183
+ """)
184
+
185
+ with gr.Row():
186
+ with gr.Column(scale=1):
187
+ # Input section
188
+ gr.Markdown("## 📤 Input Settings")
189
+
190
+ input_image = gr.Image(
191
+ label="Upload Mask Image",
192
+ type="pil",
193
+ height=300,
194
+ elem_classes=["input-image"]
195
+ )
196
+
197
+ with gr.Row():
198
+ use_example = gr.Button("📋 Load Example Mask", variant="secondary")
199
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
200
+
201
+ # Parameters section
202
+ gr.Markdown("## ⚙️ Corruption Parameters")
203
+
204
+ drop_prob = gr.Slider(
205
+ minimum=0.0,
206
+ maximum=1.0,
207
+ value=0.1,
208
+ step=0.05,
209
+ label="Drop Probability",
210
+ info="Probability of completely removing each mask instance"
211
+ )
212
+
213
+ mislabel_prob = gr.Slider(
214
+ minimum=0.0,
215
+ maximum=1.0,
216
+ value=0.1,
217
+ step=0.05,
218
+ label="Mislabel Probability",
219
+ info="Probability of assigning wrong label to each instance"
220
+ )
221
+
222
+ with gr.Row():
223
+ max_label = gr.Number(
224
+ value=10,
225
+ label="Max Label Value",
226
+ info="Set to 0 for auto-detect",
227
+ precision=0
228
+ )
229
+
230
+ seed = gr.Number(
231
+ value=42,
232
+ label="Random Seed",
233
+ info="Set to 0 for random",
234
+ precision=0
235
+ )
236
+
237
+ preserve_bg = gr.Checkbox(
238
+ value=True,
239
+ label="Preserve Background (label 0)",
240
+ info="Keep background label unchanged"
241
+ )
242
+
243
+ colormap = gr.Dropdown(
244
+ choices=['viridis', 'plasma', 'inferno', 'magma', 'cividis',
245
+ 'tab20', 'Set3', 'Set2', 'tab20c'],
246
+ value='viridis',
247
+ label="Colormap for Visualization"
248
+ )
249
+
250
+ process_btn = gr.Button("✨ Corrupt Mask!", variant="primary")
251
+
252
+ with gr.Column(scale=2):
253
+ # Output section
254
+ gr.Markdown("## 📊 Results")
255
+
256
+ output_image = gr.Image(
257
+ label="Visualization Comparison",
258
+ type="pil",
259
+ height=400,
260
+ elem_classes=["output-image"]
261
+ )
262
+
263
+ # Statistics section
264
+ stats_output = gr.Markdown(
265
+ label="Statistics",
266
+ elem_classes=["stats-box"]
267
+ )
268
+
269
+ # Example callbacks
270
+ use_example.click(
271
+ fn=lambda: example_mask,
272
+ outputs=[input_image]
273
+ )
274
+
275
+ clear_btn.click(
276
+ fn=lambda: (None, None, "## 📊 Corruption Statistics:\n\n*Upload a mask to see results here...*"),
277
+ outputs=[input_image, output_image, stats_output]
278
+ )
279
+
280
+ # Main processing callback
281
+ process_btn.click(
282
+ fn=corruptor.process_single_mask,
283
+ inputs=[input_image, drop_prob, mislabel_prob, max_label, preserve_bg, seed, colormap],
284
+ outputs=[output_image, stats_output]
285
+ )
286
+
287
+ # Examples section
288
+ gr.Markdown("## 🚀 Quick Examples")
289
+
290
+ gr.Examples(
291
+ examples=[
292
+ [example_mask, 0.2, 0.3, 10, True, 42, 'viridis'],
293
+ [example_mask, 0.5, 0.1, 10, True, 123, 'plasma'],
294
+ [example_mask, 0.1, 0.5, 10, False, 42, 'inferno'],
295
+ ],
296
+ inputs=[input_image, drop_prob, mislabel_prob, max_label, preserve_bg, seed, colormap],
297
+ outputs=[output_image, stats_output],
298
+ fn=corruptor.process_single_mask,
299
+ cache_examples=True
300
+ )
301
+
302
+ # Footer
303
+ gr.Markdown("""
304
+ ---
305
+ ### 📝 How to use:
306
+ 1. Upload a mask image (grayscale, each instance with unique integer label)
307
+ 2. Adjust corruption parameters using the sliders
308
+ 3. Click "Corrupt Mask!" to process
309
+ 4. View the comparison visualization and statistics
310
+
311
+ ### 💡 Tips:
312
+ - Use the example mask to get started quickly
313
+ - Set Random Seed to 0 for different results each time
314
+ - Higher drop/mislabel probabilities = more corruption
315
+ - Preserve background keeps label 0 unchanged (recommended for most cases)
316
+ """)
317
+ return app
318
+
319
+
320
+ # For HuggingFace Spaces deployment
321
+ app = create_gradio_app()
322
+
323
+ if __name__ == "__main__":
324
+ # For local testing
325
+ app.launch(
326
+ debug=True,
327
+ css="css",
328
+ theme=gr.themes.Soft(),
329
+ show_error = True,
330
+ )
corrupt_mask.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ from pathlib import Path
4
+ from typing import List, Tuple, Dict, Optional, Union
5
+ import tifffile
6
+ import cv2
7
+ import json
8
+ import argparse
9
+
10
+
11
+ class MaskCorruptor:
12
+ """
13
+ Corrupts segmentation masks by randomly dropping masks or assigning wrong labels.
14
+ """
15
+
16
+ def __init__(self,
17
+ drop_probability: float = 0.1,
18
+ mislabel_probability: float = 0.1,
19
+ mislabel_noise_level: float = 0.2,
20
+ seed: Optional[int] = None):
21
+ """
22
+ Initialize the corruptor with corruption parameters.
23
+
24
+ Args:
25
+ drop_probability: Probability of completely dropping a mask instance
26
+ mislabel_probability: Probability of assigning a wrong label to a mask instance
27
+ mislabel_noise_level: Maximum relative change in label value (0-1)
28
+ seed: Random seed for reproducibility
29
+ """
30
+ self.drop_probability = drop_probability
31
+ self.mislabel_probability = mislabel_probability
32
+ self.mislabel_noise_level = mislabel_noise_level
33
+
34
+ if seed is not None:
35
+ np.random.seed(seed)
36
+ random.seed(seed)
37
+
38
+ def corrupt_single_mask(self,
39
+ mask: np.ndarray,
40
+ max_label: Optional[int] = None,
41
+ preserve_background: bool = True) -> np.ndarray:
42
+ """
43
+ Corrupt a single mask image.
44
+
45
+ Args:
46
+ mask: Input mask array (2D or 3D)
47
+ max_label: Maximum label value to consider (if None, use max in mask)
48
+ preserve_background: If True, keep label 0 as background unchanged
49
+
50
+ Returns:
51
+ Corrupted mask array
52
+ """
53
+ mask = mask.copy()
54
+
55
+ if max_label is None:
56
+ max_label = np.max(mask)
57
+
58
+ # Get unique labels (excluding background if preserve_background is True)
59
+ unique_labels = np.unique(mask)
60
+ if preserve_background and 0 in unique_labels:
61
+ unique_labels = unique_labels[unique_labels != 0]
62
+
63
+ if len(unique_labels) == 0:
64
+ return mask # No instances to corrupt
65
+
66
+ corrupted_mask = np.zeros_like(mask)
67
+
68
+ # Preserve background in output
69
+ if preserve_background:
70
+ corrupted_mask[mask == 0] = 0
71
+
72
+ for label in unique_labels:
73
+ # Skip if this is background and we're preserving it
74
+ if preserve_background and label == 0:
75
+ continue
76
+
77
+ # Randomly decide whether to drop this mask
78
+ if np.random.random() < self.drop_probability:
79
+ continue # Skip this mask entirely
80
+
81
+ # Extract the current mask instance
82
+ instance_mask = (mask == label)
83
+
84
+ # Randomly decide whether to mislabel
85
+ if np.random.random() < self.mislabel_probability:
86
+ # Generate a wrong label
87
+ if preserve_background:
88
+ # Generate label from 1 to max_label
89
+ possible_labels = [l for l in range(1, max_label + 1) if l != label]
90
+ else:
91
+ possible_labels = [l for l in range(0, max_label + 1) if l != label]
92
+
93
+ if possible_labels:
94
+ new_label = np.random.choice(possible_labels)
95
+ else:
96
+ new_label = label # No alternative labels available
97
+ else:
98
+ new_label = label
99
+
100
+ # Apply the (possibly modified) label
101
+ corrupted_mask[instance_mask] = new_label
102
+
103
+ return corrupted_mask
104
+
105
+ def corrupt_masks_from_directory(self,
106
+ input_dir: Union[str, Path],
107
+ output_dir: Union[str, Path],
108
+ file_pattern: str = "*.png",
109
+ max_label: Optional[int] = None,
110
+ preserve_background: bool = True):
111
+ """
112
+ Corrupt all masks in a directory and save to output directory.
113
+
114
+ Args:
115
+ input_dir: Directory containing input masks
116
+ output_dir: Directory to save corrupted masks
117
+ file_pattern: Pattern to match mask files
118
+ max_label: Maximum label value
119
+ preserve_background: Whether to preserve background (label 0)
120
+ """
121
+ input_path = Path(input_dir)
122
+ output_path = Path(output_dir)
123
+ output_path.mkdir(parents=True, exist_ok=True)
124
+
125
+ mask_files = list(input_path.glob(file_pattern))
126
+
127
+ print(f"Found {len(mask_files)} mask files to corrupt")
128
+
129
+ for i, mask_file in enumerate(mask_files, 1):
130
+ # Read mask (supporting different formats)
131
+ if mask_file.suffix.lower() in ['.tif', '.tiff']:
132
+ mask = tifffile.imread(mask_file)
133
+ elif mask_file.suffix.lower() in ['.png', '.jpg', '.jpeg']:
134
+ mask = cv2.imread(str(mask_file), cv2.IMREAD_GRAYSCALE)
135
+ else:
136
+ print(f"Unsupported file format: {mask_file.suffix}")
137
+ continue
138
+
139
+ # Corrupt the mask
140
+ corrupted_mask = self.corrupt_single_mask(
141
+ mask, max_label, preserve_background
142
+ )
143
+
144
+ # Save corrupted mask
145
+ output_file = output_path / mask_file.name
146
+ if mask_file.suffix.lower() in ['.tif', '.tiff']:
147
+ tifffile.imwrite(output_file, corrupted_mask)
148
+ else:
149
+ cv2.imwrite(str(output_file), corrupted_mask)
150
+
151
+ if i % 10 == 0 or i == len(mask_files):
152
+ print(f"Processed {i}/{len(mask_files)} files")
153
+
154
+ # Save corruption parameters as metadata
155
+ self.save_parameters(output_path)
156
+
157
+ def save_parameters(self, output_dir: Union[str, Path]):
158
+ """Save corruption parameters as JSON file."""
159
+ params = {
160
+ 'drop_probability': self.drop_probability,
161
+ 'mislabel_probability': self.mislabel_probability,
162
+ 'mislabel_noise_level': self.mislabel_noise_level,
163
+ 'corruption_type': 'random_drop_and_mislabel'
164
+ }
165
+
166
+ output_path = Path(output_dir)
167
+ with open(output_path / 'corruption_parameters.json', 'w') as f:
168
+ json.dump(params, f, indent=2)
169
+
170
+
171
+ def main():
172
+ parser = argparse.ArgumentParser(description='Corrupt segmentation masks')
173
+ parser.add_argument('--input_dir', type=str, required=True,
174
+ help='Input directory containing masks')
175
+ parser.add_argument('--output_dir', type=str, required=True,
176
+ help='Output directory for corrupted masks')
177
+ parser.add_argument('--drop_prob', type=float, default=0.1,
178
+ help='Probability of dropping a mask (0-1)')
179
+ parser.add_argument('--mislabel_prob', type=float, default=0.1,
180
+ help='Probability of mislabeling a mask (0-1)')
181
+ parser.add_argument('--max_label', type=int, default=None,
182
+ help='Maximum label value (if not specified, use max from data)')
183
+ parser.add_argument('--file_pattern', type=str, default='*',
184
+ help='File pattern for mask files')
185
+ parser.add_argument('--preserve_background', action='store_true',
186
+ help='Preserve label 0 as background (unchanged)')
187
+ parser.add_argument('--seed', type=int, default=None,
188
+ help='Random seed for reproducibility')
189
+
190
+ args = parser.parse_args()
191
+
192
+ # Create corruptor instance
193
+ corruptor = MaskCorruptor(
194
+ drop_probability=args.drop_prob,
195
+ mislabel_probability=args.mislabel_prob,
196
+ seed=args.seed
197
+ )
198
+
199
+ # Corrupt masks
200
+ corruptor.corrupt_masks_from_directory(
201
+ input_dir=args.input_dir,
202
+ output_dir=args.output_dir,
203
+ file_pattern=args.file_pattern,
204
+ max_label=args.max_label,
205
+ preserve_background=args.preserve_background
206
+ )
207
+
208
+ print(f"Corruption complete! Masks saved to {args.output_dir}")
209
+
210
+
211
+ if __name__ == "__main__":
212
+ main()
in/semantic_class_0.png ADDED