slau8405 commited on
Commit
f3ed335
·
1 Parent(s): 4530dbe

Added few more features

Browse files
Files changed (1) hide show
  1. app.py +54 -29
app.py CHANGED
@@ -10,6 +10,11 @@ import uuid
10
  import random
11
  from pathlib import Path
12
  import colorsys
 
 
 
 
 
13
 
14
  class PolygonAugmentation:
15
  def __init__(self, tolerance=0.2, area_threshold=0.01, debug=False):
@@ -25,7 +30,7 @@ class PolygonAugmentation:
25
  poly_np = np.array(points, dtype=np.float32)
26
  area = cv2.contourArea(poly_np)
27
  if self.debug:
28
- print(f"[DEBUG] Calculating polygon area: {area:.2f}")
29
  return area
30
 
31
  def load_labelme_data(self, json_file: Any, image: np.ndarray) -> Tuple:
@@ -60,7 +65,7 @@ class PolygonAugmentation:
60
  for shape in shapes:
61
  if shape.get('shape_type') != 'polygon' or not shape.get('points') or len(shape['points']) < 3:
62
  if self.debug:
63
- print(f"[DEBUG] Skipping invalid shape: {shape}")
64
  continue
65
  try:
66
  points = [[float(x), float(y)] for x, y in shape['points']]
@@ -69,11 +74,11 @@ class PolygonAugmentation:
69
  original_areas.append(self.calculate_polygon_area(points))
70
  except (ValueError, TypeError) as e:
71
  if self.debug:
72
- print(f"[DEBUG] Error processing points: {shape['points']}, error: {str(e)}")
73
  continue
74
 
75
  if not polygons and self.debug:
76
- print(f"[DEBUG] Warning: No valid polygons in JSON")
77
  return image, polygons, labels, original_areas, data, "input"
78
 
79
  def simplify_polygon(self, polygon: List[List[float]], tolerance: float = None, label: str = None) -> List[List[float]]:
@@ -81,25 +86,25 @@ class PolygonAugmentation:
81
  if label and label.lower() in ['background', 'bg', 'back']:
82
  tol = tol * 3
83
  if self.debug:
84
- print(f"[DEBUG] Using increased tolerance {tol} for background label '{label}'")
85
 
86
  if len(polygon) < 3:
87
  if self.debug:
88
- print(f"[DEBUG] Polygon has fewer than 3 points, skipping simplification.")
89
  return polygon
90
  poly_np = np.array(polygon, dtype=np.float32)
91
  approx = cv2.approxPolyDP(poly_np, tol, closed=True)
92
  simplified = approx.reshape(-1, 2).tolist()
93
 
94
  if self.debug:
95
- print(f"[DEBUG] Simplified polygon from {len(polygon)} to {len(simplified)} points with tolerance {tol}")
96
  return simplified
97
 
98
  def create_donut_polygon(self, external_contour: np.ndarray, internal_contours: List[np.ndarray]) -> List[List[float]]:
99
  external_points = external_contour.reshape(-1, 2).tolist()
100
  if not internal_contours:
101
  if self.debug:
102
- print("[DEBUG] No internal contours found, returning external points.")
103
  return external_points
104
 
105
  result_points = external_points.copy()
@@ -122,7 +127,7 @@ class PolygonAugmentation:
122
  bridge_from = internal_points[int_idx]
123
 
124
  if self.debug:
125
- print(f"[DEBUG] Creating bridge between external index {ext_idx} and internal index {int_idx}, distance {min_dist:.2f}")
126
 
