laudari commited on
Commit
a7aaf94
Β·
verified Β·
1 Parent(s): f3ed335

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +301 -104
app.py CHANGED
@@ -11,6 +11,9 @@ import random
11
  from pathlib import Path
12
  import colorsys
13
  import logging
 
 
 
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO)
@@ -22,6 +25,7 @@ class PolygonAugmentation:
22
  self.area_threshold = area_threshold
23
  self.debug = debug
24
  self.supported_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.PNG', '.JPEG']
 
25
 
26
  def __getattr__(self, name: str) -> Any:
27
  raise AttributeError(f"'PolygonAugmentation' object has no attribute '{name}'")
@@ -381,44 +385,72 @@ class PolygonAugmentation:
381
  logger.info(f"Augmentation completed: {len(aug_polygons)} polygons generated")
382
  return aug_image, aug_data
383
 
384
- def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param: float):
385
- try:
386
- # Validate aug_param based on aug_type
387
- aug_ranges = {
388
- "rotate": (-30, 30),
389
- "horizontal_flip": (0, 1),
390
- "vertical_flip": (0, 1),
391
- "scale": (0.5, 1.5),
392
- "brightness_contrast": (-0.3, 0.3),
393
- "pixel_dropout": (0.01, 0.1)
394
- }
395
- if aug_type not in aug_ranges:
396
- raise ValueError(f"Invalid augmentation type: {aug_type}")
397
- min_val, max_val = aug_ranges[aug_type]
398
- if not (min_val <= aug_param <= max_val):
399
- raise ValueError(f"Parameter {aug_param} for {aug_type} is out of range [{min_val}, {max_val}]")
400
-
401
- # Convert PIL image to NumPy
402
- if image is None:
403
- raise ValueError("Input image is None")
404
- img_np = np.array(image)
405
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
406
-
407
- # Initialize augmenter
408
- augmenter = PolygonAugmentation(tolerance=2.0, area_threshold=0.01, debug=True)
409
 
410
- # Load data
411
- img_np, polygons, labels, original_areas, original_data, _ = augmenter.load_labelme_data(json_file, img_np)
412
-
413
- # Perform augmentation
414
- aug_image, aug_data = augmenter.augment_single_image(
415
- img_np, polygons, labels, original_areas, original_data, aug_type, aug_param
416
- )
417
-
418
- # Validate augmented image
419
- if aug_image is None or aug_image.size == 0:
420
- raise ValueError("Augmented image is empty or invalid")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
 
 
 
 
422
  # Create a dynamic color map for unique labels
423
  unique_labels = list(set(shape['label'] for shape in aug_data['shapes']))
424
  if not unique_labels:
@@ -434,20 +466,15 @@ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param:
434
 
435
  # Convert augmented image to RGB for visualization
436
  aug_image_rgb = cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB)
437
- if aug_image_rgb is None or aug_image_rgb.size == 0:
438
- raise ValueError("Failed to convert augmented image to RGB")
439
-
440
- # Create a clean copy of the augmented image for visualization
441
  overlay = aug_image_rgb.copy()
442
 
443
  # Create masks and outlines for visualization
444
  height, width = aug_image.shape[:2]
445
  for shape in aug_data['shapes']:
446
  label = shape['label']
447
- color = label_color_map.get(label, (0, 255, 0)) # Fallback to green
448
  points = np.array(shape['points'], dtype=np.int32)
449
  if len(points) < 3:
450
- logger.warning(f"Skipping invalid polygon for label {label}: fewer than 3 points")
451
  continue
452
 
453
  # Draw semi-transparent mask
@@ -455,84 +482,254 @@ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param:
455
  cv2.fillPoly(mask, [points], 1)
456
  colored_mask = np.zeros_like(aug_image_rgb)
457
  colored_mask[mask == 1] = color
458
- alpha = 0.3 # Transparency for mask
459
  overlay = cv2.addWeighted(overlay, 1.0, colored_mask, alpha, 0.0)
460
 
461
  # Draw polygon outline
462
  cv2.polylines(overlay, [points], isClosed=True, color=color, thickness=2)
463
 
464
- # Convert overlay back to PIL
465
- aug_image_pil = Image.fromarray(overlay)
466
-
467
- # Format JSON for display
468
- aug_json_str = json.dumps(aug_data, indent=2)
469
-
470
- logger.info("Visualization completed successfully")
471
- return aug_image_pil, aug_json_str
472
- except Exception as e:
473
- logger.error(f"Error in augment_image: {str(e)}")
474
- return None, f"Error: {str(e)}"
475
 
