Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -41,7 +41,11 @@ class PolygonAugmentation:
|
|
| 41 |
if isinstance(json_file, str):
|
| 42 |
with open(json_file, 'r', encoding='utf-8') as f:
|
| 43 |
data = json.load(f)
|
|
|
|
|
|
|
|
|
|
| 44 |
else:
|
|
|
|
| 45 |
data = json.load(json_file)
|
| 46 |
|
| 47 |
shapes = []
|
|
@@ -344,8 +348,8 @@ class PolygonAugmentation:
|
|
| 344 |
|
| 345 |
aug_dict = {
|
| 346 |
"rotate": A.Rotate(limit=aug_param, p=1.0),
|
| 347 |
-
"horizontal_flip": A.HorizontalFlip(p=1.0 if aug_param
|
| 348 |
-
"vertical_flip": A.VerticalFlip(p=1.0 if aug_param
|
| 349 |
"scale": A.Affine(scale=aug_param, p=1.0),
|
| 350 |
"brightness_contrast": A.RandomBrightnessContrast(
|
| 351 |
brightness_limit=aug_param,
|
|
@@ -361,16 +365,39 @@ class PolygonAugmentation:
|
|
| 361 |
transform = A.Compose([
|
| 362 |
aug_dict[aug_type],
|
| 363 |
A.RandomCrop(width=crop_width, height=crop_height, p=0.8)
|
| 364 |
-
]
|
| 365 |
|
| 366 |
masks, mask_labels = self.polygons_to_masks(image, polygons, labels)
|
| 367 |
if masks.shape[0] == 0:
|
| 368 |
raise ValueError("No valid masks created from polygons")
|
| 369 |
|
| 370 |
-
#
|
| 371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
aug_image = aug_result['image']
|
| 373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
# Validate augmented image
|
| 376 |
if aug_image is None or aug_image.size == 0:
|
|
@@ -387,23 +414,28 @@ class PolygonAugmentation:
|
|
| 387 |
|
| 388 |
def batch_augment_images(self, image_json_pairs, aug_configs, num_augmentations):
|
| 389 |
"""Batch process multiple images with multiple augmentation configurations"""
|
|
|
|
| 390 |
self.augmented_results = []
|
| 391 |
results = []
|
| 392 |
|
| 393 |
-
for pair_idx, (image,
|
| 394 |
-
if image is None or
|
|
|
|
| 395 |
continue
|
| 396 |
|
| 397 |
try:
|
|
|
|
| 398 |
# Convert PIL image to NumPy
|
| 399 |
img_np = np.array(image)
|
| 400 |
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 401 |
|
| 402 |
-
# Load data
|
| 403 |
-
img_np, polygons, labels, original_areas, original_data, _ = self.load_labelme_data(
|
|
|
|
| 404 |
|
| 405 |
# Apply each augmentation configuration
|
| 406 |
-
for config in aug_configs:
|
|
|
|
| 407 |
for aug_idx in range(num_augmentations):
|
| 408 |
# Generate random parameter within range
|
| 409 |
min_val, max_val = config['param_range']
|
|
@@ -413,6 +445,7 @@ class PolygonAugmentation:
|
|
| 413 |
aug_param = random.uniform(min_val, max_val)
|
| 414 |
|
| 415 |
try:
|
|
|
|
| 416 |
aug_image, aug_data = self.augment_single_image(
|
| 417 |
img_np, polygons, labels, original_areas,
|
| 418 |
original_data, config['aug_type'], aug_param
|
|
@@ -438,15 +471,21 @@ class PolygonAugmentation:
|
|
| 438 |
|
| 439 |
self.augmented_results.append(result_data)
|
| 440 |
results.append(aug_image_vis)
|
|
|
|
| 441 |
|
| 442 |
except Exception as e:
|
| 443 |
logger.error(f"Error augmenting image {pair_idx} with {config['aug_type']}: {str(e)}")
|
|
|
|
|
|
|
| 444 |
continue
|
| 445 |
|
| 446 |
except Exception as e:
|
| 447 |
logger.error(f"Error processing image pair {pair_idx}: {str(e)}")
|
|
|
|
|
|
|
| 448 |
continue
|
| 449 |
|
|
|
|
| 450 |
return results
|
| 451 |
|
| 452 |
def create_visualization(self, aug_image, aug_data):
|
|
@@ -544,9 +583,17 @@ def create_interface():
|
|
| 544 |
if images[i] is not None and json_files[i] is not None:
|
| 545 |
try:
|
| 546 |
image = Image.open(images[i].name)
|
| 547 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
except Exception as e:
|
| 549 |
-
logger.error(f"Error loading image {i}: {str(e)}")
|
|
|
|
|
|
|
| 550 |
continue
|
| 551 |
|
| 552 |
if not image_json_pairs:
|
|
@@ -596,6 +643,7 @@ def create_interface():
|
|
| 596 |
|
| 597 |
# Process augmentations
|
| 598 |
try:
|
|
|
|
| 599 |
augmented_images = augmenter.batch_augment_images(
|
| 600 |
image_json_pairs, aug_configs, num_augmentations
|
| 601 |
)
|
|
@@ -604,11 +652,15 @@ def create_interface():
|
|
| 604 |
json_summary = json.dumps([result['metadata'] for result in augmenter.augmented_results], indent=2)
|
| 605 |
|
| 606 |
status = f"Generated {len(augmented_images)} augmented images from {len(image_json_pairs)} input pairs"
|
|
|
|
| 607 |
return augmented_images, json_summary, status
|
| 608 |
|
| 609 |
except Exception as e:
|
| 610 |
-
|
| 611 |
-
|
|
|
|
|
|
|
|
|
|
| 612 |
|
| 613 |
def download_package():
|
| 614 |
return augmenter.create_download_package()
|
|
|
|
| 41 |
if isinstance(json_file, str):
|
| 42 |
with open(json_file, 'r', encoding='utf-8') as f:
|
| 43 |
data = json.load(f)
|
| 44 |
+
elif isinstance(json_file, dict):
|
| 45 |
+
# Handle dictionary data directly
|
| 46 |
+
data = json_file
|
| 47 |
else:
|
| 48 |
+
# Handle file object
|
| 49 |
data = json.load(json_file)
|
| 50 |
|
| 51 |
shapes = []
|
|
|
|
| 348 |
|
| 349 |
aug_dict = {
|
| 350 |
"rotate": A.Rotate(limit=aug_param, p=1.0),
|
| 351 |
+
"horizontal_flip": A.HorizontalFlip(p=1.0 if aug_param == 1 else 0.0),
|
| 352 |
+
"vertical_flip": A.VerticalFlip(p=1.0 if aug_param == 1 else 0.0),
|
| 353 |
"scale": A.Affine(scale=aug_param, p=1.0),
|
| 354 |
"brightness_contrast": A.RandomBrightnessContrast(
|
| 355 |
brightness_limit=aug_param,
|
|
|
|
| 365 |
transform = A.Compose([
|
| 366 |
aug_dict[aug_type],
|
| 367 |
A.RandomCrop(width=crop_width, height=crop_height, p=0.8)
|
| 368 |
+
])
|
| 369 |
|
| 370 |
masks, mask_labels = self.polygons_to_masks(image, polygons, labels)
|
| 371 |
if masks.shape[0] == 0:
|
| 372 |
raise ValueError("No valid masks created from polygons")
|
| 373 |
|
| 374 |
+
# Convert masks array to list for albumentations
|
| 375 |
+
masks_list = [masks[i] for i in range(masks.shape[0])]
|
| 376 |
+
|
| 377 |
+
# Create additional targets for each mask
|
| 378 |
+
additional_targets = {f'mask{i}': 'mask' for i in range(len(masks_list))}
|
| 379 |
+
|
| 380 |
+
# Update transform with additional targets
|
| 381 |
+
transform = A.Compose([
|
| 382 |
+
aug_dict[aug_type],
|
| 383 |
+
A.RandomCrop(width=crop_width, height=crop_height, p=0.8)
|
| 384 |
+
], additional_targets=additional_targets)
|
| 385 |
+
|
| 386 |
+
# Prepare input dictionary
|
| 387 |
+
input_dict = {'image': image}
|
| 388 |
+
for i, mask in enumerate(masks_list):
|
| 389 |
+
input_dict[f'mask{i}'] = mask
|
| 390 |
+
|
| 391 |
+
# Apply augmentation
|
| 392 |
+
aug_result = transform(**input_dict)
|
| 393 |
aug_image = aug_result['image']
|
| 394 |
+
|
| 395 |
+
# Collect augmented masks
|
| 396 |
+
aug_masks_list = []
|
| 397 |
+
for i in range(len(masks_list)):
|
| 398 |
+
aug_masks_list.append(aug_result[f'mask{i}'])
|
| 399 |
+
|
| 400 |
+
aug_masks = np.array(aug_masks_list, dtype=np.uint8)
|
| 401 |
|
| 402 |
# Validate augmented image
|
| 403 |
if aug_image is None or aug_image.size == 0:
|
|
|
|
| 414 |
|
| 415 |
def batch_augment_images(self, image_json_pairs, aug_configs, num_augmentations):
|
| 416 |
"""Batch process multiple images with multiple augmentation configurations"""
|
| 417 |
+
logger.info(f"Starting batch augmentation with {len(image_json_pairs)} pairs, {len(aug_configs)} configs, {num_augmentations} augmentations each")
|
| 418 |
self.augmented_results = []
|
| 419 |
results = []
|
| 420 |
|
| 421 |
+
for pair_idx, (image, json_data) in enumerate(image_json_pairs):
|
| 422 |
+
if image is None or json_data is None:
|
| 423 |
+
logger.warning(f"Skipping pair {pair_idx}: missing image or JSON data")
|
| 424 |
continue
|
| 425 |
|
| 426 |
try:
|
| 427 |
+
logger.info(f"Processing image pair {pair_idx}")
|
| 428 |
# Convert PIL image to NumPy
|
| 429 |
img_np = np.array(image)
|
| 430 |
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 431 |
|
| 432 |
+
# Load data - pass the JSON data directly
|
| 433 |
+
img_np, polygons, labels, original_areas, original_data, _ = self.load_labelme_data(json_data, img_np)
|
| 434 |
+
logger.info(f"Loaded {len(polygons)} polygons for image {pair_idx}")
|
| 435 |
|
| 436 |
# Apply each augmentation configuration
|
| 437 |
+
for config_idx, config in enumerate(aug_configs):
|
| 438 |
+
logger.info(f"Applying config {config_idx}: {config['aug_type']}")
|
| 439 |
for aug_idx in range(num_augmentations):
|
| 440 |
# Generate random parameter within range
|
| 441 |
min_val, max_val = config['param_range']
|
|
|
|
| 445 |
aug_param = random.uniform(min_val, max_val)
|
| 446 |
|
| 447 |
try:
|
| 448 |
+
logger.info(f"Generating augmentation {aug_idx} with {config['aug_type']}, param: {aug_param}")
|
| 449 |
aug_image, aug_data = self.augment_single_image(
|
| 450 |
img_np, polygons, labels, original_areas,
|
| 451 |
original_data, config['aug_type'], aug_param
|
|
|
|
| 471 |
|
| 472 |
self.augmented_results.append(result_data)
|
| 473 |
results.append(aug_image_vis)
|
| 474 |
+
logger.info(f"Successfully generated augmentation {aug_idx} for image {pair_idx}")
|
| 475 |
|
| 476 |
except Exception as e:
|
| 477 |
logger.error(f"Error augmenting image {pair_idx} with {config['aug_type']}: {str(e)}")
|
| 478 |
+
import traceback
|
| 479 |
+
logger.error(traceback.format_exc())
|
| 480 |
continue
|
| 481 |
|
| 482 |
except Exception as e:
|
| 483 |
logger.error(f"Error processing image pair {pair_idx}: {str(e)}")
|
| 484 |
+
import traceback
|
| 485 |
+
logger.error(traceback.format_exc())
|
| 486 |
continue
|
| 487 |
|
| 488 |
+
logger.info(f"Batch augmentation completed. Generated {len(results)} total results.")
|
| 489 |
return results
|
| 490 |
|
| 491 |
def create_visualization(self, aug_image, aug_data):
|
|
|
|
| 583 |
if images[i] is not None and json_files[i] is not None:
|
| 584 |
try:
|
| 585 |
image = Image.open(images[i].name)
|
| 586 |
+
# Load JSON file content properly
|
| 587 |
+
json_path = json_files[i].name
|
| 588 |
+
logger.info(f"Loading JSON from: {json_path}")
|
| 589 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
| 590 |
+
json_data = json.load(f)
|
| 591 |
+
logger.info(f"Successfully loaded JSON with keys: {list(json_data.keys())}")
|
| 592 |
+
image_json_pairs.append((image, json_data))
|
| 593 |
except Exception as e:
|
| 594 |
+
logger.error(f"Error loading image/JSON pair {i}: {str(e)}")
|
| 595 |
+
import traceback
|
| 596 |
+
logger.error(traceback.format_exc())
|
| 597 |
continue
|
| 598 |
|
| 599 |
if not image_json_pairs:
|
|
|
|
| 643 |
|
| 644 |
# Process augmentations
|
| 645 |
try:
|
| 646 |
+
logger.info(f"Starting batch augmentation with {len(image_json_pairs)} image pairs and {len(aug_configs)} configurations")
|
| 647 |
augmented_images = augmenter.batch_augment_images(
|
| 648 |
image_json_pairs, aug_configs, num_augmentations
|
| 649 |
)
|
|
|
|
| 652 |
json_summary = json.dumps([result['metadata'] for result in augmenter.augmented_results], indent=2)
|
| 653 |
|
| 654 |
status = f"Generated {len(augmented_images)} augmented images from {len(image_json_pairs)} input pairs"
|
| 655 |
+
logger.info(status)
|
| 656 |
return augmented_images, json_summary, status
|
| 657 |
|
| 658 |
except Exception as e:
|
| 659 |
+
error_msg = f"Batch augmentation error: {str(e)}"
|
| 660 |
+
logger.error(error_msg)
|
| 661 |
+
import traceback
|
| 662 |
+
logger.error(traceback.format_exc())
|
| 663 |
+
return [], error_msg, None
|
| 664 |
|
| 665 |
def download_package():
|
| 666 |
return augmenter.create_download_package()
|