127
  new_points = (
128
  result_points[:ext_idx+1] +
@@ -180,7 +185,7 @@ class PolygonAugmentation:
180
  poly_np = np.array(poly, dtype=np.int32)
181
  if len(poly_np) < 3:
182
  if self.debug:
183
- print(f"[DEBUG] Skipping polygon {poly_idx}: fewer than 3 points")
184
  continue
185
  mask = np.zeros((height, width), dtype=np.uint8)
186
  cv2.fillPoly(mask, [poly_np], 1)
@@ -188,7 +193,7 @@ class PolygonAugmentation:
188
  all_labels.append(label)
189
  except Exception as e:
190
  if self.debug:
191
- print(f"[DEBUG] Error processing polygon {poly_idx}: {str(e)}")
192
 
193
  if not all_masks:
194
  return np.zeros((0, height, width), dtype=np.uint8), []
@@ -217,7 +222,7 @@ class PolygonAugmentation:
217
  all_polygons.append(poly_labelme)
218
  all_labels.append(label)
219
  if self.debug:
220
- print(f"[DEBUG] Added simplified external polygon with {len(poly_labelme)} points.")
221
 
222
  for internal_contour in internal_contours:
223
  internal_points = internal_contour.reshape(-1, 2).tolist()
@@ -230,7 +235,7 @@ class PolygonAugmentation:
230
  all_polygons.append(poly_labelme)
231
  all_labels.append(label)
232
  if self.debug:
233
- print(f"[DEBUG] Added simplified internal polygon with {len(poly_labelme)} points.")
234
 
235
  def masks_to_labelme_polygons(
236
  self,
@@ -249,12 +254,12 @@ class PolygonAugmentation:
249
  for mask_idx, (mask, label) in enumerate(zip(masks, labels)):
250
  if mask.sum() < 10:
251
  if self.debug:
252
- print(f"[DEBUG] Skipping mask {mask_idx}: very small or empty.")
253
  continue
254
  contours, hierarchy = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
255
  if hierarchy is None or len(contours) == 0:
256
  if self.debug:
257
- print(f"[DEBUG] No contours found in mask {mask_idx}.")
258
  continue
259
 
260
  hierarchy = hierarchy[0]
@@ -274,7 +279,7 @@ class PolygonAugmentation:
274
 
275
  if not external_contours:
276
  if self.debug:
277
- print(f"[DEBUG] No external contours found in mask {mask_idx}.")
278
  continue
279
 
280
  for ext_idx, external_contour in enumerate(external_contours):
@@ -286,7 +291,7 @@ class PolygonAugmentation:
286
  relative_area = ext_area / original_areas[mask_idx]
287
  if relative_area < area_thresh:
288
  if self.debug:
289
- print(f"[DEBUG] Skipping contour {ext_idx} (area too small: {relative_area:.4f})")
290
  continue
291
 
292
  is_background = label.lower() in ['background', 'bg', 'back']
@@ -301,10 +306,10 @@ class PolygonAugmentation:
301
  all_polygons.append(poly_labelme)
302
  all_labels.append(label)
303
  if self.debug:
304
- print(f"[DEBUG] Added donut polygon with {len(poly_labelme)} points.")
305
  except Exception as e:
306
  if self.debug:
307
- print(f"[DEBUG] Error creating donut: {str(e)}, fallback to separate polygons.")
308
  self.process_contours(
309
  external_contour, internal_contours, width, height,
310
  label, all_polygons, all_labels, tol
@@ -327,6 +332,7 @@ class PolygonAugmentation:
327
  aug_type: str,
328
  aug_param: float
329
  ) -> Tuple[np.ndarray, Dict[str, Any]]:
 
330
  height, width = image.shape[:2]
331
  crop_scale = random.uniform(0.8, 0.9)
332
  crop_height = int(height * crop_scale)
@@ -351,22 +357,28 @@ class PolygonAugmentation:
351
  transform = A.Compose([
352
  aug_dict[aug_type],
353
  A.RandomCrop(width=crop_width, height=crop_height, p=0.8)
354
- ])
355
 
356
  masks, mask_labels = self.polygons_to_masks(image, polygons, labels)
357
  if masks.shape[0] == 0:
358
  raise ValueError("No valid masks created from polygons")
359
 
 
360
  aug_result = transform(image=image, masks=masks)
361
  aug_image = aug_result['image']
362
  aug_masks = aug_result['masks']
363
 
 
 
 
 
364
  aug_polygons, aug_labels = self.masks_to_labelme_polygons(
365
  aug_masks, mask_labels, original_areas, self.area_threshold, self.tolerance
366
  )
367
 
368
  aug_data = self.save_augmented_data(aug_image, aug_polygons, aug_labels, original_data, "input")
369
 
 
370
  return aug_image, aug_data
371
 
372
  def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param: float):
@@ -387,11 +399,13 @@ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param:
387
  raise ValueError(f"Parameter {aug_param} for {aug_type} is out of range [{min_val}, {max_val}]")
388
 
389
  # Convert PIL image to NumPy
 
 
390
  img_np = np.array(image)
391
  img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
392
 
393
  # Initialize augmenter
394
- augmenter = PolygonAugmentation(tolerance=2.0, area_threshold=0.01, debug=False)
395
 
396
  # Load data
397
  img_np, polygons, labels, original_areas, original_data, _ = augmenter.load_labelme_data(json_file, img_np)
@@ -401,6 +415,10 @@ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param:
401
  img_np, polygons, labels, original_areas, original_data, aug_type, aug_param
402
  )
