slau8405 commited on
Commit
c495fed
·
1 Parent(s): 8d41fcd

Added few more features

Browse files
Files changed (1) hide show
  1. app.py +36 -15
app.py CHANGED
@@ -9,6 +9,7 @@ import supervision as sv
9
  import uuid
10
  import random
11
  from pathlib import Path
 
12
 
13
  class PolygonAugmentation:
14
  def __init__(self, tolerance=0.2, area_threshold=0.01, debug=False):
@@ -34,12 +35,10 @@ class PolygonAugmentation:
34
  else:
35
  data = json.load(json_file)
36
 
37
- # Check for 'shapes' (LabelMe) or 'segments' (custom format)
38
  shapes = []
39
  if 'shapes' in data and isinstance(data['shapes'], list):
40
  shapes = data['shapes']
41
  elif 'segments' in data and isinstance(data['segments'], list):
42
- # Convert custom 'segments' to LabelMe 'shapes' format
43
  shapes = [
44
  {
45
  "label": seg.get("class", "unknown"),
@@ -52,7 +51,7 @@ class PolygonAugmentation:
52
  for seg in data['segments']
53
  ]
54
  else:
55
- raise ValueError("Invalid JSON: Neither 'shapes' nor ' segments' key found or not a list")
56
 
57
  polygons = []
58
  labels = []
@@ -156,7 +155,7 @@ class PolygonAugmentation:
156
  "group_id": None,
157
  "shape_type": "polygon",
158
  "flags": {},
159
- "confidence": 1.0 # Default confidence for augmented shapes
160
  })
161
 
162
  aug_data = {
@@ -343,7 +342,7 @@ class PolygonAugmentation:
343
  contrast_limit=aug_param,
344
  p=1.0
345
  ),
346
- "pixel_dropout": A.PixelDropout(dropout_prob=aug_param, p=1.0)
347
  }
348
 
349
  if aug_type not in aug_dict:
@@ -372,6 +371,21 @@ class PolygonAugmentation:
372
 
373
  def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param: float):
374
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  # Convert PIL image to NumPy
376
  img_np = np.array(image)
377
  img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
@@ -387,13 +401,18 @@ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param:
387
  img_np, polygons, labels, original_areas, original_data, aug_type, aug_param
388
  )
389
 
390
- # Create a color map for unique labels
391
  unique_labels = list(set(shape['label'] for shape in aug_data['shapes']))
392
- colors = [
393
- (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255),
394
- (128, 0, 0), (0, 128, 0), (0, 0, 128), (128, 128, 0) # Add more if needed
395
- ]
396
- label_color_map = {label: colors[i % len(colors)] for i, label in enumerate(unique_labels)}
 
 
 
 
 
397
 
398
  # Convert augmented image to RGB for visualization
399
  aug_image_rgb = cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB)
@@ -403,7 +422,7 @@ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param:
403
  height, width = aug_image.shape[:2]
404
  for shape in aug_data['shapes']:
405
  label = shape['label']
406
- color = label_color_map[label]
407
  points = np.array(shape['points'], dtype=np.int32)
408
 
409
  # Draw filled mask with transparency
@@ -411,7 +430,7 @@ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param:
411
  cv2.fillPoly(mask, [points], 1)
412
  colored_mask = np.zeros_like(aug_image_rgb)
413
  colored_mask[mask == 1] = color
414
- alpha = 0.3 # Transparency for mask
415
  cv2.addWeighted(colored_mask, alpha, overlay, 1 - alpha, 0, overlay)
416
 
417
  # Draw polygon outline
@@ -455,7 +474,8 @@ def create_interface():
455
  minimum=aug_options["Rotate"]["range"][0],
456
  maximum=aug_options["Rotate"]["range"][1],
457
  value=aug_options["Rotate"]["default"],
458
- label=aug_options["Rotate"]["param_name"]
 
459
  )
460
 
461
  def update_slider(aug_type):
@@ -464,7 +484,8 @@ def create_interface():
464
  minimum=aug_options[aug_type]["range"][0],
465
  maximum=aug_options[aug_type]["range"][1],
466
  value=aug_options[aug_type]["default"],
467
- label=aug_options[aug_type]["param_name"]
 
468
  )
469
  }
470
 
 
9
  import uuid
10
  import random
11
  from pathlib import Path
12
+ import colorsys
13
 
