slau8405 commited on
Commit
2df9731
·
1 Parent(s): e2b6524

Added few more features

Browse files
Files changed (2) hide show
  1. app.py +439 -27
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,34 +1,446 @@
1
- # -*- coding: utf-8 -*-
2
  import gradio as gr
3
  from PIL import Image
4
  import numpy as np
5
  import cv2
 
 
 
 
 
 
 
6
 
7
- def dummy_polygon_augment(image: Image.Image):
8
- # Convert to NumPy
9
- img = np.array(image)
10
-
11
- # Dummy augmentation: Draw a donut-like polygon
12
- overlay = img.copy()
13
- height, width = img.shape[:2]
14
- center = (width // 2, height // 2)
15
- outer_radius = min(width, height) // 4
16
- inner_radius = outer_radius // 2
17
-
18
- cv2.circle(overlay, center, outer_radius, (0, 255, 0), -1)
19
- cv2.circle(overlay, center, inner_radius, (0, 0, 0), -1)
20
-
21
- alpha = 0.4
22
- cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
23
-
24
- return Image.fromarray(img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- demo = gr.Interface(
27
- fn=dummy_polygon_augment,
28
- inputs=gr.Image(type="pil"),
29
- outputs=gr.Image(type="pil"),
30
- title="Donut Polygon Augmentation",
31
- description="Upload an image to visualize a sample donut-shaped polygon augmentation."
32
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- demo.launch()
 
 
 
 
1
  import gradio as gr
2
  from PIL import Image
3
  import numpy as np
4
  import cv2
5
+ import json
6
+ import albumentations as A
7
+ from typing import List, Tuple, Dict, Any
8
+ import supervision as sv
9
+ import uuid
10
+ import random
11
+ from pathlib import Path
12
 
13
+ class PolygonAugmentation:
14
+ def __init__(self, tolerance=0.2, area_threshold=0.01, debug=False):
15
+ self.tolerance = tolerance
16
+ self.area_threshold = area_threshold
17
+ self.debug = debug
18
+ self.supported_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.PNG', '.JPEG']
19
+
20
+ def __getattr__(self, name: str) -> Any:
21
+ raise AttributeError(f"'PolygonAugmentation' object has no attribute '{name}'")
22
+
23
+ def calculate_polygon_area(self, points: List[List[float]]) -> float:
24
+ poly_np = np.array(points, dtype=np.float32)
25
+ area = cv2.contourArea(poly_np)
26
+ if self.debug:
27
+ print(f"[DEBUG] Calculating polygon area: {area:.2f}")
28
+ return area
29
+
30
+ def load_labelme_data(self, json_file: Any, image: np.ndarray) -> Tuple:
31
+ if isinstance(json_file, str):
32
+ with open(json_file, 'r', encoding='utf-8') as f:
33
+ data = json.load(f)
34
+ else:
35
+ data = json.load(json_file)
36
+ if 'shapes' not in data or not isinstance(data['shapes'], list):
37
+ raise ValueError("Invalid JSON: 'shapes' key missing or not a list")
38
+
39
+ polygons = []
40
+ labels = []
41
+ original_areas = []
42
+
43
+ for shape in data['shapes']:
44
+ if shape.get('shape_type') != 'polygon' or not shape.get('points') or len(shape['points']) < 3:
45
+ if self.debug:
46
+ print(f"[DEBUG] Skipping invalid shape")
47
+ continue
48
+ try:
49
+ points = [[float(x), float(y)] for x, y in shape['points']]
50
+ polygons.append(points)
51
+ labels.append(shape['label'])
52
+ original_areas.append(self.calculate_polygon_area(points))
53
+ except (ValueError, TypeError):
54
+ if self.debug:
55
+ print(f"[DEBUG] Error processing points: {shape['points']}")
56
+ continue
57
+
58
+ if not polygons and self.debug:
59
+ print(f"[DEBUG] Warning: No valid polygons in JSON")
60
+ return image, polygons, labels, original_areas, data, "input"
61
+
62
+ def simplify_polygon(self, polygon: List[List[float]], tolerance: float = None, label: str = None) -> List[List[float]]:
63
+ tol = tolerance if tolerance is not None else self.tolerance
64
+ if label and label.lower() in ['background', 'bg', 'back']:
65
+ tol = tol * 3
66
+ if self.debug:
67
+ print(f"[DEBUG] Using increased tolerance {tol} for background label '{label}'")
68
+
69
+ if len(polygon) < 3:
70
+ if self.debug:
71
+ print(f"[DEBUG] Polygon has fewer than 3 points, skipping simplification.")
72
+ return polygon
73
+ poly_np = np.array(polygon, dtype=np.float32)
74
+ approx = cv2.approxPolyDP(poly_np, tol, closed=True)
75
+ simplified = approx.reshape(-1, 2).tolist()
76
+
77
+ if self.debug:
78
+ print(f"[DEBUG] Simplified polygon from {len(polygon)} to {len(simplified)} points with tolerance {tol}")
79
+ return simplified
80
+
81
+ def create_donut_polygon(self, external_contour: np.ndarray, internal_contours: List[np.ndarray]) -> List[List[float]]:
82
+ external_points = external_contour.reshape(-1, 2).tolist()
83
+ if not internal_contours:
84
+ if self.debug:
85
+ print("[DEBUG] No internal contours found, returning external points.")
86
+ return external_points
87
+
88
+ result_points = external_points.copy()
89
+
90
+ for internal_contour in internal_contours:
91
+ internal_points = internal_contour.reshape(-1, 2).tolist()
92
+ min_dist = float('inf')
93
+ ext_idx = 0
94
+ int_idx = 0
95
+
96
+ for i, p1 in enumerate(external_points):
97
+ for j, p2 in enumerate(internal_points):
98
+ dist = np.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)
99
+ if dist < min_dist:
100
+ min_dist = dist
101
+ ext_idx = i
102
+ int_idx = j
103
+
104
+ bridge_to = external_points[ext_idx]
105
+ bridge_from = internal_points[int_idx]
106
+
107
+ if self.debug:
108
+ print(f"[DEBUG] Creating bridge between external index {ext_idx} and internal index {int_idx}, distance {min_dist:.2f}")
109
+
110
+ new_points = (
111
+ result_points[:ext_idx+1] +
112
+ internal_points[int_idx:] + internal_points[:int_idx+1] +
113
+ [bridge_to] +
114
+ external_points[ext_idx+1:]
115
+ )
116
+ result_points = new_points
117
+
118
+ return result_points
119
+
120
+ def save_augmented_data(
121
+ self,
122
+ aug_image: np.ndarray,
123
+ aug_polygons: List[List[List[float]]],
124
+ aug_labels: List[str],
125
+ original_data: Dict[str, Any],
126
+ base_name: str
127
+ ) -> Dict[str, Any]:
128
+ aug_id = uuid.uuid4().hex[:4]
129
+ aug_img_name = f"{base_name}_{aug_id}_aug.png"
130
+
131
+ new_shapes = []
132
+ for poly, label in zip(aug_polygons, aug_labels):
133
+ if not poly or len(poly) < 3:
134
+ continue
135
+ new_shapes.append({
136
+ "label": label,
137
+ "points": poly,
138
+ "group_id": None,
139
+ "shape_type": "polygon",
140
+ "flags": {}
141
+ })
142
+
143
+ aug_data = {
144
+ "version": original_data.get("version", "5.0.1"),
145
+ "flags": original_data.get("flags", {}),
146
+ "shapes": new_shapes,
147
+ "imagePath": aug_img_name,
148
+ "imageData": None,
149
+ "imageHeight": aug_image.shape[0],
150
+ "imageWidth": aug_image.shape[1]
151
+ }
152
+
153
+ return aug_data
154
 
155
+ def polygons_to_masks(self, image: np.ndarray, polygons: List[List[List[float]]], labels: List[str]) -> Tuple[np.ndarray, List[str]]:
156
+ height, width = image.shape[:2]
157
+ all_masks = []
158
+ all_labels = []
159
+
160
+ for poly_idx, (poly, label) in enumerate(zip(polygons, labels)):
161
+ try:
162
+ poly_np = np.array(poly, dtype=np.int32)
163
+ if len(poly_np) < 3:
164
+ if self.debug:
165
+ print(f"[DEBUG] Skipping polygon {poly_idx}: fewer than 3 points")
166
+ continue
167
+ mask = np.zeros((height, width), dtype=np.uint8)
168
+ cv2.fillPoly(mask, [poly_np], 1)
169
+ all_masks.append(mask)
170
+ all_labels.append(label)
171
+ except Exception as e:
172
+ if self.debug:
173
+ print(f"[DEBUG] Error processing polygon {poly_idx}: {str(e)}")
174
+
175
+ if not all_masks:
176
+ return np.zeros((0, height, width), dtype=np.uint8), []
177
+
178
+ return np.array(all_masks, dtype=np.uint8), all_labels
179
+
180
+ def process_contours(
181
+ self,
182
+ external_contour: np.ndarray,
183
+ internal_contours: List[np.ndarray],
184
+ width: int,
185
+ height: int,
186
+ label: str,
187
+ all_polygons: List[List[List[float]]],
188
+ all_labels: List[str],
189
+ tolerance: float = None
190
+ ) -> None:
191
+ tol = tolerance if tolerance is not None else self.tolerance
192
+ external_points = external_contour.reshape(-1, 2).tolist()
193
+ simplified_external = self.simplify_polygon(external_points, tolerance=tol, label=label)
194
+
195
+ if len(simplified_external) >= 3:
196
+ poly_labelme = [[round(max(0, min(float(x), width - 1)), 2),
197
+ round(max(0, min(float(y), height - 1)), 2)]
198
+ for x, y in simplified_external]
199
+ all_polygons.append(poly_labelme)
200
+ all_labels.append(label)
201
+ if self.debug:
202
+ print(f"[DEBUG] Added simplified external Nghia with {len(poly_labelme)} points.")
203
+
204
+ for internal_contour in internal_contours:
205
+ internal_points = internal_contour.reshape(-1, 2).tolist()
206
+ simplified_internal = self.simplify_polygon(internal_points, tolerance=tol, label=label)
207
+
208
+ if len(simplified_internal) >= 3:
209
+ poly_labelme = [[round(max(0, min(float(x), width - 1)), 2),
210
+ round(max(0, min(float(y), height - 1)), 2)]
211
+ for x, y in simplified_internal]
212
+ all_polygons.append(poly_labelme)
213
+ all_labels.append(label)
214
+ if self.debug:
215
+ print(f"[DEBUG] Added simplified internal polygon with {len(poly_labelme)} points.")
216
+
217
+ def masks_to_labelme_polygons(
218
+ self,
219
+ masks: np.ndarray,
220
+ labels: List[str],
221
+ original_areas: List[float],
222
+ area_threshold: float = None,
223
+ tolerance: float = None
224
+ ) -> Tuple[List[List[List[float]]], List[str]]:
225
+ tol = tolerance if tolerance is not None else self.tolerance
226
+ area_thresh = area_threshold if area_threshold is not None else self.area_threshold
227
+ height, width = masks[0].shape if len(masks) > 0 else (0, 0)
228
+ all_polygons = []
229
+ all_labels = []
230
+
231
+ for mask_idx, (mask, label) in enumerate(zip(masks, labels)):
232
+ if mask.sum() < 10:
233
+ if self.debug:
234
+ print(f"[DEBUG] Skipping mask {mask_idx}: very small or empty.")
235
+ continue
236
+ contours, hierarchy = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
237
+ if hierarchy is None or len(contours) == 0:
238
+ if self.debug:
239
+ print(f"[DEBUG] No contours found in mask {mask_idx}.")
240
+ continue
241
+
242
+ hierarchy = hierarchy[0]
243
+ external_contours = []
244
+ internal_contours_map = {}
245
+
246
+ for i, (contour, h) in enumerate(zip(contours, hierarchy)):
247
+ if h[3] == -1:
248
+ external_contours.append(contour)
249
+ internal_contours_map[len(external_contours)-1] = []
250
+ else:
251
+ parent_idx = h[3]
252
+ for j, _ in enumerate(external_contours):
253
+ if parent_idx == j:
254
+ internal_contours_map[j].append(contour)
255
+ break
256
+
257
+ if not external_contours:
258
+ if self.debug:
259
+ print(f"[DEBUG] No external contours found in mask {mask_idx}.")
260
+ continue
261
+
262
+ for ext_idx, external_contour in enumerate(external_contours):
263
+ internal_contours = internal_contours_map.get(ext_idx, [])
264
+ ext_area = cv2.contourArea(external_contour)
265
+ if ext_area <= 0:
266
+ continue
267
+ if mask_idx < len(original_areas) and original_areas[mask_idx] > 0:
268
+ relative_area = ext_area / original_areas[mask_idx]
269
+ if relative_area < area_thresh:
270
+ if self.debug:
271
+ print(f"[DEBUG] Skipping contour {ext_idx} (area too small: {relative_area:.4f})")
272
+ continue
273
+
274
+ is_background = label.lower() in ['background', 'bg', 'back']
275
+ if is_background and internal_contours:
276
+ try:
277
+ donut_points = self.create_donut_polygon(external_contour, internal_contours)
278
+ simplified_donut = self.simplify_polygon(donut_points, tolerance=tol, label=label)
279
+ if len(simplified_donut) >= 3:
280
+ poly_labelme = [[round(max(0, min(float(x), width - 1)), 2),
281
+ round(max(0, min(float(y), height - 1)), 2)]
282
+ for x, y in simplified_donut]
283
+ all_polygons.append(poly_labelme)
284
+ all_labels.append(label)
285
+ if self.debug:
286
+ print(f"[DEBUG] Added donut polygon with {len(poly_labelme)} points.")
287
+ except Exception as e:
288
+ if self.debug:
289
+ print(f"[DEBUG] Error creating donut: {str(e)}, fallback to separate polygons.")
290
+ self.process_contours(
291
+ external_contour, internal_contours, width, height,
292
+ label, all_polygons, all_labels, tol
293
+ )
294
+ else:
295
+ self.process_contours(
296
+ external_contour, internal_contours, width, height,
297
+ label, all_polygons, all_labels, tol
298
+ )
299
+
300
+ return all_polygons, all_labels
301
+
302
+ def augment_single_image(
303
+ self,
304
+ image: np.ndarray,
305
+ polygons: List[List[List[float]]],
306
+ labels: List[str],
307
+ original_areas: List[float],
308
+ original_data: Dict[str, Any],
309
+ aug_type: str,
310
+ aug_param: float
311
+ ) -> Tuple[np.ndarray, Dict[str, Any]]:
312
+ height, width = image.shape[:2]
313
+ crop_scale = random.uniform(0.8, 0.9)
314
+ crop_height = int(height * crop_scale)
315
+ crop_width = int(width * crop_scale)
316
+
317
+ aug_dict = {
318
+ "rotate": A.Rotate(limit=aug_param, p=1.0),
319
+ "horizontal_flip": A.HorizontalFlip(p=1.0 if aug_param > 0 else 0.0),
320
+ "vertical_flip": A.VerticalFlip(p=1.0 if aug_param > 0 else 0.0),
321
+ "scale": A.Affine(scale=aug_param, p=1.0),
322
+ "brightness_contrast": A.RandomBrightnessContrast(
323
+ brightness_limit=aug_param,
324
+ contrast_limit=aug_param,
325
+ p=1.0
326
+ ),
327
+ "pixel_dropout": A.PixelDropout(dropout_prob=aug_param, p=1.0)
328
+ }
329
+
330
+ if aug_type not in aug_dict:
331
+ raise ValueError(f"Unsupported augmentation type: {aug_type}")
332
+
333
+ transform = A.Compose([
334
+ aug_dict[aug_type],
335
+ A.RandomCrop(width=crop_width, height=crop_height, p=0.8)
336
+ ])
337
+
338
+ masks, mask_labels = self.polygons_to_masks(image, polygons, labels)
339
+ if masks.shape[0] == 0:
340
+ raise ValueError("No valid masks created from polygons")
341
+
342
+ aug_result = transform(image=image, masks=masks)
343
+ aug_image = aug_result['image']
344
+ aug_masks = aug_result['masks']
345
+
346
+ aug_polygons, aug_labels = self.masks_to_labelme_polygons(
347
+ aug_masks, mask_labels, original_areas, self.area_threshold, self.tolerance
348
+ )
349
+
350
+ aug_data = self.save_augmented_data(aug_image, aug_polygons, aug_labels, original_data, "input")
351
+
352
+ return aug_image, aug_data
353
+
354
+ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param: float):
355
+ try:
356
+ # Convert PIL image to NumPy
357
+ img_np = np.array(image)
358
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
359
+
360
+ # Initialize augmenter
361
+ augmenter = PolygonAugmentation(tolerance=2.0, area_threshold=0.01, debug=False)
362
+
363
+ # Load data
364
+ img_np, polygons, labels, original_areas, original_data, _ = augmenter.load_labelme_data(json_file, img_np)
365
+
366
+ # Perform augmentation
367
+ aug_image, aug_data = augmenter.augment_single_image(
368
+ img_np, polygons, labels, original_areas, original_data, aug_type, aug_param
369
+ )
370
+
371
+ # Overlay polygons on augmented image for visualization
372
+ aug_image_rgb = cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB)
373
+ for poly in aug_data['shapes']:
374
+ points = np.array(poly['points'], dtype=np.int32)
375
+ cv2.polylines(aug_image_rgb, [points], isClosed=True, color=(0, 255, 0), thickness=2)
376
+
377
+ # Convert augmented image back to PIL
378
+ aug_image_pil = Image.fromarray(aug_image_rgb)
379
+
380
+ # Format JSON for display
381
+ aug_json_str = json.dumps(aug_data, indent=2)
382
+
383
+ return aug_image_pil, aug_json_str
384
+ except Exception as e:
385
+ return None, f"Error: {str(e)}"
386
+
387
+ # Define augmentation types and parameter ranges
388
+ aug_options = {
389
+ "Rotate": {"param_name": "Angle (degrees)", "range": (-30, 30), "default": 0},
390
+ "Horizontal Flip": {"param_name": "Apply Flip (0 or 1)", "range": (0, 1), "default": 0},
391
+ "Vertical Flip": {"param_name": "Apply Flip (0 or 1)", "range": (0, 1), "default": 0},
392
+ "Scale": {"param_name": "Scale Factor", "range": (0.5, 1.5), "default": 1.0},
393
+ "Brightness/Contrast": {"param_name": "Brightness/Contrast Limit", "range": (-0.3, 0.3), "default": 0},
394
+ "Pixel Dropout": {"param_name": "Dropout Probability", "range": (0.01, 0.1), "default": 0.05}
395
+ }
396
+
397
+ def create_interface():
398
+ with gr.Blocks(title="Donut Polygon Augmentation") as demo:
399
+ gr.Markdown("# Donut Polygon Augmentation 🌀")
400
+ gr.Markdown("Upload an image and a LabelMe JSON file to apply topology-preserving augmentation to donut-shaped polygons.")
401
+
402
+ with gr.Row():
403
+ with gr.Column():
404
+ image_input = gr.Image(type="pil", label="Input Image")
405
+ json_input = gr.File(label="LabelMe JSON File", file_types=[".json"])
406
+ aug_type = gr.Dropdown(
407
+ choices=list(aug_options.keys()),
408
+ label="Augmentation Type",
409
+ value="Rotate"
410
+ )
411
+ aug_param = gr.Slider(
412
+ minimum=aug_options["Rotate"]["range"][0],
413
+ maximum=aug_options["Rotate"]["range"][1],
414
+ value=aug_options["Rotate"]["default"],
415
+ label=aug_options["Rotate"]["param_name"]
416
+ )
417
+
418
+ def update_slider(aug_type):
419
+ return {
420
+ aug_param: gr.update(
421
+ minimum=aug_options[aug_type]["range"][0],
422
+ maximum=aug_options[aug_type]["range"][1],
423
+ value=aug_options[aug_type]["default"],
424
+ label=aug_options[aug_type]["param_name"]
425
+ )
426
+ }
427
+
428
+ aug_type.change(fn=update_slider, inputs=aug_type, outputs=[aug_param])
429
+
430
+ submit_btn = gr.Button("Apply Augmentation")
431
+
432
+ with gr.Column():
433
+ output_image = gr.Image(type="pil", label="Augmented Image")
434
+ output_json = gr.Textbox(label="Augmented LabelMe JSON", lines=10, max_lines=20)
435
+
436
+ submit_btn.click(
437
+ fn=augment_image,
438
+ inputs=[image_input, json_input, aug_type, aug_param],
439
+ outputs=[output_image, output_json]
440
+ )
441
+
442
+ return demo
443
 
444
+ if __name__ == "__main__":
445
+ demo = create_interface()
446
+ demo.launch()
requirements.txt CHANGED
@@ -2,3 +2,5 @@ opencv-python-headless
2
  gradio==5.30.0
3
  Pillow
4
  numpy
 
 
 
2
  gradio==5.30.0
3
  Pillow
4
  numpy
5
+ albumentations
6
+ supervision