403
 
 
 
 
 
404
  # Create a dynamic color map for unique labels
405
  unique_labels = list(set(shape['label'] for shape in aug_data['shapes']))
406
  if not unique_labels:
@@ -416,22 +434,29 @@ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param:
416
 
417
  # Convert augmented image to RGB for visualization
418
  aug_image_rgb = cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB)
 
 
 
 
419
  overlay = aug_image_rgb.copy()
420
 
421
- # Create masks for visualization
422
  height, width = aug_image.shape[:2]
423
  for shape in aug_data['shapes']:
424
  label = shape['label']
425
- color = label_color_map.get(label, (0, 255, 0))
426
  points = np.array(shape['points'], dtype=np.int32)
 
 
 
427
 
428
- # Draw filled mask with transparency
429
  mask = np.zeros((height, width), dtype=np.uint8)
430
  cv2.fillPoly(mask, [points], 1)
431
  colored_mask = np.zeros_like(aug_image_rgb)
432
  colored_mask[mask == 1] = color
433
- alpha = 0.3
434
- cv2.addWeighted(colored_mask, alpha, overlay, 1 - alpha, 0, overlay)
435
 
436
  # Draw polygon outline
437
  cv2.polylines(overlay, [points], isClosed=True, color=color, thickness=2)
@@ -442,8 +467,10 @@ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param:
442
  # Format JSON for display
443
  aug_json_str = json.dumps(aug_data, indent=2)
444
 
 
445
  return aug_image_pil, aug_json_str
446
  except Exception as e:
 
447
  return None, f"Error: {str(e)}"
448
 
449
  # Define augmentation types and parameter ranges
@@ -459,7 +486,7 @@ aug_options = {
459
  def create_interface():
460
  with gr.Blocks(title="Donut Polygon Augmentation") as demo:
461
  gr.Markdown("# Donut Polygon Augmentation 🌀")
462
- gr.Markdown("Upload an image and a LabelMe JSON file to apply topology-preserving augmentation to donut-shaped polygons. Each class is visualized with a unique color and semi-transparent mask.")
463
 
464
  with gr.Row():
465
  with gr.Column():
@@ -479,7 +506,6 @@ def create_interface():
479
  )
480
 
481
  def update_slider(display_name):
482
- # Map display_name back to internal key
483
  aug_key = next(k for k, v in aug_options.items() if v["display_name"] == display_name)
