laudari commited on
Commit
1cc9814
·
verified ·
1 Parent(s): a7aaf94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -15
app.py CHANGED
@@ -41,7 +41,11 @@ class PolygonAugmentation:
41
  if isinstance(json_file, str):
42
  with open(json_file, 'r', encoding='utf-8') as f:
43
  data = json.load(f)
 
 
 
44
  else:
 
45
  data = json.load(json_file)
46
 
47
  shapes = []
@@ -344,8 +348,8 @@ class PolygonAugmentation:
344
 
345
  aug_dict = {
346
  "rotate": A.Rotate(limit=aug_param, p=1.0),
347
- "horizontal_flip": A.HorizontalFlip(p=1.0 if aug_param > 0 else 0.0),
348
- "vertical_flip": A.VerticalFlip(p=1.0 if aug_param > 0 else 0.0),
349
  "scale": A.Affine(scale=aug_param, p=1.0),
350
  "brightness_contrast": A.RandomBrightnessContrast(
351
  brightness_limit=aug_param,
@@ -361,16 +365,39 @@ class PolygonAugmentation:
361
  transform = A.Compose([
362
  aug_dict[aug_type],
363
  A.RandomCrop(width=crop_width, height=crop_height, p=0.8)
364
- ], additional_targets={'mask': 'mask'})
365
 
366
  masks, mask_labels = self.polygons_to_masks(image, polygons, labels)
367
  if masks.shape[0] == 0:
368
  raise ValueError("No valid masks created from polygons")
369
 
370
- # Ensure masks are processed correctly
371
- aug_result = transform(image=image, masks=masks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  aug_image = aug_result['image']
373
- aug_masks = aug_result['masks']
 
 
 
 
 
 
374
 
375
  # Validate augmented image
376
  if aug_image is None or aug_image.size == 0:
@@ -387,23 +414,28 @@ class PolygonAugmentation:
387
 
388
  def batch_augment_images(self, image_json_pairs, aug_configs, num_augmentations):
389
  """Batch process multiple images with multiple augmentation configurations"""
 
390
  self.augmented_results = []
391
  results = []
392
 
393
- for pair_idx, (image, json_file) in enumerate(image_json_pairs):
394
- if image is None or json_file is None:
 
395
  continue
396
 
397
  try:
 
398
  # Convert PIL image to NumPy
399
  img_np = np.array(image)
400
  img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
401
 
402
- # Load data
403
- img_np, polygons, labels, original_areas, original_data, _ = self.load_labelme_data(json_file, img_np)
 
404
 
405
  # Apply each augmentation configuration
406
- for config in aug_configs:
 
407
  for aug_idx in range(num_augmentations):
408
  # Generate random parameter within range
409
  min_val, max_val = config['param_range']
@@ -413,6 +445,7 @@ class PolygonAugmentation:
413
  aug_param = random.uniform(min_val, max_val)
414
 
415
  try:
 
416
  aug_image, aug_data = self.augment_single_image(
417
  img_np, polygons, labels, original_areas,
418
  original_data, config['aug_type'], aug_param
@@ -438,15 +471,21 @@ class PolygonAugmentation:
438
 
439
  self.augmented_results.append(result_data)
440
  results.append(aug_image_vis)
 
441
 
442
  except Exception as e:
443
  logger.error(f"Error augmenting image {pair_idx} with {config['aug_type']}: {str(e)}")
 
 
444
  continue
445
 
446
  except Exception as e:
447
  logger.error(f"Error processing image pair {pair_idx}: {str(e)}")
 
 
448
  continue
449
 
 
450
  return results
451
 
452
  def create_visualization(self, aug_image, aug_data):
@@ -544,9 +583,17 @@ def create_interface():
544
  if images[i] is not None and json_files[i] is not None:
545
  try:
546
  image = Image.open(images[i].name)
547
- image_json_pairs.append((image, images[i].name))
 
 
 
 
 
 
548
  except Exception as e:
549
- logger.error(f"Error loading image {i}: {str(e)}")
 
 
550
  continue
551
 
552
  if not image_json_pairs:
@@ -596,6 +643,7 @@ def create_interface():
596
 
597
  # Process augmentations
598
  try:
 
599
  augmented_images = augmenter.batch_augment_images(
600
  image_json_pairs, aug_configs, num_augmentations
601
  )
@@ -604,11 +652,15 @@ def create_interface():
604
  json_summary = json.dumps([result['metadata'] for result in augmenter.augmented_results], indent=2)
605
 
606
  status = f"Generated {len(augmented_images)} augmented images from {len(image_json_pairs)} input pairs"
 
607
  return augmented_images, json_summary, status
608
 
609
  except Exception as e:
610
- logger.error(f"Batch augmentation error: {str(e)}")
611
- return [], f"Error: {str(e)}", None
 
 
 
612
 
613
  def download_package():
614
  return augmenter.create_download_package()
 
41
  if isinstance(json_file, str):
42
  with open(json_file, 'r', encoding='utf-8') as f:
43
  data = json.load(f)
44
+ elif isinstance(json_file, dict):
45
+ # Handle dictionary data directly
46
+ data = json_file
47
  else:
48
+ # Handle file object
49
  data = json.load(json_file)
50
 
51
  shapes = []
 
348
 
349
  aug_dict = {
350
  "rotate": A.Rotate(limit=aug_param, p=1.0),
351
+ "horizontal_flip": A.HorizontalFlip(p=1.0 if aug_param == 1 else 0.0),
352
+ "vertical_flip": A.VerticalFlip(p=1.0 if aug_param == 1 else 0.0),
353
  "scale": A.Affine(scale=aug_param, p=1.0),
354
  "brightness_contrast": A.RandomBrightnessContrast(
355
  brightness_limit=aug_param,
 
365
  transform = A.Compose([
366
  aug_dict[aug_type],
367
  A.RandomCrop(width=crop_width, height=crop_height, p=0.8)
368
+ ])
369
 
370
  masks, mask_labels = self.polygons_to_masks(image, polygons, labels)
371
  if masks.shape[0] == 0:
372
  raise ValueError("No valid masks created from polygons")
373
 
374
+ # Convert masks array to list for albumentations
375
+ masks_list = [masks[i] for i in range(masks.shape[0])]
376
+
377
+ # Create additional targets for each mask
378
+ additional_targets = {f'mask{i}': 'mask' for i in range(len(masks_list))}
379
+
380
+ # Update transform with additional targets
381
+ transform = A.Compose([
382
+ aug_dict[aug_type],
383
+ A.RandomCrop(width=crop_width, height=crop_height, p=0.8)
384
+ ], additional_targets=additional_targets)
385
+
386
+ # Prepare input dictionary
387
+ input_dict = {'image': image}
388
+ for i, mask in enumerate(masks_list):
389
+ input_dict[f'mask{i}'] = mask
390
+
391
+ # Apply augmentation
392
+ aug_result = transform(**input_dict)
393
  aug_image = aug_result['image']
394
+
395
+ # Collect augmented masks
396
+ aug_masks_list = []
397
+ for i in range(len(masks_list)):
398
+ aug_masks_list.append(aug_result[f'mask{i}'])
399
+
400
+ aug_masks = np.array(aug_masks_list, dtype=np.uint8)
401
 
402
  # Validate augmented image
403
  if aug_image is None or aug_image.size == 0:
 
414
 
415
  def batch_augment_images(self, image_json_pairs, aug_configs, num_augmentations):
416
  """Batch process multiple images with multiple augmentation configurations"""
417
+ logger.info(f"Starting batch augmentation with {len(image_json_pairs)} pairs, {len(aug_configs)} configs, {num_augmentations} augmentations each")
418
  self.augmented_results = []
419
  results = []
420
 
421
+ for pair_idx, (image, json_data) in enumerate(image_json_pairs):
422
+ if image is None or json_data is None:
423
+ logger.warning(f"Skipping pair {pair_idx}: missing image or JSON data")
424
  continue
425
 
426
  try:
427
+ logger.info(f"Processing image pair {pair_idx}")
428
  # Convert PIL image to NumPy
429
  img_np = np.array(image)
430
  img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
431
 
432
+ # Load data - pass the JSON data directly
433
+ img_np, polygons, labels, original_areas, original_data, _ = self.load_labelme_data(json_data, img_np)
434
+ logger.info(f"Loaded {len(polygons)} polygons for image {pair_idx}")
435
 
436
  # Apply each augmentation configuration
437
+ for config_idx, config in enumerate(aug_configs):
438
+ logger.info(f"Applying config {config_idx}: {config['aug_type']}")
439
  for aug_idx in range(num_augmentations):
440
  # Generate random parameter within range
441
  min_val, max_val = config['param_range']
 
445
  aug_param = random.uniform(min_val, max_val)
446
 
447
  try:
448
+ logger.info(f"Generating augmentation {aug_idx} with {config['aug_type']}, param: {aug_param}")
449
  aug_image, aug_data = self.augment_single_image(
450
  img_np, polygons, labels, original_areas,
451
  original_data, config['aug_type'], aug_param
 
471
 
472
  self.augmented_results.append(result_data)
473
  results.append(aug_image_vis)
474
+ logger.info(f"Successfully generated augmentation {aug_idx} for image {pair_idx}")
475
 
476
  except Exception as e:
477
  logger.error(f"Error augmenting image {pair_idx} with {config['aug_type']}: {str(e)}")
478
+ import traceback
479
+ logger.error(traceback.format_exc())
480
  continue
481
 
482
  except Exception as e:
483
  logger.error(f"Error processing image pair {pair_idx}: {str(e)}")
484
+ import traceback
485
+ logger.error(traceback.format_exc())
486
  continue
487
 
488
+ logger.info(f"Batch augmentation completed. Generated {len(results)} total results.")
489
  return results
490
 
491
  def create_visualization(self, aug_image, aug_data):
 
583
  if images[i] is not None and json_files[i] is not None:
584
  try:
585
  image = Image.open(images[i].name)
586
+ # Load JSON file content properly
587
+ json_path = json_files[i].name
588
+ logger.info(f"Loading JSON from: {json_path}")
589
+ with open(json_path, 'r', encoding='utf-8') as f:
590
+ json_data = json.load(f)
591
+ logger.info(f"Successfully loaded JSON with keys: {list(json_data.keys())}")
592
+ image_json_pairs.append((image, json_data))
593
  except Exception as e:
594
+ logger.error(f"Error loading image/JSON pair {i}: {str(e)}")
595
+ import traceback
596
+ logger.error(traceback.format_exc())
597
  continue
598
 
599
  if not image_json_pairs:
 
643
 
644
  # Process augmentations
645
  try:
646
+ logger.info(f"Starting batch augmentation with {len(image_json_pairs)} image pairs and {len(aug_configs)} configurations")
647
  augmented_images = augmenter.batch_augment_images(
648
  image_json_pairs, aug_configs, num_augmentations
649
  )
 
652
  json_summary = json.dumps([result['metadata'] for result in augmenter.augmented_results], indent=2)
653
 
654
  status = f"Generated {len(augmented_images)} augmented images from {len(image_json_pairs)} input pairs"
655
+ logger.info(status)
656
  return augmented_images, json_summary, status
657
 
658
  except Exception as e:
659
+ error_msg = f"Batch augmentation error: {str(e)}"
660
+ logger.error(error_msg)
661
+ import traceback
662
+ logger.error(traceback.format_exc())
663
+ return [], error_msg, None
664
 
665
  def download_package():
666
  return augmenter.create_download_package()