import gradio as gr from PIL import Image import numpy as np import cv2 import json import albumentations as A from typing import List, Tuple, Dict, Any import supervision as sv import uuid import random from pathlib import Path import colorsys import logging import zipfile import io from datetime import datetime # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class PolygonAugmentation: def __init__(self, tolerance=0.2, area_threshold=0.01, debug=False): self.tolerance = tolerance self.area_threshold = area_threshold self.debug = debug self.supported_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.PNG', '.JPEG'] self.augmented_results = [] # Store all augmentation results def __getattr__(self, name: str) -> Any: raise AttributeError(f"'PolygonAugmentation' object has no attribute '{name}'") def calculate_polygon_area(self, points: List[List[float]]) -> float: poly_np = np.array(points, dtype=np.float32) area = cv2.contourArea(poly_np) if self.debug: logger.info(f"[DEBUG] Calculating polygon area: {area:.2f}") return area def load_labelme_data(self, json_file: Any, image: np.ndarray) -> Tuple: if isinstance(json_file, str): with open(json_file, 'r', encoding='utf-8') as f: data = json.load(f) elif isinstance(json_file, dict): # Handle dictionary data directly data = json_file else: # Handle file object data = json.load(json_file) shapes = [] if 'shapes' in data and isinstance(data['shapes'], list): shapes = data['shapes'] elif 'segments' in data and isinstance(data['segments'], list): shapes = [ { "label": seg.get("class", "unknown"), "points": seg.get("polygon", []), "shape_type": "polygon", "group_id": None, "flags": {}, "confidence": seg.get("confidence", 1.0) } for seg in data['segments'] ] else: raise ValueError("Invalid JSON: Neither 'shapes' nor 'segments' key found or not a list") polygons = [] labels = [] original_areas = [] for shape in shapes: if shape.get('shape_type') != 'polygon' or not shape.get('points') or len(shape['points']) < 3: if self.debug: logger.info(f"[DEBUG] Skipping invalid shape: {shape}") continue try: points = [[float(x), float(y)] for x, y in shape['points']] polygons.append(points) labels.append(shape['label']) original_areas.append(self.calculate_polygon_area(points)) except (ValueError, TypeError) as e: if self.debug: logger.info(f"[DEBUG] Error processing points: {shape['points']}, error: {str(e)}") continue if not polygons and self.debug: logger.info(f"[DEBUG] Warning: No valid polygons in JSON") return image, polygons, labels, original_areas, data, "input" def simplify_polygon(self, polygon: List[List[float]], tolerance: float = None, label: str = None) -> List[List[float]]: tol = tolerance if tolerance is not None else self.tolerance if label and label.lower() in ['background', 'bg', 'back']: tol = tol * 3 if self.debug: logger.info(f"[DEBUG] Using increased tolerance {tol} for background label '{label}'") if len(polygon) < 3: if self.debug: logger.info(f"[DEBUG] Polygon has fewer than 3 points, skipping simplification.") return polygon poly_np = np.array(polygon, dtype=np.float32) approx = cv2.approxPolyDP(poly_np, tol, closed=True) simplified = approx.reshape(-1, 2).tolist() if self.debug: logger.info(f"[DEBUG] Simplified polygon from {len(polygon)} to {len(simplified)} points with tolerance {tol}") return simplified def create_donut_polygon(self, external_contour: np.ndarray, internal_contours: List[np.ndarray]) -> List[List[float]]: """Create a donut/ring polygon by connecting external and internal contours with bridges""" external_points = external_contour.reshape(-1, 2).tolist() if not internal_contours: if self.debug: logger.info("[DEBUG] No internal contours found, returning external points.") return external_points # Start with external contour points result_points = external_points.copy() # Process each internal contour (hole) for hole_idx, internal_contour in enumerate(internal_contours): internal_points = internal_contour.reshape(-1, 2).tolist() # Find the closest point between external and internal contours min_dist = float('inf') best_ext_idx = 0 best_int_idx = 0 # Check all combinations to find minimum distance for i, ext_point in enumerate(result_points): for j, int_point in enumerate(internal_points): dist = np.sqrt((ext_point[0] - int_point[0])**2 + (ext_point[1] - int_point[1])**2) if dist < min_dist: min_dist = dist best_ext_idx = i best_int_idx = j # Create bridge points bridge_start = result_points[best_ext_idx] connect_point = internal_points[best_int_idx] if self.debug: logger.info(f"[DEBUG] Creating bridge for hole {hole_idx}: ext_idx={best_ext_idx}, int_idx={best_int_idx}, distance={min_dist:.2f}") # Insert the internal contour into the result # Order: external_points[:best_ext_idx+1] + internal_hole + back_to_external + external_points[best_ext_idx+1:] new_result = ( result_points[:best_ext_idx+1] + # External points up to bridge internal_points[best_int_idx:] + # Internal points from connection point to end internal_points[:best_int_idx+1] + # Internal points from start to connection point [bridge_start] + # Bridge back to external result_points[best_ext_idx+1:] # Remaining external points ) result_points = new_result if self.debug: logger.info(f"[DEBUG] Created donut polygon with {len(result_points)} total points") return result_points def save_augmented_data( self, aug_image: np.ndarray, aug_polygons: List[List[List[float]]], aug_labels: List[str], original_data: Dict[str, Any], base_name: str ) -> Dict[str, Any]: aug_id = uuid.uuid4().hex[:4] aug_img_name = f"{base_name}_{aug_id}_aug.png" new_shapes = [] for poly, label in zip(aug_polygons, aug_labels): if not poly or len(poly) < 3: continue # Create LabelMe format shape shape_data = { "label": label, "points": poly, "group_id": None, "shape_type": "polygon", "flags": {}, "description": "", "attributes": {}, "iscrowd": 0, "difficult": 0 } # Add additional metadata for special polygon types if label.lower() in ['ring', 'donut', 'annulus', 'circle', 'round']: shape_data["attributes"]["polygon_type"] = "ring" elif label.lower() in ['background', 'bg', 'back']: shape_data["attributes"]["polygon_type"] = "background" else: shape_data["attributes"]["polygon_type"] = "object" new_shapes.append(shape_data) # Get actual dimensions from augmented image aug_height, aug_width = aug_image.shape[:2] # Create LabelMe compatible JSON structure aug_data = { "version": original_data.get("version", "5.0.1"), "flags": original_data.get("flags", {}), "shapes": new_shapes, "imagePath": aug_img_name, "imageData": None, # Explicitly set to None as requested "imageHeight": aug_height, "imageWidth": aug_width, "imageDepth": 3 if len(aug_image.shape) == 3 else 1, # Additional LabelMe metadata "lineColor": [0, 255, 0, 128], "fillColor": [255, 0, 0, 128], "textSize": 10, "textColor": [0, 0, 0, 255], # Augmentation metadata "augmentation": { "augmented": True, "augmentation_id": aug_id, "original_file": original_data.get("imagePath", "unknown"), "augmentation_timestamp": datetime.now().isoformat(), "augmentation_tool": "PolygonAugmentation v1.0" } } if self.debug: logger.info(f"[DEBUG] Created LabelMe JSON: {len(new_shapes)} shapes, size: {aug_width}x{aug_height}") logger.info(f"[DEBUG] Shape types: {[s['attributes'].get('polygon_type', 'unknown') for s in new_shapes]}") return aug_data def polygons_to_masks(self, image: np.ndarray, polygons: List[List[List[float]]], labels: List[str]) -> Tuple[np.ndarray, List[str]]: height, width = image.shape[:2] all_masks = [] all_labels = [] for poly_idx, (poly, label) in enumerate(zip(polygons, labels)): try: poly_np = np.array(poly, dtype=np.int32) if len(poly_np) < 3: if self.debug: logger.info(f"[DEBUG] Skipping polygon {poly_idx}: fewer than 3 points") continue mask = np.zeros((height, width), dtype=np.uint8) cv2.fillPoly(mask, [poly_np], 1) all_masks.append(mask) all_labels.append(label) except Exception as e: if self.debug: logger.info(f"[DEBUG] Error processing polygon {poly_idx}: {str(e)}") if not all_masks: return np.zeros((0, height, width), dtype=np.uint8), [] return np.array(all_masks, dtype=np.uint8), all_labels def process_contours( self, external_contour: np.ndarray, internal_contours: List[np.ndarray], width: int, height: int, label: str, all_polygons: List[List[List[float]]], all_labels: List[str], tolerance: float = None ) -> None: tol = tolerance if tolerance is not None else self.tolerance external_points = external_contour.reshape(-1, 2).tolist() simplified_external = self.simplify_polygon(external_points, tolerance=tol, label=label) if len(simplified_external) >= 3: poly_labelme = [[round(max(0, min(float(x), width - 1)), 2), round(max(0, min(float(y), height - 1)), 2)] for x, y in simplified_external] all_polygons.append(poly_labelme) all_labels.append(label) if self.debug: logger.info(f"[DEBUG] Added simplified external polygon with {len(poly_labelme)} points.") for internal_contour in internal_contours: internal_points = internal_contour.reshape(-1, 2).tolist() simplified_internal = self.simplify_polygon(internal_points, tolerance=tol, label=label) if len(simplified_internal) >= 3: poly_labelme = [[round(max(0, min(float(x), width - 1)), 2), round(max(0, min(float(y), height - 1)), 2)] for x, y in simplified_internal] all_polygons.append(poly_labelme) all_labels.append(label) if self.debug: logger.info(f"[DEBUG] Added simplified internal polygon with {len(poly_labelme)} points.") def masks_to_labelme_polygons( self, masks: np.ndarray, labels: List[str], original_areas: List[float], area_threshold: float = None, tolerance: float = None ) -> Tuple[List[List[List[float]]], List[str]]: tol = tolerance if tolerance is not None else self.tolerance area_thresh = area_threshold if area_threshold is not None else self.area_threshold height, width = masks[0].shape if len(masks) > 0 else (0, 0) all_polygons = [] all_labels = [] for mask_idx, (mask, label) in enumerate(zip(masks, labels)): if mask.sum() < 10: if self.debug: logger.info(f"[DEBUG] Skipping mask {mask_idx}: very small or empty.") continue contours, hierarchy = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE) if hierarchy is None or len(contours) == 0: if self.debug: logger.info(f"[DEBUG] No contours found in mask {mask_idx}.") continue hierarchy = hierarchy[0] external_contours = [] internal_contours_map = {} for i, (contour, h) in enumerate(zip(contours, hierarchy)): if h[3] == -1: external_contours.append(contour) internal_contours_map[len(external_contours)-1] = [] else: parent_idx = h[3] for j, _ in enumerate(external_contours): if parent_idx == j: internal_contours_map[j].append(contour) break if not external_contours: if self.debug: logger.info(f"[DEBUG] No external contours found in mask {mask_idx}.") continue for ext_idx, external_contour in enumerate(external_contours): internal_contours = internal_contours_map.get(ext_idx, []) ext_area = cv2.contourArea(external_contour) if ext_area <= 0: continue if mask_idx < len(original_areas) and original_areas[mask_idx] > 0: relative_area = ext_area / original_areas[mask_idx] if relative_area < area_thresh: if self.debug: logger.info(f"[DEBUG] Skipping contour {ext_idx} (area too small: {relative_area:.4f})") continue # Check if this is a ring/donut shape or complex polygon is_ring_shape = label.lower() in ['ring', 'donut', 'annulus', 'circle', 'round'] or len(internal_contours) > 0 is_background = label.lower() in ['background', 'bg', 'back'] # Handle different polygon types if (is_background or is_ring_shape) and internal_contours: try: # Create donut polygon for rings, backgrounds, or shapes with holes donut_points = self.create_donut_polygon(external_contour, internal_contours) simplified_donut = self.simplify_polygon(donut_points, tolerance=tol, label=label) if len(simplified_donut) >= 3: # Ensure all points are within image boundaries poly_labelme = [] for x, y in simplified_donut: clipped_x = round(max(0, min(float(x), width - 1)), 2) clipped_y = round(max(0, min(float(y), height - 1)), 2) poly_labelme.append([clipped_x, clipped_y]) all_polygons.append(poly_labelme) all_labels.append(label) if self.debug: logger.info(f"[DEBUG] Added {'ring' if is_ring_shape else 'background'} donut polygon with {len(poly_labelme)} points, {len(internal_contours)} holes") else: if self.debug: logger.info(f"[DEBUG] Donut polygon too small after simplification, falling back to separate contours") # Fallback to separate contours self.process_contours( external_contour, internal_contours, width, height, label, all_polygons, all_labels, tol ) except Exception as e: if self.debug: logger.info(f"[DEBUG] Error creating donut for {label}: {str(e)}, fallback to separate polygons.") # Fallback to processing contours separately self.process_contours( external_contour, internal_contours, width, height, label, all_polygons, all_labels, tol ) else: # Handle regular polygons (no holes or simple shapes) self.process_contours( external_contour, internal_contours, width, height, label, all_polygons, all_labels, tol ) return all_polygons, all_labels def augment_single_image( self, image: np.ndarray, polygons: List[List[List[float]]], labels: List[str], original_areas: List[float], original_data: Dict[str, Any], aug_type: str, aug_param: float ) -> Tuple[np.ndarray, Dict[str, Any]]: logger.info(f"Applying augmentation: {aug_type} with parameter {aug_param}") height, width = image.shape[:2] # Setup augmentation based on type with proper parameters if aug_type == "rotate": # For rotation, use the parameter as degrees and make it more visible rotation_angle = aug_param if abs(aug_param) >= 5 else (15 if aug_param >= 0 else -15) # Use angle directly (not abs) and set limit as tuple for specific angle aug_transform = A.Rotate(limit=(rotation_angle, rotation_angle), p=1.0, border_mode=cv2.BORDER_CONSTANT, value=0) logger.info(f"Applying rotation: {rotation_angle} degrees") elif aug_type == "horizontal_flip": aug_transform = A.HorizontalFlip(p=1.0 if aug_param == 1 else 0.0) elif aug_type == "vertical_flip": aug_transform = A.VerticalFlip(p=1.0 if aug_param == 1 else 0.0) elif aug_type == "scale": # Ensure scale parameter is reasonable scale_factor = max(0.5, min(2.0, aug_param)) aug_transform = A.Affine(scale=scale_factor, p=1.0, keep_ratio=True) logger.info(f"Applying scale: {scale_factor}") elif aug_type == "brightness_contrast": brightness_factor = max(-0.5, min(0.5, aug_param)) aug_transform = A.RandomBrightnessContrast( brightness_limit=abs(brightness_factor), contrast_limit=abs(brightness_factor), p=1.0 ) elif aug_type == "pixel_dropout": dropout_prob = min(max(aug_param, 0.0), 0.2) aug_transform = A.PixelDropout(dropout_prob=dropout_prob, p=1.0) else: raise ValueError(f"Unsupported augmentation type: {aug_type}") # Create masks from polygons masks, mask_labels = self.polygons_to_masks(image, polygons, labels) if masks.shape[0] == 0: raise ValueError("No valid masks created from polygons") # Convert masks array to list for albumentations masks_list = [masks[i] for i in range(masks.shape[0])] # Create additional targets for each mask additional_targets = {f'mask{i}': 'mask' for i in range(len(masks_list))} # Create transform with proper mask handling transform = A.Compose([ aug_transform ], additional_targets=additional_targets) # Prepare input dictionary input_dict = {'image': image} for i, mask in enumerate(masks_list): input_dict[f'mask{i}'] = mask # Apply augmentation aug_result = transform(**input_dict) aug_image = aug_result['image'] # Collect augmented masks and ensure they match image dimensions aug_masks_list = [] aug_height, aug_width = aug_image.shape[:2] for i in range(len(masks_list)): aug_mask = aug_result[f'mask{i}'] # Ensure mask dimensions match augmented image if aug_mask.shape[:2] != (aug_height, aug_width): aug_mask = cv2.resize(aug_mask, (aug_width, aug_height), interpolation=cv2.INTER_NEAREST) aug_masks_list.append(aug_mask) aug_masks = np.array(aug_masks_list, dtype=np.uint8) # Validate augmented image if aug_image is None or aug_image.size == 0: raise ValueError("Augmented image is empty or invalid") # Convert augmented masks back to polygons aug_polygons, aug_labels = self.masks_to_labelme_polygons( aug_masks, mask_labels, original_areas, self.area_threshold, self.tolerance ) # Apply random crop as post-processing to add variety if random.random() < 0.3: # 30% chance of cropping crop_scale = random.uniform(0.85, 0.95) crop_height = int(aug_height * crop_scale) crop_width = int(aug_width * crop_scale) # Create crop transform crop_transform = A.Compose([ A.RandomCrop(width=crop_width, height=crop_height, p=1.0) ], additional_targets={f'mask{i}': 'mask' for i in range(len(aug_masks_list))}) # Apply crop crop_input = {'image': aug_image} for i, mask in enumerate(aug_masks_list): crop_input[f'mask{i}'] = mask crop_result = crop_transform(**crop_input) aug_image = crop_result['image'] # Update masks after crop cropped_masks = [] for i in range(len(aug_masks_list)): cropped_masks.append(crop_result[f'mask{i}']) aug_masks = np.array(cropped_masks, dtype=np.uint8) # Re-convert masks to polygons after crop aug_polygons, aug_labels = self.masks_to_labelme_polygons( aug_masks, mask_labels, original_areas, self.area_threshold, self.tolerance ) # Create augmented data with correct dimensions aug_data = self.save_augmented_data(aug_image, aug_polygons, aug_labels, original_data, "input") logger.info(f"Augmentation completed: {len(aug_polygons)} polygons generated, final size: {aug_image.shape[:2]}") return aug_image, aug_data def batch_augment_images(self, image_json_pairs, aug_configs, num_augmentations): """Batch process multiple images with multiple augmentation configurations""" logger.info(f"Starting batch augmentation with {len(image_json_pairs)} pairs, {len(aug_configs)} configs, {num_augmentations} augmentations each") self.augmented_results = [] results = [] for pair_idx, (image, json_data) in enumerate(image_json_pairs): if image is None or json_data is None: logger.warning(f"Skipping pair {pair_idx}: missing image or JSON data") continue try: logger.info(f"Processing image pair {pair_idx}") # Convert PIL image to NumPy img_np = np.array(image) img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) # Load data - pass the JSON data directly img_np, polygons, labels, original_areas, original_data, _ = self.load_labelme_data(json_data, img_np) logger.info(f"Loaded {len(polygons)} polygons for image {pair_idx}") # Apply each augmentation configuration for config_idx, config in enumerate(aug_configs): logger.info(f"Applying config {config_idx}: {config['aug_type']}") for aug_idx in range(num_augmentations): # Generate random parameter within range min_val, max_val = config['param_range'] if config['aug_type'] in ['horizontal_flip', 'vertical_flip']: aug_param = random.choice([0, 1]) else: aug_param = random.uniform(min_val, max_val) try: logger.info(f"Generating augmentation {aug_idx} with {config['aug_type']}, param: {aug_param}") aug_image, aug_data = self.augment_single_image( img_np, polygons, labels, original_areas, original_data, config['aug_type'], aug_param ) # Create visualization aug_image_vis = self.create_visualization(aug_image, aug_data) # Store result result_data = { 'image': aug_image_vis, 'json_data': aug_data, 'metadata': { 'original_image_index': pair_idx, 'augmentation_index': aug_idx, 'augmentation_type': config['aug_type'], 'parameter_value': aug_param, 'parameter_range': config['param_range'], 'timestamp': datetime.now().isoformat(), 'filename': f'aug_{pair_idx}_{config["aug_type"]}_{aug_idx}.png' } } self.augmented_results.append(result_data) results.append(aug_image_vis) logger.info(f"Successfully generated augmentation {aug_idx} for image {pair_idx}") except Exception as e: logger.error(f"Error augmenting image {pair_idx} with {config['aug_type']}: {str(e)}") import traceback logger.error(traceback.format_exc()) continue except Exception as e: logger.error(f"Error processing image pair {pair_idx}: {str(e)}") import traceback logger.error(traceback.format_exc()) continue logger.info(f"Batch augmentation completed. Generated {len(results)} total results.") return results def create_visualization(self, aug_image, aug_data): """Create visualization with colored polygon masks and outlines for each class""" # Create a dynamic color map for unique labels with better color distribution unique_labels = list(set(shape['label'] for shape in aug_data['shapes'])) if not unique_labels: label_color_map = {"unknown": (0, 255, 0)} else: num_labels = len(unique_labels) # Create more distinct colors using different hue ranges label_color_map = {} for i, label in enumerate(unique_labels): if label.lower() in ['background', 'bg', 'back']: # Background gets a neutral gray-blue color rgb = (100, 149, 237) # Cornflower blue with low opacity elif 'ring' in label.lower() or 'donut' in label.lower(): # Ring/donut shapes get purple-pink colors hue = 0.8 + (i * 0.1) % 0.2 # Purple range rgb = colorsys.hsv_to_rgb(hue, 0.8, 0.9) rgb = tuple(int(c * 255) for c in rgb) else: # Regular objects get distributed colors across the spectrum hue = (i * 0.618033988749895) % 1.0 # Golden ratio for better distribution saturation = 0.7 + (i % 3) * 0.1 # Vary saturation value = 0.8 + (i % 2) * 0.15 # Vary brightness rgb = colorsys.hsv_to_rgb(hue, saturation, value) rgb = tuple(int(c * 255) for c in rgb) label_color_map[label] = rgb # Convert augmented image to RGB for visualization aug_image_rgb = cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB) overlay = aug_image_rgb.copy() height, width = aug_image.shape[:2] # Create a composite mask to handle overlapping polygons composite_mask = np.zeros((height, width, 3), dtype=np.uint8) # Group shapes by label for better visualization shapes_by_label = {} for shape in aug_data['shapes']: label = shape['label'] if label not in shapes_by_label: shapes_by_label[label] = [] shapes_by_label[label].append(shape) # Process each label group for label, shapes in shapes_by_label.items(): color = label_color_map.get(label, (0, 255, 0)) # Create mask for all polygons of this label label_mask = np.zeros((height, width), dtype=np.uint8) for shape in shapes: points = np.array(shape['points'], dtype=np.int32) if len(points) < 3: continue # Fill the polygon area cv2.fillPoly(label_mask, [points], 255) # Apply color to the mask areas if label_mask.sum() > 0: # Only if mask has content # Determine alpha based on label type if label.lower() in ['background', 'bg', 'back']: alpha = 0.15 # Lower opacity for background elif 'ring' in label.lower() or 'donut' in label.lower(): alpha = 0.4 # Medium opacity for rings else: alpha = 0.35 # Standard opacity for objects # Create colored mask colored_mask = np.zeros_like(aug_image_rgb) colored_mask[label_mask == 255] = color # Blend with overlay mask_area = label_mask == 255 overlay[mask_area] = cv2.addWeighted( overlay[mask_area], 1.0 - alpha, colored_mask[mask_area], alpha, 0 ) # Draw polygon outlines with thicker lines for better visibility for shape in aug_data['shapes']: label = shape['label'] color = label_color_map.get(label, (0, 255, 0)) points = np.array(shape['points'], dtype=np.int32) if len(points) < 3: continue # Determine line thickness based on polygon type if label.lower() in ['background', 'bg', 'back']: thickness = 1 # Thinner lines for background elif 'ring' in label.lower() or 'donut' in label.lower(): thickness = 3 # Thicker lines for rings to show structure else: thickness = 2 # Standard thickness # Draw polygon outline cv2.polylines(overlay, [points], isClosed=True, color=color, thickness=thickness) # Add label text near the polygon if len(points) > 0: # Find a good position for the label moments = cv2.moments(points) if moments['m00'] != 0: cx = int(moments['m10'] / moments['m00']) cy = int(moments['m01'] / moments['m00']) else: cx, cy = points[0][0], points[0][1] # Ensure text position is within image bounds cx = max(10, min(cx, width - 50)) cy = max(20, min(cy, height - 10)) # Add text background for better readability font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.4 text_thickness = 1 text_size = cv2.getTextSize(label, font, font_scale, text_thickness)[0] # Draw background rectangle cv2.rectangle(overlay, (cx - 2, cy - text_size[1] - 4), (cx + text_size[0] + 2, cy + 2), (0, 0, 0), -1) # Draw text cv2.putText(overlay, label, (cx, cy - 2), font, font_scale, color, text_thickness) if self.debug: logger.info(f"[DEBUG] Created visualization with {len(unique_labels)} unique labels: {list(unique_labels)}") return Image.fromarray(overlay) def create_download_package(self): """Create a zip file with all augmented images and proper LabelMe JSON files""" if not self.augmented_results: logger.warning("No augmented results available for download") return None logger.info(f"Creating download package with {len(self.augmented_results)} results") zip_buffer = io.BytesIO() try: with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: # Add all augmented images and their corresponding LabelMe JSON files for idx, result in enumerate(self.augmented_results): filename = result['metadata']['filename'] # Save augmented image try: # Convert PIL image to RGB if needed if result['image'].mode != 'RGB': img_rgb = result['image'].convert('RGB') else: img_rgb = result['image'] # Save as PNG bytes img_buffer = io.BytesIO() img_rgb.save(img_buffer, format='PNG', optimize=True) zip_file.writestr(filename, img_buffer.getvalue()) logger.info(f"Added image: {filename}") except Exception as e: logger.error(f"Error saving image {filename}: {str(e)}") continue # Save corresponding LabelMe JSON file json_filename = filename.replace('.png', '.json') try: # Create a clean LabelMe JSON structure clean_json_data = { "version": "5.0.1", "flags": {}, "shapes": [], "imagePath": filename, "imageData": None, # No embedded image data as requested "imageHeight": result['json_data']['imageHeight'], "imageWidth": result['json_data']['imageWidth'], "imageDepth": 3 } # Copy shapes with proper LabelMe format for shape in result['json_data']['shapes']: clean_shape = { "label": shape['label'], "points": shape['points'], "group_id": shape.get('group_id'), "shape_type": "polygon", "flags": shape.get('flags', {}), "description": shape.get('description', ''), "iscrowd": shape.get('iscrowd', 0), "attributes": shape.get('attributes', {}) } clean_json_data['shapes'].append(clean_shape) # Write JSON file json_str = json.dumps(clean_json_data, indent=2, ensure_ascii=False) zip_file.writestr(json_filename, json_str) logger.info(f"Added JSON: {json_filename} with {len(clean_json_data['shapes'])} shapes") except Exception as e: logger.error(f"Error saving JSON {json_filename}: {str(e)}") continue # Add comprehensive summary metadata summary = { 'package_info': { 'total_augmentations': len(self.augmented_results), 'generation_timestamp': datetime.now().isoformat(), 'generator': 'PolygonAugmentation v1.0', 'format': 'LabelMe JSON + PNG images' }, 'augmentation_summary': [ { 'filename': result['metadata']['filename'], 'json_file': result['metadata']['filename'].replace('.png', '.json'), 'augmentation_type': result['metadata']['augmentation_type'], 'parameter_value': result['metadata']['parameter_value'], 'polygon_count': len(result['json_data']['shapes']), 'image_size': f"{result['json_data']['imageWidth']}x{result['json_data']['imageHeight']}", 'timestamp': result['metadata']['timestamp'], 'labels': list(set([shape['label'] for shape in result['json_data']['shapes']])) } for result in self.augmented_results ], 'statistics': { 'unique_augmentation_types': list(set([r['metadata']['augmentation_type'] for r in self.augmented_results])), 'total_polygons': sum([len(r['json_data']['shapes']) for r in self.augmented_results]), 'unique_labels': list(set([ shape['label'] for result in self.augmented_results for shape in result['json_data']['shapes'] ])), 'average_polygons_per_image': sum([len(r['json_data']['shapes']) for r in self.augmented_results]) / len(self.augmented_results) if self.augmented_results else 0 } } zip_file.writestr('augmentation_summary.json', json.dumps(summary, indent=2, ensure_ascii=False)) # Add README for the package readme_content = f"""# Augmented Dataset Package ## Overview This package contains {len(self.augmented_results)} augmented images with their corresponding LabelMe annotation files. ## Contents - **Images**: PNG format augmented images - **Annotations**: LabelMe JSON format annotation files (standard format) - **Summary**: augmentation_summary.json with detailed metadata ## File Structure - Each image file (*.png) has a corresponding annotation file (*.json) with the same base name - All annotations are in standard LabelMe format without embedded image data - Compatible with LabelMe, CVAT, and other annotation tools ## Statistics - Total augmented images: {len(self.augmented_results)} - Total polygons: {sum([len(r['json_data']['shapes']) for r in self.augmented_results])} - Unique labels: {list(set([shape['label'] for result in self.augmented_results for shape in result['json_data']['shapes']]))} - Augmentation types used: {list(set([r['metadata']['augmentation_type'] for r in self.augmented_results]))} ## Usage 1. Extract the ZIP file 2. Load images and annotations using any tool that supports LabelMe format 3. Use the augmentation_summary.json for batch processing or analysis Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} Tool: PolygonAugmentation v1.0 """ zip_file.writestr('README.md', readme_content) logger.info("Successfully created ZIP package with all files") zip_buffer.seek(0) logger.info(f"Created download package with {len(self.augmented_results)} image-annotation pairs") return zip_buffer.getvalue() except Exception as e: logger.error(f"Error creating ZIP package: {str(e)}") import traceback logger.error(traceback.format_exc()) return None def create_interface(): augmenter = PolygonAugmentation(tolerance=2.0, area_threshold=0.01, debug=True) def process_batch_augmentation( images, json_files, num_augmentations, rotate_enabled, rotate_min, rotate_max, hflip_enabled, vflip_enabled, scale_enabled, scale_min, scale_max, brightness_enabled, brightness_min, brightness_max, dropout_enabled, dropout_min, dropout_max ): if not images or not json_files: return [], "No images or JSON files uploaded", None # Pair images with JSON files image_json_pairs = [] min_length = min(len(images), len(json_files)) for i in range(min_length): if images[i] is not None and json_files[i] is not None: try: image = Image.open(images[i].name) # Load JSON file content properly json_path = json_files[i].name logger.info(f"Loading JSON from: {json_path}") with open(json_path, 'r', encoding='utf-8') as f: json_data = json.load(f) logger.info(f"Successfully loaded JSON with keys: {list(json_data.keys())}") image_json_pairs.append((image, json_data)) except Exception as e: logger.error(f"Error loading image/JSON pair {i}: {str(e)}") import traceback logger.error(traceback.format_exc()) continue if not image_json_pairs: return [], "No valid image-JSON pairs found", None # Configure augmentations based on user selections aug_configs = [] if rotate_enabled: aug_configs.append({ 'aug_type': 'rotate', 'param_range': (rotate_min, rotate_max) }) if hflip_enabled: aug_configs.append({ 'aug_type': 'horizontal_flip', 'param_range': (0, 1) }) if vflip_enabled: aug_configs.append({ 'aug_type': 'vertical_flip', 'param_range': (0, 1) }) if scale_enabled: aug_configs.append({ 'aug_type': 'scale', 'param_range': (scale_min, scale_max) }) if brightness_enabled: aug_configs.append({ 'aug_type': 'brightness_contrast', 'param_range': (brightness_min, brightness_max) }) if dropout_enabled: aug_configs.append({ 'aug_type': 'pixel_dropout', 'param_range': (dropout_min, dropout_max) }) if not aug_configs: return [], "No augmentation types selected", None # Process augmentations try: logger.info(f"Starting batch augmentation with {len(image_json_pairs)} image pairs and {len(aug_configs)} configurations") augmented_images = augmenter.batch_augment_images( image_json_pairs, aug_configs, num_augmentations ) # Create JSON summary json_summary = json.dumps([result['metadata'] for result in augmenter.augmented_results], indent=2) status = f"Generated {len(augmented_images)} augmented images from {len(image_json_pairs)} input pairs" logger.info(status) return augmented_images, json_summary, status except Exception as e: error_msg = f"Batch augmentation error: {str(e)}" logger.error(error_msg) import traceback logger.error(traceback.format_exc()) return [], error_msg, None def download_package(): """Handle download package creation and return proper file data""" try: package_data = augmenter.create_download_package() if package_data is None: return None # Save the package to a temporary file for download import tempfile import os # Create temporary file with proper name temp_file = tempfile.NamedTemporaryFile( delete=False, suffix='.zip', prefix='augmented_dataset_' ) with open(temp_file.name, 'wb') as f: f.write(package_data) logger.info(f"Created download package: {temp_file.name}") return temp_file.name except Exception as e: logger.error(f"Error creating download package: {str(e)}") import traceback logger.error(traceback.format_exc()) return None def show_mask_overlay(evt: gr.SelectData): if evt.index < len(augmenter.augmented_results): return augmenter.augmented_results[evt.index]['image'] return None with gr.Blocks(title="Dynamic Donut Polygon Augmentation") as demo: gr.Markdown("# 🌀 Dynamic Donut Polygon Augmentation Tool") gr.Markdown("Upload multiple images and JSON files to apply batch augmentation with configurable parameter ranges") with gr.Row(): with gr.Column(scale=1): gr.Markdown("## 📁 Input Files") images_input = gr.File( file_count="multiple", file_types=["image"], label="Upload Images" ) json_input = gr.File( file_count="multiple", file_types=[".json"], label="Upload LabelMe JSON Files" ) num_augmentations = gr.Slider( minimum=1, maximum=5, value=2, step=1, label="Augmentations per configuration" ) gr.Markdown("## ⚙️ Augmentation Configuration") # Rotation parameters with gr.Group(): rotate_enabled = gr.Checkbox(label="Enable Rotation", value=True) with gr.Row(): rotate_min = gr.Slider(-45, 45, -15, label="Min Rotation (degrees)") rotate_max = gr.Slider(-45, 45, 15, label="Max Rotation (degrees)") # Flip parameters with gr.Group(): hflip_enabled = gr.Checkbox(label="Enable Horizontal Flip", value=True) vflip_enabled = gr.Checkbox(label="Enable Vertical Flip", value=False) # Scale parameters with gr.Group(): scale_enabled = gr.Checkbox(label="Enable Scale", value=True) with gr.Row(): scale_min = gr.Slider(0.7, 1.3, 0.9, label="Min Scale") scale_max = gr.Slider(0.7, 1.3, 1.1, label="Max Scale") # Brightness parameters with gr.Group(): brightness_enabled = gr.Checkbox(label="Enable Brightness/Contrast", value=True) with gr.Row(): brightness_min = gr.Slider(-0.3, 0.3, -0.1, label="Min Brightness") brightness_max = gr.Slider(-0.3, 0.3, 0.1, label="Max Brightness") # Dropout parameters with gr.Group(): dropout_enabled = gr.Checkbox(label="Enable Pixel Dropout", value=False) with gr.Row(): dropout_min = gr.Slider(0.01, 0.1, 0.02, label="Min Dropout") dropout_max = gr.Slider(0.01, 0.1, 0.05, label="Max Dropout") generate_btn = gr.Button("🚀 Generate Augmentations", variant="primary") status_text = gr.Textbox(label="Status", interactive=False) with gr.Column(scale=2): gr.Markdown("## 🖼️ Augmented Results") gr.Markdown("*Click on any image to view with enhanced mask overlay*") augmented_gallery = gr.Gallery( label="Augmented Images with Polygon Masks", show_label=False, elem_id="gallery", columns=3, rows=3, height="auto" ) with gr.Row(): download_btn = gr.Button("📥 Download All (ZIP)", variant="secondary") download_file = gr.File(label="Download Package", visible=True) gr.Markdown("## 📋 Augmentation Metadata") json_output = gr.Code( label="Generated Metadata JSON", language="json", lines=15 ) gr.Markdown("## 🎭 Enhanced Preview") mask_preview = gr.Image(label="Selected Image with Mask Overlay") # Event handlers generate_btn.click( process_batch_augmentation, inputs=[ images_input, json_input, num_augmentations, rotate_enabled, rotate_min, rotate_max, hflip_enabled, vflip_enabled, scale_enabled, scale_min, scale_max, brightness_enabled, brightness_min, brightness_max, dropout_enabled, dropout_min, dropout_max ], outputs=[augmented_gallery, json_output, status_text] ) download_btn.click( download_package, outputs=download_file ) augmented_gallery.select( show_mask_overlay, outputs=mask_preview ) return demo if __name__ == "__main__": demo = create_interface() demo.launch()