484
  return {
485
  aug_param: gr.update(
@@ -500,7 +526,6 @@ def create_interface():
500
  output_json = gr.Textbox(label="Augmented LabelMe JSON", lines=10, max_lines=20)
501
 
502
  def submit(image, json_file, display_name, aug_param):
503
- # Map display_name to internal key
504
  aug_key = next(k for k, v in aug_options.items() if v["display_name"] == display_name)
505
  return augment_image(image, json_file, aug_key, aug_param)
506
 
 
10
  import random
11
  from pathlib import Path
12
  import colorsys
13
+ import logging
14
+
15
+ # Set up logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
 
19
  class PolygonAugmentation:
20
  def __init__(self, tolerance=0.2, area_threshold=0.01, debug=False):
 
30
  poly_np = np.array(points, dtype=np.float32)
31
  area = cv2.contourArea(poly_np)
32
  if self.debug:
33
+ logger.info(f"[DEBUG] Calculating polygon area: {area:.2f}")
34
  return area
35
 
36
  def load_labelme_data(self, json_file: Any, image: np.ndarray) -> Tuple:
 
65
  for shape in shapes:
66
  if shape.get('shape_type') != 'polygon' or not shape.get('points') or len(shape['points']) < 3:
67
  if self.debug:
68
+ logger.info(f"[DEBUG] Skipping invalid shape: {shape}")
69
  continue
70
  try:
71
  points = [[float(x), float(y)] for x, y in shape['points']]
 
74
  original_areas.append(self.calculate_polygon_area(points))
75
  except (ValueError, TypeError) as e:
76
  if self.debug:
77
+ logger.info(f"[DEBUG] Error processing points: {shape['points']}, error: {str(e)}")
78
  continue
79
 
80
  if not polygons and self.debug:
81
+ logger.info(f"[DEBUG] Warning: No valid polygons in JSON")
82
  return image, polygons, labels, original_areas, data, "input"
83
 
84
  def simplify_polygon(self, polygon: List[List[float]], tolerance: float = None, label: str = None) -> List[List[float]]:
 
86
  if label and label.lower() in ['background', 'bg', 'back']:
87
  tol = tol * 3
88
  if self.debug:
89
+ logger.info(f"[DEBUG] Using increased tolerance {tol} for background label '{label}'")
90
 
91
  if len(polygon) < 3:
92
  if self.debug:
93
+ logger.info(f"[DEBUG] Polygon has fewer than 3 points, skipping simplification.")
94
  return polygon
95
  poly_np = np.array(polygon, dtype=np.float32)
96
  approx = cv2.approxPolyDP(poly_np, tol, closed=True)
97
  simplified = approx.reshape(-1, 2).tolist()
98
 
99
  if self.debug:
100
+ logger.info(f"[DEBUG] Simplified polygon from {len(polygon)} to {len(simplified)} points with tolerance {tol}")
101
  return simplified
102
 
103
  def create_donut_polygon(self, external_contour: np.ndarray, internal_contours: List[np.ndarray]) -> List[List[float]]:
104
  external_points = external_contour.reshape(-1, 2).tolist()
105
  if not internal_contours:
106
  if self.debug:
107
+ logger.info("[DEBUG] No internal contours found, returning external points.")
108
  return external_points
109
 
110
  result_points = external_points.copy()
 
127
  bridge_from = internal_points[int_idx]
128
 
129
  if self.debug:
130
+ logger.info(f"[DEBUG] Creating bridge between external index {ext_idx} and internal index {int_idx}, distance {min_dist:.2f}")
131
 
132
  new_points = (
133
  result_points[:ext_idx+1] +
 
185
  poly_np = np.array(poly, dtype=np.int32)
186
  if len(poly_np) < 3:
187
  if self.debug:
188
+ logger.info(f"[DEBUG] Skipping polygon {poly_idx}: fewer than 3 points")
189
  continue
190
  mask = np.zeros((height, width), dtype=np.uint8)
191
  cv2.fillPoly(mask, [poly_np], 1)
 
193
  all_labels.append(label)
194
  except Exception as e:
195
  if self.debug:
196
+ logger.info(f"[DEBUG] Error processing polygon {poly_idx}: {str(e)}")
197
 
198
  if not all_masks:
199
  return np.zeros((0, height, width), dtype=np.uint8), []
 
222
  all_polygons.append(poly_labelme)
223
  all_labels.append(label)
224
  if self.debug:
225
+ logger.info(f"[DEBUG] Added simplified external polygon with {len(poly_labelme)} points.")
226
 
227
  for internal_contour in internal_contours:
228
  internal_points = internal_contour.reshape(-1, 2).tolist()
 
235
  all_polygons.append(poly_labelme)
236
  all_labels.append(label)
237
  if self.debug:
238
+ logger.info(f"[DEBUG] Added simplified internal polygon with {len(poly_labelme)} points.")
239
 
240
  def masks_to_labelme_polygons(
241
  self,
 
254
  for mask_idx, (mask, label) in enumerate(zip(masks, labels)):
255
  if mask.sum() < 10:
256
  if self.debug:
257
+ logger.info(f"[DEBUG] Skipping mask {mask_idx}: very small or empty.")
258
  continue
259
  contours, hierarchy = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
260
  if hierarchy is None or len(contours) == 0:
261
  if self.debug:
262
+ logger.info(f"[DEBUG] No contours found in mask {mask_idx}.")
263
  continue
264
 
265
  hierarchy = hierarchy[0]
 
279
 
280
  if not external_contours:
281
  if self.debug:
282
+ logger.info(f"[DEBUG] No external contours found in mask {mask_idx}.")
283
  continue
284
 
285
  for ext_idx, external_contour in enumerate(external_contours):
 
291
  relative_area = ext_area / original_areas[mask_idx]
292
  if relative_area < area_thresh:
293
  if self.debug:
294
+ logger.info(f"[DEBUG] Skipping contour {ext_idx} (area too small: {relative_area:.4f})")
295
  continue
296
 
297
  is_background = label.lower() in ['background', 'bg', 'back']
 
306
  all_polygons.append(poly_labelme)
307
  all_labels.append(label)
308
  if self.debug:
309
+ logger.info(f"[DEBUG] Added donut polygon with {len(poly_labelme)} points.")
310
  except Exception as e:
311
  if self.debug:
312
+ logger.info(f"[DEBUG] Error creating donut: {str(e)}, fallback to separate polygons.")
313
  self.process_contours(
314
  external_contour, internal_contours, width, height,
315
  label, all_polygons, all_labels, tol
 
332
  aug_type: str,
333
  aug_param: float
334
  ) -> Tuple[np.ndarray, Dict[str, Any]]:
335
+ logger.info(f"Applying augmentation: {aug_type} with parameter {aug_param}")
336
  height, width = image.shape[:2]
337
  crop_scale = random.uniform(0.8, 0.9)
338
  crop_height = int(height * crop_scale)
 
357
  transform = A.Compose([
358
  aug_dict[aug_type],
359
  A.RandomCrop(width=crop_width, height=crop_height, p=0.8)
360
+ ], additional_targets={'mask': 'mask'})
361
 
362
  masks, mask_labels = self.polygons_to_masks(image, polygons, labels)
363
  if masks.shape[0] == 0:
364
  raise ValueError("No valid masks created from polygons")
365
 
366
+ # Ensure masks are processed correctly
367
  aug_result = transform(image=image, masks=masks)
368
  aug_image = aug_result['image']
369
  aug_masks = aug_result['masks']
370
 
371
+ # Validate augmented image
372
+ if aug_image is None or aug_image.size == 0:
373
+ raise ValueError("Augmented image is empty or invalid")
374
+
375
  aug_polygons, aug_labels = self.masks_to_labelme_polygons(
376
  aug_masks, mask_labels, original_areas, self.area_threshold, self.tolerance
377
  )
378
 
379
  aug_data = self.save_augmented_data(aug_image, aug_polygons, aug_labels, original_data, "input")
380
 
381
+ logger.info(f"Augmentation completed: {len(aug_polygons)} polygons generated")
382
  return aug_image, aug_data
383
 
384
  def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param: float):
 
399
  raise ValueError(f"Parameter {aug_param} for {aug_type} is out of range [{min_val}, {max_val}]")
400
 
401
  # Convert PIL image to NumPy
402
+ if image is None:
403
+ raise ValueError("Input image is None")
404
  img_np = np.array(image)
405
  img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
406
 
407
  # Initialize augmenter
408
+ augmenter = PolygonAugmentation(tolerance=2.0, area_threshold=0.01, debug=True)
409
 
410
  # Load data
411
  img_np, polygons, labels, original_areas, original_data, _ = augmenter.load_labelme_data(json_file, img_np)
 
415
  img_np, polygons, labels, original_areas, original_data, aug_type, aug_param
416
  )
417
 
418
+ # Validate augmented image
419
+ if aug_image is None or aug_image.size == 0:
420
+ raise ValueError("Augmented image is empty or invalid")
421
+
422
  # Create a dynamic color map for unique labels
423
  unique_labels = list(set(shape['label'] for shape in aug_data['shapes']))
424
  if not unique_labels:
 
434
 
435
  # Convert augmented image to RGB for visualization
436
  aug_image_rgb = cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB)
437
+ if aug_image_rgb is None or aug_image_rgb.size == 0:
438
+ raise ValueError("Failed to convert augmented image to RGB")
439
+
440
+ # Create a clean copy of the augmented image for visualization
441
  overlay = aug_image_rgb.copy()
442
 
443
+ # Create masks and outlines for visualization
444
  height, width = aug_image.shape[:2]
445
  for shape in aug_data['shapes']:
446
  label = shape['label']
447
+ color = label_color_map.get(label, (0, 255, 0)) # Fallback to green
448
  points = np.array(shape['points'], dtype=np.int32)
449
+ if len(points) < 3:
450
+ logger.warning(f"Skipping invalid polygon for label {label}: fewer than 3 points")
451
+ continue
452
 
453
+ # Draw semi-transparent mask
454
  mask = np.zeros((height, width), dtype=np.uint8)
455
  cv2.fillPoly(mask, [points], 1)
456
  colored_mask = np.zeros_like(aug_image_rgb)
457
  colored_mask[mask == 1] = color
458
+ alpha = 0.3 # Transparency for mask
459
+ overlay = cv2.addWeighted(overlay, 1.0, colored_mask, alpha, 0.0)
460
 
461
  # Draw polygon outline
462
  cv2.polylines(overlay, [points], isClosed=True, color=color, thickness=2)
 
467
  # Format JSON for display
468
  aug_json_str = json.dumps(aug_data, indent=2)
469
 
470
+ logger.info("Visualization completed successfully")
471
  return aug_image_pil, aug_json_str
472
  except Exception as e:
473
+ logger.error(f"Error in augment_image: {str(e)}")
474
  return None, f"Error: {str(e)}"
475
 
476
  # Define augmentation types and parameter ranges
 
486
  def create_interface():
487
  with gr.Blocks(title="Donut Polygon Augmentation") as demo:
488
  gr.Markdown("# Donut Polygon Augmentation 🌀")
489
+ gr.Markdown("Upload an image and a LabelMe JSON file to apply topology-preserving augmentation to donut-shaped polygons. Each class is visualized with a unique color and semi-transparent mask over the augmented image.")
490
 
491
  with gr.Row():
492
  with gr.Column():
 
506
  )
507
 
508
  def update_slider(display_name):
 
509
  aug_key = next(k for k, v in aug_options.items() if v["display_name"] == display_name)
510
  return {
511
  aug_param: gr.update(
 
526
  output_json = gr.Textbox(label="Augmented LabelMe JSON", lines=10, max_lines=20)
527
 
528
  def submit(image, json_file, display_name, aug_param):
 
529
  aug_key = next(k for k, v in aug_options.items() if v["display_name"] == display_name)
530
  return augment_image(image, json_file, aug_key, aug_param)
531