Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,6 +11,9 @@ import random
|
|
| 11 |
from pathlib import Path
|
| 12 |
import colorsys
|
| 13 |
import logging
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Set up logging
|
| 16 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -22,6 +25,7 @@ class PolygonAugmentation:
|
|
| 22 |
self.area_threshold = area_threshold
|
| 23 |
self.debug = debug
|
| 24 |
self.supported_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.PNG', '.JPEG']
|
|
|
|
| 25 |
|
| 26 |
def __getattr__(self, name: str) -> Any:
|
| 27 |
raise AttributeError(f"'PolygonAugmentation' object has no attribute '{name}'")
|
|
@@ -381,44 +385,72 @@ class PolygonAugmentation:
|
|
| 381 |
logger.info(f"Augmentation completed: {len(aug_polygons)} polygons generated")
|
| 382 |
return aug_image, aug_data
|
| 383 |
|
| 384 |
-
def
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
"rotate": (-30, 30),
|
| 389 |
-
"horizontal_flip": (0, 1),
|
| 390 |
-
"vertical_flip": (0, 1),
|
| 391 |
-
"scale": (0.5, 1.5),
|
| 392 |
-
"brightness_contrast": (-0.3, 0.3),
|
| 393 |
-
"pixel_dropout": (0.01, 0.1)
|
| 394 |
-
}
|
| 395 |
-
if aug_type not in aug_ranges:
|
| 396 |
-
raise ValueError(f"Invalid augmentation type: {aug_type}")
|
| 397 |
-
min_val, max_val = aug_ranges[aug_type]
|
| 398 |
-
if not (min_val <= aug_param <= max_val):
|
| 399 |
-
raise ValueError(f"Parameter {aug_param} for {aug_type} is out of range [{min_val}, {max_val}]")
|
| 400 |
-
|
| 401 |
-
# Convert PIL image to NumPy
|
| 402 |
-
if image is None:
|
| 403 |
-
raise ValueError("Input image is None")
|
| 404 |
-
img_np = np.array(image)
|
| 405 |
-
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 406 |
-
|
| 407 |
-
# Initialize augmenter
|
| 408 |
-
augmenter = PolygonAugmentation(tolerance=2.0, area_threshold=0.01, debug=True)
|
| 409 |
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
# Create a dynamic color map for unique labels
|
| 423 |
unique_labels = list(set(shape['label'] for shape in aug_data['shapes']))
|
| 424 |
if not unique_labels:
|
|
@@ -434,20 +466,15 @@ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param:
|
|
| 434 |
|
| 435 |
# Convert augmented image to RGB for visualization
|
| 436 |
aug_image_rgb = cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB)
|
| 437 |
-
if aug_image_rgb is None or aug_image_rgb.size == 0:
|
| 438 |
-
raise ValueError("Failed to convert augmented image to RGB")
|
| 439 |
-
|
| 440 |
-
# Create a clean copy of the augmented image for visualization
|
| 441 |
overlay = aug_image_rgb.copy()
|
| 442 |
|
| 443 |
# Create masks and outlines for visualization
|
| 444 |
height, width = aug_image.shape[:2]
|
| 445 |
for shape in aug_data['shapes']:
|
| 446 |
label = shape['label']
|
| 447 |
-
color = label_color_map.get(label, (0, 255, 0))
|
| 448 |
points = np.array(shape['points'], dtype=np.int32)
|
| 449 |
if len(points) < 3:
|
| 450 |
-
logger.warning(f"Skipping invalid polygon for label {label}: fewer than 3 points")
|
| 451 |
continue
|
| 452 |
|
| 453 |
# Draw semi-transparent mask
|
|
@@ -455,84 +482,254 @@ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param:
|
|
| 455 |
cv2.fillPoly(mask, [points], 1)
|
| 456 |
colored_mask = np.zeros_like(aug_image_rgb)
|
| 457 |
colored_mask[mask == 1] = color
|
| 458 |
-
alpha = 0.3
|
| 459 |
overlay = cv2.addWeighted(overlay, 1.0, colored_mask, alpha, 0.0)
|
| 460 |
|
| 461 |
# Draw polygon outline
|
| 462 |
cv2.polylines(overlay, [points], isClosed=True, color=color, thickness=2)
|
| 463 |
|
| 464 |
-
|
| 465 |
-
aug_image_pil = Image.fromarray(overlay)
|
| 466 |
-
|
| 467 |
-
# Format JSON for display
|
| 468 |
-
aug_json_str = json.dumps(aug_data, indent=2)
|
| 469 |
-
|
| 470 |
-
logger.info("Visualization completed successfully")
|
| 471 |
-
return aug_image_pil, aug_json_str
|
| 472 |
-
except Exception as e:
|
| 473 |
-
logger.error(f"Error in augment_image: {str(e)}")
|
| 474 |
-
return None, f"Error: {str(e)}"
|
| 475 |
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
def create_interface():
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
with gr.Row():
|
| 492 |
-
with gr.Column():
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
)
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
label=
|
| 505 |
-
step=0.01
|
| 506 |
)
|
| 507 |
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
|
| 520 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
|
| 522 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
|
| 524 |
-
with gr.Column():
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
)
|
| 537 |
|
| 538 |
return demo
|
|
|
|
| 11 |
from pathlib import Path
|
| 12 |
import colorsys
|
| 13 |
import logging
|
| 14 |
+
import zipfile
|
| 15 |
+
import io
|
| 16 |
+
from datetime import datetime
|
| 17 |
|
| 18 |
# Set up logging
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 25 |
self.area_threshold = area_threshold
|
| 26 |
self.debug = debug
|
| 27 |
self.supported_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.PNG', '.JPEG']
|
| 28 |
+
self.augmented_results = [] # Store all augmentation results
|
| 29 |
|
| 30 |
def __getattr__(self, name: str) -> Any:
|
| 31 |
raise AttributeError(f"'PolygonAugmentation' object has no attribute '{name}'")
|
|
|
|
| 385 |
logger.info(f"Augmentation completed: {len(aug_polygons)} polygons generated")
|
| 386 |
return aug_image, aug_data
|
| 387 |
|
| 388 |
+
def batch_augment_images(self, image_json_pairs, aug_configs, num_augmentations):
|
| 389 |
+
"""Batch process multiple images with multiple augmentation configurations"""
|
| 390 |
+
self.augmented_results = []
|
| 391 |
+
results = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
+
for pair_idx, (image, json_file) in enumerate(image_json_pairs):
|
| 394 |
+
if image is None or json_file is None:
|
| 395 |
+
continue
|
| 396 |
+
|
| 397 |
+
try:
|
| 398 |
+
# Convert PIL image to NumPy
|
| 399 |
+
img_np = np.array(image)
|
| 400 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 401 |
+
|
| 402 |
+
# Load data
|
| 403 |
+
img_np, polygons, labels, original_areas, original_data, _ = self.load_labelme_data(json_file, img_np)
|
| 404 |
+
|
| 405 |
+
# Apply each augmentation configuration
|
| 406 |
+
for config in aug_configs:
|
| 407 |
+
for aug_idx in range(num_augmentations):
|
| 408 |
+
# Generate random parameter within range
|
| 409 |
+
min_val, max_val = config['param_range']
|
| 410 |
+
if config['aug_type'] in ['horizontal_flip', 'vertical_flip']:
|
| 411 |
+
aug_param = random.choice([0, 1])
|
| 412 |
+
else:
|
| 413 |
+
aug_param = random.uniform(min_val, max_val)
|
| 414 |
+
|
| 415 |
+
try:
|
| 416 |
+
aug_image, aug_data = self.augment_single_image(
|
| 417 |
+
img_np, polygons, labels, original_areas,
|
| 418 |
+
original_data, config['aug_type'], aug_param
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
# Create visualization
|
| 422 |
+
aug_image_vis = self.create_visualization(aug_image, aug_data)
|
| 423 |
+
|
| 424 |
+
# Store result
|
| 425 |
+
result_data = {
|
| 426 |
+
'image': aug_image_vis,
|
| 427 |
+
'json_data': aug_data,
|
| 428 |
+
'metadata': {
|
| 429 |
+
'original_image_index': pair_idx,
|
| 430 |
+
'augmentation_index': aug_idx,
|
| 431 |
+
'augmentation_type': config['aug_type'],
|
| 432 |
+
'parameter_value': aug_param,
|
| 433 |
+
'parameter_range': config['param_range'],
|
| 434 |
+
'timestamp': datetime.now().isoformat(),
|
| 435 |
+
'filename': f'aug_{pair_idx}_{config["aug_type"]}_{aug_idx}.png'
|
| 436 |
+
}
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
self.augmented_results.append(result_data)
|
| 440 |
+
results.append(aug_image_vis)
|
| 441 |
+
|
| 442 |
+
except Exception as e:
|
| 443 |
+
logger.error(f"Error augmenting image {pair_idx} with {config['aug_type']}: {str(e)}")
|
| 444 |
+
continue
|
| 445 |
+
|
| 446 |
+
except Exception as e:
|
| 447 |
+
logger.error(f"Error processing image pair {pair_idx}: {str(e)}")
|
| 448 |
+
continue
|
| 449 |
|
| 450 |
+
return results
|
| 451 |
+
|
| 452 |
+
def create_visualization(self, aug_image, aug_data):
|
| 453 |
+
"""Create visualization with colored polygons and masks"""
|
| 454 |
# Create a dynamic color map for unique labels
|
| 455 |
unique_labels = list(set(shape['label'] for shape in aug_data['shapes']))
|
| 456 |
if not unique_labels:
|
|
|
|
| 466 |
|
| 467 |
# Convert augmented image to RGB for visualization
|
| 468 |
aug_image_rgb = cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
overlay = aug_image_rgb.copy()
|
| 470 |
|
| 471 |
# Create masks and outlines for visualization
|
| 472 |
height, width = aug_image.shape[:2]
|
| 473 |
for shape in aug_data['shapes']:
|
| 474 |
label = shape['label']
|
| 475 |
+
color = label_color_map.get(label, (0, 255, 0))
|
| 476 |
points = np.array(shape['points'], dtype=np.int32)
|
| 477 |
if len(points) < 3:
|
|
|
|
| 478 |
continue
|
| 479 |
|
| 480 |
# Draw semi-transparent mask
|
|
|
|
| 482 |
cv2.fillPoly(mask, [points], 1)
|
| 483 |
colored_mask = np.zeros_like(aug_image_rgb)
|
| 484 |
colored_mask[mask == 1] = color
|
| 485 |
+
alpha = 0.3
|
| 486 |
overlay = cv2.addWeighted(overlay, 1.0, colored_mask, alpha, 0.0)
|
| 487 |
|
| 488 |
# Draw polygon outline
|
| 489 |
cv2.polylines(overlay, [points], isClosed=True, color=color, thickness=2)
|
| 490 |
|
| 491 |
+
return Image.fromarray(overlay)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
|
| 493 |
+
def create_download_package(self):
|
| 494 |
+
"""Create a zip file with all augmented images and JSON files"""
|
| 495 |
+
if not self.augmented_results:
|
| 496 |
+
return None
|
| 497 |
+
|
| 498 |
+
zip_buffer = io.BytesIO()
|
| 499 |
+
|
| 500 |
+
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
| 501 |
+
# Add all augmented images and their JSON files
|
| 502 |
+
for idx, result in enumerate(self.augmented_results):
|
| 503 |
+
# Save image
|
| 504 |
+
img_buffer = io.BytesIO()
|
| 505 |
+
result['image'].save(img_buffer, format='PNG')
|
| 506 |
+
img_filename = result['metadata']['filename']
|
| 507 |
+
zip_file.writestr(img_filename, img_buffer.getvalue())
|
| 508 |
+
|
| 509 |
+
# Save JSON data
|
| 510 |
+
json_filename = f"metadata_{idx}.json"
|
| 511 |
+
json_str = json.dumps(result['json_data'], indent=2)
|
| 512 |
+
zip_file.writestr(json_filename, json_str)
|
| 513 |
+
|
| 514 |
+
# Add summary metadata
|
| 515 |
+
summary = {
|
| 516 |
+
'total_augmentations': len(self.augmented_results),
|
| 517 |
+
'generation_timestamp': datetime.now().isoformat(),
|
| 518 |
+
'augmentation_summary': [result['metadata'] for result in self.augmented_results]
|
| 519 |
+
}
|
| 520 |
+
zip_file.writestr('augmentation_summary.json', json.dumps(summary, indent=2))
|
| 521 |
+
|
| 522 |
+
zip_buffer.seek(0)
|
| 523 |
+
return zip_buffer.getvalue()
|
| 524 |
|
| 525 |
def create_interface():
|
| 526 |
+
augmenter = PolygonAugmentation(tolerance=2.0, area_threshold=0.01, debug=True)
|
| 527 |
+
|
| 528 |
+
def process_batch_augmentation(
|
| 529 |
+
images, json_files, num_augmentations,
|
| 530 |
+
rotate_enabled, rotate_min, rotate_max,
|
| 531 |
+
hflip_enabled, vflip_enabled,
|
| 532 |
+
scale_enabled, scale_min, scale_max,
|
| 533 |
+
brightness_enabled, brightness_min, brightness_max,
|
| 534 |
+
dropout_enabled, dropout_min, dropout_max
|
| 535 |
+
):
|
| 536 |
+
if not images or not json_files:
|
| 537 |
+
return [], "No images or JSON files uploaded", None
|
| 538 |
+
|
| 539 |
+
# Pair images with JSON files
|
| 540 |
+
image_json_pairs = []
|
| 541 |
+
min_length = min(len(images), len(json_files))
|
| 542 |
+
|
| 543 |
+
for i in range(min_length):
|
| 544 |
+
if images[i] is not None and json_files[i] is not None:
|
| 545 |
+
try:
|
| 546 |
+
image = Image.open(images[i].name)
|
| 547 |
+
image_json_pairs.append((image, images[i].name))
|
| 548 |
+
except Exception as e:
|
| 549 |
+
logger.error(f"Error loading image {i}: {str(e)}")
|
| 550 |
+
continue
|
| 551 |
+
|
| 552 |
+
if not image_json_pairs:
|
| 553 |
+
return [], "No valid image-JSON pairs found", None
|
| 554 |
+
|
| 555 |
+
# Configure augmentations based on user selections
|
| 556 |
+
aug_configs = []
|
| 557 |
+
|
| 558 |
+
if rotate_enabled:
|
| 559 |
+
aug_configs.append({
|
| 560 |
+
'aug_type': 'rotate',
|
| 561 |
+
'param_range': (rotate_min, rotate_max)
|
| 562 |
+
})
|
| 563 |
+
|
| 564 |
+
if hflip_enabled:
|
| 565 |
+
aug_configs.append({
|
| 566 |
+
'aug_type': 'horizontal_flip',
|
| 567 |
+
'param_range': (0, 1)
|
| 568 |
+
})
|
| 569 |
+
|
| 570 |
+
if vflip_enabled:
|
| 571 |
+
aug_configs.append({
|
| 572 |
+
'aug_type': 'vertical_flip',
|
| 573 |
+
'param_range': (0, 1)
|
| 574 |
+
})
|
| 575 |
+
|
| 576 |
+
if scale_enabled:
|
| 577 |
+
aug_configs.append({
|
| 578 |
+
'aug_type': 'scale',
|
| 579 |
+
'param_range': (scale_min, scale_max)
|
| 580 |
+
})
|
| 581 |
+
|
| 582 |
+
if brightness_enabled:
|
| 583 |
+
aug_configs.append({
|
| 584 |
+
'aug_type': 'brightness_contrast',
|
| 585 |
+
'param_range': (brightness_min, brightness_max)
|
| 586 |
+
})
|
| 587 |
+
|
| 588 |
+
if dropout_enabled:
|
| 589 |
+
aug_configs.append({
|
| 590 |
+
'aug_type': 'pixel_dropout',
|
| 591 |
+
'param_range': (dropout_min, dropout_max)
|
| 592 |
+
})
|
| 593 |
+
|
| 594 |
+
if not aug_configs:
|
| 595 |
+
return [], "No augmentation types selected", None
|
| 596 |
+
|
| 597 |
+
# Process augmentations
|
| 598 |
+
try:
|
| 599 |
+
augmented_images = augmenter.batch_augment_images(
|
| 600 |
+
image_json_pairs, aug_configs, num_augmentations
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
# Create JSON summary
|
| 604 |
+
json_summary = json.dumps([result['metadata'] for result in augmenter.augmented_results], indent=2)
|
| 605 |
+
|
| 606 |
+
status = f"Generated {len(augmented_images)} augmented images from {len(image_json_pairs)} input pairs"
|
| 607 |
+
return augmented_images, json_summary, status
|
| 608 |
+
|
| 609 |
+
except Exception as e:
|
| 610 |
+
logger.error(f"Batch augmentation error: {str(e)}")
|
| 611 |
+
return [], f"Error: {str(e)}", None
|
| 612 |
+
|
| 613 |
+
def download_package():
|
| 614 |
+
return augmenter.create_download_package()
|
| 615 |
+
|
| 616 |
+
def show_mask_overlay(evt: gr.SelectData):
|
| 617 |
+
if evt.index < len(augmenter.augmented_results):
|
| 618 |
+
return augmenter.augmented_results[evt.index]['image']
|
| 619 |
+
return None
|
| 620 |
+
|
| 621 |
+
with gr.Blocks(title="Dynamic Donut Polygon Augmentation") as demo:
|
| 622 |
+
gr.Markdown("# π Dynamic Donut Polygon Augmentation Tool")
|
| 623 |
+
gr.Markdown("Upload multiple images and JSON files to apply batch augmentation with configurable parameter ranges")
|
| 624 |
|
| 625 |
with gr.Row():
|
| 626 |
+
with gr.Column(scale=1):
|
| 627 |
+
gr.Markdown("## π Input Files")
|
| 628 |
+
|
| 629 |
+
images_input = gr.File(
|
| 630 |
+
file_count="multiple",
|
| 631 |
+
file_types=["image"],
|
| 632 |
+
label="Upload Images"
|
| 633 |
)
|
| 634 |
+
|
| 635 |
+
json_input = gr.File(
|
| 636 |
+
file_count="multiple",
|
| 637 |
+
file_types=[".json"],
|
| 638 |
+
label="Upload LabelMe JSON Files"
|
|
|
|
| 639 |
)
|
| 640 |
|
| 641 |
+
num_augmentations = gr.Slider(
|
| 642 |
+
minimum=1, maximum=5, value=2, step=1,
|
| 643 |
+
label="Augmentations per configuration"
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
gr.Markdown("## βοΈ Augmentation Configuration")
|
| 647 |
+
|
| 648 |
+
# Rotation parameters
|
| 649 |
+
with gr.Group():
|
| 650 |
+
rotate_enabled = gr.Checkbox(label="Enable Rotation", value=True)
|
| 651 |
+
with gr.Row():
|
| 652 |
+
rotate_min = gr.Slider(-45, 45, -15, label="Min Rotation (degrees)")
|
| 653 |
+
rotate_max = gr.Slider(-45, 45, 15, label="Max Rotation (degrees)")
|
| 654 |
+
|
| 655 |
+
# Flip parameters
|
| 656 |
+
with gr.Group():
|
| 657 |
+
hflip_enabled = gr.Checkbox(label="Enable Horizontal Flip", value=True)
|
| 658 |
+
vflip_enabled = gr.Checkbox(label="Enable Vertical Flip", value=False)
|
| 659 |
|
| 660 |
+
# Scale parameters
|
| 661 |
+
with gr.Group():
|
| 662 |
+
scale_enabled = gr.Checkbox(label="Enable Scale", value=True)
|
| 663 |
+
with gr.Row():
|
| 664 |
+
scale_min = gr.Slider(0.7, 1.3, 0.9, label="Min Scale")
|
| 665 |
+
scale_max = gr.Slider(0.7, 1.3, 1.1, label="Max Scale")
|
| 666 |
|
| 667 |
+
# Brightness parameters
|
| 668 |
+
with gr.Group():
|
| 669 |
+
brightness_enabled = gr.Checkbox(label="Enable Brightness/Contrast", value=True)
|
| 670 |
+
with gr.Row():
|
| 671 |
+
brightness_min = gr.Slider(-0.3, 0.3, -0.1, label="Min Brightness")
|
| 672 |
+
brightness_max = gr.Slider(-0.3, 0.3, 0.1, label="Max Brightness")
|
| 673 |
+
|
| 674 |
+
# Dropout parameters
|
| 675 |
+
with gr.Group():
|
| 676 |
+
dropout_enabled = gr.Checkbox(label="Enable Pixel Dropout", value=False)
|
| 677 |
+
with gr.Row():
|
| 678 |
+
dropout_min = gr.Slider(0.01, 0.1, 0.02, label="Min Dropout")
|
| 679 |
+
dropout_max = gr.Slider(0.01, 0.1, 0.05, label="Max Dropout")
|
| 680 |
+
|
| 681 |
+
generate_btn = gr.Button("π Generate Augmentations", variant="primary")
|
| 682 |
+
status_text = gr.Textbox(label="Status", interactive=False)
|
| 683 |
|
| 684 |
+
with gr.Column(scale=2):
|
| 685 |
+
gr.Markdown("## πΌοΈ Augmented Results")
|
| 686 |
+
gr.Markdown("*Click on any image to view with enhanced mask overlay*")
|
| 687 |
+
|
| 688 |
+
augmented_gallery = gr.Gallery(
|
| 689 |
+
label="Augmented Images with Polygon Masks",
|
| 690 |
+
show_label=False,
|
| 691 |
+
elem_id="gallery",
|
| 692 |
+
columns=3,
|
| 693 |
+
rows=3,
|
| 694 |
+
height="auto"
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
with gr.Row():
|
| 698 |
+
download_btn = gr.Button("π₯ Download All (ZIP)", variant="secondary")
|
| 699 |
+
download_file = gr.File(label="Download Package", visible=False)
|
| 700 |
+
|
| 701 |
+
gr.Markdown("## π Augmentation Metadata")
|
| 702 |
+
json_output = gr.Code(
|
| 703 |
+
label="Generated Metadata JSON",
|
| 704 |
+
language="json",
|
| 705 |
+
lines=15
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
gr.Markdown("## π Enhanced Preview")
|
| 709 |
+
mask_preview = gr.Image(label="Selected Image with Mask Overlay")
|
| 710 |
+
|
| 711 |
+
# Event handlers
|
| 712 |
+
generate_btn.click(
|
| 713 |
+
process_batch_augmentation,
|
| 714 |
+
inputs=[
|
| 715 |
+
images_input, json_input, num_augmentations,
|
| 716 |
+
rotate_enabled, rotate_min, rotate_max,
|
| 717 |
+
hflip_enabled, vflip_enabled,
|
| 718 |
+
scale_enabled, scale_min, scale_max,
|
| 719 |
+
brightness_enabled, brightness_min, brightness_max,
|
| 720 |
+
dropout_enabled, dropout_min, dropout_max
|
| 721 |
+
],
|
| 722 |
+
outputs=[augmented_gallery, json_output, status_text]
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
download_btn.click(
|
| 726 |
+
download_package,
|
| 727 |
+
outputs=download_file
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
augmented_gallery.select(
|
| 731 |
+
show_mask_overlay,
|
| 732 |
+
outputs=mask_preview
|
| 733 |
)
|
| 734 |
|
| 735 |
return demo
|