14
  class PolygonAugmentation:
15
  def __init__(self, tolerance=0.2, area_threshold=0.01, debug=False):
 
35
  else:
36
  data = json.load(json_file)
37
 
 
38
  shapes = []
39
  if 'shapes' in data and isinstance(data['shapes'], list):
40
  shapes = data['shapes']
41
  elif 'segments' in data and isinstance(data['segments'], list):
 
42
  shapes = [
43
  {
44
  "label": seg.get("class", "unknown"),
 
51
  for seg in data['segments']
52
  ]
53
  else:
54
+ raise ValueError("Invalid JSON: Neither 'shapes' nor 'segments' key found or not a list")
55
 
56
  polygons = []
57
  labels = []
 
155
  "group_id": None,
156
  "shape_type": "polygon",
157
  "flags": {},
158
+ "confidence": 1.0
159
  })
160
 
161
  aug_data = {
 
342
  contrast_limit=aug_param,
343
  p=1.0
344
  ),
345
+ "pixel_dropout": A.PixelDropout(dropout_prob=min(max(aug_param, 0.0), 1.0), p=1.0)
346
  }
347
 
348
  if aug_type not in aug_dict:
 
371
 
372
  def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param: float):
373
  try:
374
+ # Validate aug_param based on aug_type
375
+ aug_ranges = {
376
+ "Rotate": (-30, 30),
377
+ "Horizontal Flip": (0, 1),
378
+ "Vertical Flip": (0, 1),
379
+ "Scale": (0.5, 1.5),
380
+ "Brightness/Contrast": (-0.3, 0.3),
381
+ "Pixel Dropout": (0.01, 0.1)
382
+ }
383
+ if aug_type not in aug_ranges:
384
+ raise ValueError(f"Invalid augmentation type: {aug_type}")
385
+ min_val, max_val = aug_ranges[aug_type]
386
+ if not (min_val <= aug_param <= max_val):
387
+ raise ValueError(f"Parameter {aug_param} for {aug_type} is out of range [{min_val}, {max_val}]")
388
+
389
  # Convert PIL image to NumPy
390
  img_np = np.array(image)
391
  img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
 
401
  img_np, polygons, labels, original_areas, original_data, aug_type, aug_param
402
  )
403
 
404
+ # Create a dynamic color map for unique labels
405
  unique_labels = list(set(shape['label'] for shape in aug_data['shapes']))
406
+ if not unique_labels:
407
+ label_color_map = {"unknown": (0, 255, 0)}
408
+ else:
409
+ num_labels = len(unique_labels)
410
+ hues = [i / num_labels for i in range(num_labels)]
411
+ label_color_map = {}
412
+ for label, hue in zip(unique_labels, hues):
413
+ rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
414
+ rgb = tuple(int(c * 255) for c in rgb)
415
+ label_color_map[label] = rgb
416
 
417
  # Convert augmented image to RGB for visualization
418
  aug_image_rgb = cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB)
 
422
  height, width = aug_image.shape[:2]
423
  for shape in aug_data['shapes']:
424
  label = shape['label']
425
+ color = label_color_map.get(label, (0, 255, 0))
426
  points = np.array(shape['points'], dtype=np.int32)
427
 
428
  # Draw filled mask with transparency
 
430
  cv2.fillPoly(mask, [points], 1)
431
  colored_mask = np.zeros_like(aug_image_rgb)
432
  colored_mask[mask == 1] = color
433
+ alpha = 0.3
434
  cv2.addWeighted(colored_mask, alpha, overlay, 1 - alpha, 0, overlay)
435
 
436
  # Draw polygon outline
 
474
  minimum=aug_options["Rotate"]["range"][0],
475
  maximum=aug_options["Rotate"]["range"][1],
476
  value=aug_options["Rotate"]["default"],
477
+ label=aug_options["Rotate"]["param_name"],
478
+ step=0.01
479
  )
480
 
481
  def update_slider(aug_type):
 
484
  minimum=aug_options[aug_type]["range"][0],
485
  maximum=aug_options[aug_type]["range"][1],
486
  value=aug_options[aug_type]["default"],
487
+ label=aug_options[aug_type]["param_name"],
488
+ step=0.01 if aug_type in ["Pixel Dropout", "Brightness/Contrast", "Scale"] else 1
489
  )
490
  }
491