476
- # Define augmentation types and parameter ranges
477
- aug_options = {
478
- "rotate": {"display_name": "Rotate", "param_name": "Angle (degrees)", "range": (-30, 30), "default": 0},
479
- "horizontal_flip": {"display_name": "Horizontal Flip", "param_name": "Apply Flip (0 or 1)", "range": (0, 1), "default": 0},
480
- "vertical_flip": {"display_name": "Vertical Flip", "param_name": "Apply Flip (0 or 1)", "range": (0, 1), "default": 0},
481
- "scale": {"display_name": "Scale", "param_name": "Scale Factor", "range": (0.5, 1.5), "default": 1.0},
482
- "brightness_contrast": {"display_name": "Brightness/Contrast", "param_name": "Brightness/Contrast Limit", "range": (-0.3, 0.3), "default": 0},
483
- "pixel_dropout": {"display_name": "Pixel Dropout", "param_name": "Dropout Probability", "range": (0.01, 0.1), "default": 0.05}
484
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
  def create_interface():
487
- with gr.Blocks(title="Donut Polygon Augmentation") as demo:
488
- gr.Markdown("# Donut Polygon Augmentation πŸŒ€")
489
- gr.Markdown("Upload an image and a LabelMe JSON file to apply topology-preserving augmentation to donut-shaped polygons. Each class is visualized with a unique color and semi-transparent mask over the augmented image.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
 
491
  with gr.Row():
492
- with gr.Column():
493
- image_input = gr.Image(type="pil", label="Input Image")
494
- json_input = gr.File(label="LabelMe JSON File", file_types=[".json"])
495
- aug_type = gr.Dropdown(
496
- choices=[v["display_name"] for v in aug_options.values()],
497
- label="Augmentation Type",
498
- value=aug_options["rotate"]["display_name"]
499
  )
500
- aug_param = gr.Slider(
501
- minimum=aug_options["rotate"]["range"][0],
502
- maximum=aug_options["rotate"]["range"][1],
503
- value=aug_options["rotate"]["default"],
504
- label=aug_options["rotate"]["param_name"],
505
- step=0.01
506
  )
507
 
508
- def update_slider(display_name):
509
- aug_key = next(k for k, v in aug_options.items() if v["display_name"] == display_name)
510
- return {
511
- aug_param: gr.update(
512
- minimum=aug_options[aug_key]["range"][0],
513
- maximum=aug_options[aug_key]["range"][1],
514
- value=aug_options[aug_key]["default"],
515
- label=aug_options[aug_key]["param_name"],
516
- step=0.01 if aug_key in ["pixel_dropout", "brightness_contrast", "scale"] else 1
517
- )
518
- }
 
 
 
 
 
 
 
519
 
520
- aug_type.change(fn=update_slider, inputs=aug_type, outputs=[aug_param])
 
 
 
 
 
521
 
522
- submit_btn = gr.Button("Apply Augmentation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
- with gr.Column():
525
- output_image = gr.Image(type="pil", label="Augmented Image with Colored Polygons and Masks")
526
- output_json = gr.Textbox(label="Augmented LabelMe JSON", lines=10, max_lines=20)
527
-
528
- def submit(image, json_file, display_name, aug_param):
529
- aug_key = next(k for k, v in aug_options.items() if v["display_name"] == display_name)
530
- return augment_image(image, json_file, aug_key, aug_param)
531
-
532
- submit_btn.click(
533
- fn=submit,
534
- inputs=[image_input, json_input, aug_type, aug_param],
535
- outputs=[output_image, output_json]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  )
537
 
538
  return demo
 
11
  from pathlib import Path
12
  import colorsys
13
  import logging
14
+ import zipfile
15
+ import io
16
+ from datetime import datetime
17
 
18
  # Set up logging
19
  logging.basicConfig(level=logging.INFO)
 
25
  self.area_threshold = area_threshold
26
  self.debug = debug
27
  self.supported_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.PNG', '.JPEG']
28
+ self.augmented_results = [] # Store all augmentation results
29
 
30
  def __getattr__(self, name: str) -> Any:
31
  raise AttributeError(f"'PolygonAugmentation' object has no attribute '{name}'")
 
385
  logger.info(f"Augmentation completed: {len(aug_polygons)} polygons generated")
386
  return aug_image, aug_data
387
 
388
+ def batch_augment_images(self, image_json_pairs, aug_configs, num_augmentations):
389
+ """Batch process multiple images with multiple augmentation configurations"""
390
+ self.augmented_results = []
391
+ results = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
+ for pair_idx, (image, json_file) in enumerate(image_json_pairs):
394
+ if image is None or json_file is None:
395
+ continue
396
+
397
+ try:
398
+ # Convert PIL image to NumPy
399
+ img_np = np.array(image)
400
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
401
+
402
+ # Load data
403
+ img_np, polygons, labels, original_areas, original_data, _ = self.load_labelme_data(json_file, img_np)
404
+
405
+ # Apply each augmentation configuration
406
+ for config in aug_configs:
407
+ for aug_idx in range(num_augmentations):
408
+ # Generate random parameter within range
409
+ min_val, max_val = config['param_range']
410
+ if config['aug_type'] in ['horizontal_flip', 'vertical_flip']:
411
+ aug_param = random.choice([0, 1])
412
+ else:
413
+ aug_param = random.uniform(min_val, max_val)
414
+
415
+ try:
416
+ aug_image, aug_data = self.augment_single_image(
417
+ img_np, polygons, labels, original_areas,
418
+ original_data, config['aug_type'], aug_param
419
+ )
420
+
421
+ # Create visualization
422
+ aug_image_vis = self.create_visualization(aug_image, aug_data)
423
+
424
+ # Store result
425
+ result_data = {
426
+ 'image': aug_image_vis,
427
+ 'json_data': aug_data,
428
+ 'metadata': {
429
+ 'original_image_index': pair_idx,
430
+ 'augmentation_index': aug_idx,
431
+ 'augmentation_type': config['aug_type'],
432
+ 'parameter_value': aug_param,
433
+ 'parameter_range': config['param_range'],
434
+ 'timestamp': datetime.now().isoformat(),
435
+ 'filename': f'aug_{pair_idx}_{config["aug_type"]}_{aug_idx}.png'
436
+ }
437
+ }
438
+
439
+ self.augmented_results.append(result_data)
440
+ results.append(aug_image_vis)
441
+
442
+ except Exception as e:
443
+ logger.error(f"Error augmenting image {pair_idx} with {config['aug_type']}: {str(e)}")
444
+ continue
445
+
446
+ except Exception as e:
447
+ logger.error(f"Error processing image pair {pair_idx}: {str(e)}")
448
+ continue
449
 
450
+ return results
451
+
452
+ def create_visualization(self, aug_image, aug_data):
453
+ """Create visualization with colored polygons and masks"""
454
  # Create a dynamic color map for unique labels
455
  unique_labels = list(set(shape['label'] for shape in aug_data['shapes']))
456
  if not unique_labels:
 
466
 
467
  # Convert augmented image to RGB for visualization
468
  aug_image_rgb = cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB)
 
 
 
 
469
  overlay = aug_image_rgb.copy()
470
 
471
  # Create masks and outlines for visualization
472
  height, width = aug_image.shape[:2]
473
  for shape in aug_data['shapes']:
474
  label = shape['label']
475
+ color = label_color_map.get(label, (0, 255, 0))
476
  points = np.array(shape['points'], dtype=np.int32)
477
  if len(points) < 3:
 
478
  continue
479
 
480
  # Draw semi-transparent mask
 
482
  cv2.fillPoly(mask, [points], 1)
483
  colored_mask = np.zeros_like(aug_image_rgb)
484
  colored_mask[mask == 1] = color
485
+ alpha = 0.3
486
  overlay = cv2.addWeighted(overlay, 1.0, colored_mask, alpha, 0.0)
487
 
488
  # Draw polygon outline
489
  cv2.polylines(overlay, [points], isClosed=True, color=color, thickness=2)
490
 
491
+ return Image.fromarray(overlay)
 
 
 
 
 
 
 
 
 
 
492
 
493
+ def create_download_package(self):
494
+ """Create a zip file with all augmented images and JSON files"""
495
+ if not self.augmented_results:
496
+ return None
497
+
498
+ zip_buffer = io.BytesIO()
499
+
500
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
501
+ # Add all augmented images and their JSON files
502
+ for idx, result in enumerate(self.augmented_results):
503
+ # Save image
504
+ img_buffer = io.BytesIO()
505
+ result['image'].save(img_buffer, format='PNG')
506
+ img_filename = result['metadata']['filename']
507
+ zip_file.writestr(img_filename, img_buffer.getvalue())
508
+
509
+ # Save JSON data
510
+ json_filename = f"metadata_{idx}.json"
511
+ json_str = json.dumps(result['json_data'], indent=2)
512
+ zip_file.writestr(json_filename, json_str)
513
+
514
+ # Add summary metadata
515
+ summary = {
516
+ 'total_augmentations': len(self.augmented_results),
517
+ 'generation_timestamp': datetime.now().isoformat(),
518
+ 'augmentation_summary': [result['metadata'] for result in self.augmented_results]
519
+ }
520
+ zip_file.writestr('augmentation_summary.json', json.dumps(summary, indent=2))
521
+
522
+ zip_buffer.seek(0)
523
+ return zip_buffer.getvalue()
524
 
525
  def create_interface():
526
+ augmenter = PolygonAugmentation(tolerance=2.0, area_threshold=0.01, debug=True)
527
+
528
+ def process_batch_augmentation(
529
+ images, json_files, num_augmentations,
530
+ rotate_enabled, rotate_min, rotate_max,
531
+ hflip_enabled, vflip_enabled,
532
+ scale_enabled, scale_min, scale_max,
533
+ brightness_enabled, brightness_min, brightness_max,
534
+ dropout_enabled, dropout_min, dropout_max
535
+ ):
536
+ if not images or not json_files:
537
+ return [], "No images or JSON files uploaded", None
538
+
539
+ # Pair images with JSON files
540
+ image_json_pairs = []
541
+ min_length = min(len(images), len(json_files))
542
+
543
+ for i in range(min_length):
544
+ if images[i] is not None and json_files[i] is not None:
545
+ try:
546
+ image = Image.open(images[i].name)
547
+ image_json_pairs.append((image, images[i].name))
548
+ except Exception as e:
549
+ logger.error(f"Error loading image {i}: {str(e)}")
550
+ continue
551
+
552
+ if not image_json_pairs:
553
+ return [], "No valid image-JSON pairs found", None
554
+
555
+ # Configure augmentations based on user selections
556
+ aug_configs = []
557
+
558
+ if rotate_enabled:
559
+ aug_configs.append({
560
+ 'aug_type': 'rotate',
561
+ 'param_range': (rotate_min, rotate_max)
562
+ })
563
+
564
+ if hflip_enabled:
565
+ aug_configs.append({
566
+ 'aug_type': 'horizontal_flip',
567
+ 'param_range': (0, 1)
568
+ })
569
+
570
+ if vflip_enabled:
571
+ aug_configs.append({
572
+ 'aug_type': 'vertical_flip',
573
+ 'param_range': (0, 1)
574
+ })
575
+
576
+ if scale_enabled:
577
+ aug_configs.append({
578
+ 'aug_type': 'scale',
579
+ 'param_range': (scale_min, scale_max)
580
+ })
581
+
582
+ if brightness_enabled:
583
+ aug_configs.append({
584
+ 'aug_type': 'brightness_contrast',
585
+ 'param_range': (brightness_min, brightness_max)
586
+ })
587
+
588
+ if dropout_enabled:
589
+ aug_configs.append({
590
+ 'aug_type': 'pixel_dropout',
591
+ 'param_range': (dropout_min, dropout_max)
592
+ })
593
+
594
+ if not aug_configs:
595
+ return [], "No augmentation types selected", None
596
+
597
+ # Process augmentations
598
+ try:
599
+ augmented_images = augmenter.batch_augment_images(
600
+ image_json_pairs, aug_configs, num_augmentations
601
+ )
602
+
603
+ # Create JSON summary
604
+ json_summary = json.dumps([result['metadata'] for result in augmenter.augmented_results], indent=2)
605
+
606
+ status = f"Generated {len(augmented_images)} augmented images from {len(image_json_pairs)} input pairs"
607
+ return augmented_images, json_summary, status
608
+
609
+ except Exception as e:
610
+ logger.error(f"Batch augmentation error: {str(e)}")
611
+ return [], f"Error: {str(e)}", None
612
+
613
+ def download_package():
614
+ return augmenter.create_download_package()
615
+
616
+ def show_mask_overlay(evt: gr.SelectData):
617
+ if evt.index < len(augmenter.augmented_results):
618
+ return augmenter.augmented_results[evt.index]['image']
619
+ return None
620
+
621
+ with gr.Blocks(title="Dynamic Donut Polygon Augmentation") as demo:
622
+ gr.Markdown("# πŸŒ€ Dynamic Donut Polygon Augmentation Tool")
623
+ gr.Markdown("Upload multiple images and JSON files to apply batch augmentation with configurable parameter ranges")
624
 
625
  with gr.Row():
626
+ with gr.Column(scale=1):
627
+ gr.Markdown("## πŸ“ Input Files")
628
+
629
+ images_input = gr.File(
630
+ file_count="multiple",
631
+ file_types=["image"],
632
+ label="Upload Images"
633
  )
634
+
635
+ json_input = gr.File(
636
+ file_count="multiple",
637
+ file_types=[".json"],
638
+ label="Upload LabelMe JSON Files"
 
639
  )
640
 
641
+ num_augmentations = gr.Slider(
642
+ minimum=1, maximum=5, value=2, step=1,
643
+ label="Augmentations per configuration"
644
+ )
645
+
646
+ gr.Markdown("## βš™οΈ Augmentation Configuration")
647
+
648
+ # Rotation parameters
649
+ with gr.Group():
650
+ rotate_enabled = gr.Checkbox(label="Enable Rotation", value=True)
651
+ with gr.Row():
652
+ rotate_min = gr.Slider(-45, 45, -15, label="Min Rotation (degrees)")
653
+ rotate_max = gr.Slider(-45, 45, 15, label="Max Rotation (degrees)")
654
+
655
+ # Flip parameters
656
+ with gr.Group():
657
+ hflip_enabled = gr.Checkbox(label="Enable Horizontal Flip", value=True)
658
+ vflip_enabled = gr.Checkbox(label="Enable Vertical Flip", value=False)
659
 
660
+ # Scale parameters
661
+ with gr.Group():
662
+ scale_enabled = gr.Checkbox(label="Enable Scale", value=True)
663
+ with gr.Row():
664
+ scale_min = gr.Slider(0.7, 1.3, 0.9, label="Min Scale")
665
+ scale_max = gr.Slider(0.7, 1.3, 1.1, label="Max Scale")
666
 
667
+ # Brightness parameters
668
+ with gr.Group():
669
+ brightness_enabled = gr.Checkbox(label="Enable Brightness/Contrast", value=True)
670
+ with gr.Row():
671
+ brightness_min = gr.Slider(-0.3, 0.3, -0.1, label="Min Brightness")
672
+ brightness_max = gr.Slider(-0.3, 0.3, 0.1, label="Max Brightness")
673
+
674
+ # Dropout parameters
675
+ with gr.Group():
676
+ dropout_enabled = gr.Checkbox(label="Enable Pixel Dropout", value=False)
677
+ with gr.Row():
678
+ dropout_min = gr.Slider(0.01, 0.1, 0.02, label="Min Dropout")
679
+ dropout_max = gr.Slider(0.01, 0.1, 0.05, label="Max Dropout")
680
+
681
+ generate_btn = gr.Button("πŸš€ Generate Augmentations", variant="primary")
682
+ status_text = gr.Textbox(label="Status", interactive=False)
683
 
684
+ with gr.Column(scale=2):
685
+ gr.Markdown("## πŸ–ΌοΈ Augmented Results")
686
+ gr.Markdown("*Click on any image to view with enhanced mask overlay*")
687
+
688
+ augmented_gallery = gr.Gallery(
689
+ label="Augmented Images with Polygon Masks",
690
+ show_label=False,
691
+ elem_id="gallery",
692
+ columns=3,
693
+ rows=3,
694
+ height="auto"
695
+ )
696
+
697
+ with gr.Row():
698
+ download_btn = gr.Button("πŸ“₯ Download All (ZIP)", variant="secondary")
699
+ download_file = gr.File(label="Download Package", visible=False)
700
+
701
+ gr.Markdown("## πŸ“‹ Augmentation Metadata")
702
+ json_output = gr.Code(
703
+ label="Generated Metadata JSON",
704
+ language="json",
705
+ lines=15
706
+ )
707
+
708
+ gr.Markdown("## 🎭 Enhanced Preview")
709
+ mask_preview = gr.Image(label="Selected Image with Mask Overlay")
710
+
711
+ # Event handlers
712
+ generate_btn.click(
713
+ process_batch_augmentation,
714
+ inputs=[
715
+ images_input, json_input, num_augmentations,
716
+ rotate_enabled, rotate_min, rotate_max,
717
+ hflip_enabled, vflip_enabled,
718
+ scale_enabled, scale_min, scale_max,
719
+ brightness_enabled, brightness_min, brightness_max,
720
+ dropout_enabled, dropout_min, dropout_max
721
+ ],
722
+ outputs=[augmented_gallery, json_output, status_text]
723
+ )
724
+
725
+ download_btn.click(
726
+ download_package,
727
+ outputs=download_file
728
+ )
729
+
730
+ augmented_gallery.select(
731
+ show_mask_overlay,
732
+ outputs=mask_preview
733
  )
734
 
735
  return demo