Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import json | |
| import albumentations as A | |
| from typing import List, Tuple, Dict, Any | |
| import supervision as sv | |
| import uuid | |
| import random | |
| from pathlib import Path | |
| import colorsys | |
| import logging | |
| import zipfile | |
| import io | |
| from datetime import datetime | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class PolygonAugmentation: | |
| def __init__(self, tolerance=0.2, area_threshold=0.01, debug=False): | |
| self.tolerance = tolerance | |
| self.area_threshold = area_threshold | |
| self.debug = debug | |
| self.supported_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.PNG', '.JPEG'] | |
| self.augmented_results = [] # Store all augmentation results | |
| def __getattr__(self, name: str) -> Any: | |
| raise AttributeError(f"'PolygonAugmentation' object has no attribute '{name}'") | |
| def calculate_polygon_area(self, points: List[List[float]]) -> float: | |
| poly_np = np.array(points, dtype=np.float32) | |
| area = cv2.contourArea(poly_np) | |
| if self.debug: | |
| logger.info(f"[DEBUG] Calculating polygon area: {area:.2f}") | |
| return area | |
| def load_labelme_data(self, json_file: Any, image: np.ndarray) -> Tuple: | |
| if isinstance(json_file, str): | |
| with open(json_file, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| elif isinstance(json_file, dict): | |
| # Handle dictionary data directly | |
| data = json_file | |
| else: | |
| # Handle file object | |
| data = json.load(json_file) | |
| shapes = [] | |
| if 'shapes' in data and isinstance(data['shapes'], list): | |
| shapes = data['shapes'] | |
| elif 'segments' in data and isinstance(data['segments'], list): | |
| shapes = [ | |
| { | |
| "label": seg.get("class", "unknown"), | |
| "points": seg.get("polygon", []), | |
| "shape_type": "polygon", | |
| "group_id": None, | |
| "flags": {}, | |
| "confidence": seg.get("confidence", 1.0) | |
| } | |
| for seg in data['segments'] | |
| ] | |
| else: | |
| raise ValueError("Invalid JSON: Neither 'shapes' nor 'segments' key found or not a list") | |
| polygons = [] | |
| labels = [] | |
| original_areas = [] | |
| for shape in shapes: | |
| if shape.get('shape_type') != 'polygon' or not shape.get('points') or len(shape['points']) < 3: | |
| if self.debug: | |
| logger.info(f"[DEBUG] Skipping invalid shape: {shape}") | |
| continue | |
| try: | |
| points = [[float(x), float(y)] for x, y in shape['points']] | |
| polygons.append(points) | |
| labels.append(shape['label']) | |
| original_areas.append(self.calculate_polygon_area(points)) | |
| except (ValueError, TypeError) as e: | |
| if self.debug: | |
| logger.info(f"[DEBUG] Error processing points: {shape['points']}, error: {str(e)}") | |
| continue | |
| if not polygons and self.debug: | |
| logger.info(f"[DEBUG] Warning: No valid polygons in JSON") | |
| return image, polygons, labels, original_areas, data, "input" | |
| def simplify_polygon(self, polygon: List[List[float]], tolerance: float = None, label: str = None) -> List[List[float]]: | |
| tol = tolerance if tolerance is not None else self.tolerance | |
| if label and label.lower() in ['background', 'bg', 'back']: | |
| tol = tol * 3 | |
| if self.debug: | |
| logger.info(f"[DEBUG] Using increased tolerance {tol} for background label '{label}'") | |
| if len(polygon) < 3: | |
| if self.debug: | |
| logger.info(f"[DEBUG] Polygon has fewer than 3 points, skipping simplification.") | |
| return polygon | |
| poly_np = np.array(polygon, dtype=np.float32) | |
| approx = cv2.approxPolyDP(poly_np, tol, closed=True) | |
| simplified = approx.reshape(-1, 2).tolist() | |
| if self.debug: | |
| logger.info(f"[DEBUG] Simplified polygon from {len(polygon)} to {len(simplified)} points with tolerance {tol}") | |
| return simplified | |
| def create_donut_polygon(self, external_contour: np.ndarray, internal_contours: List[np.ndarray]) -> List[List[float]]: | |
| """Create a donut/ring polygon by connecting external and internal contours with bridges""" | |
| external_points = external_contour.reshape(-1, 2).tolist() | |
| if not internal_contours: | |
| if self.debug: | |
| logger.info("[DEBUG] No internal contours found, returning external points.") | |
| return external_points | |
| # Start with external contour points | |
| result_points = external_points.copy() | |
| # Process each internal contour (hole) | |
| for hole_idx, internal_contour in enumerate(internal_contours): | |
| internal_points = internal_contour.reshape(-1, 2).tolist() | |
| # Find the closest point between external and internal contours | |
| min_dist = float('inf') | |
| best_ext_idx = 0 | |
| best_int_idx = 0 | |
| # Check all combinations to find minimum distance | |
| for i, ext_point in enumerate(result_points): | |
| for j, int_point in enumerate(internal_points): | |
| dist = np.sqrt((ext_point[0] - int_point[0])**2 + (ext_point[1] - int_point[1])**2) | |
| if dist < min_dist: | |
| min_dist = dist | |
| best_ext_idx = i | |
| best_int_idx = j | |
| # Create bridge points | |
| bridge_start = result_points[best_ext_idx] | |
| connect_point = internal_points[best_int_idx] | |
| if self.debug: | |
| logger.info(f"[DEBUG] Creating bridge for hole {hole_idx}: ext_idx={best_ext_idx}, int_idx={best_int_idx}, distance={min_dist:.2f}") | |
| # Insert the internal contour into the result | |
| # Order: external_points[:best_ext_idx+1] + internal_hole + back_to_external + external_points[best_ext_idx+1:] | |
| new_result = ( | |
| result_points[:best_ext_idx+1] + # External points up to bridge | |
| internal_points[best_int_idx:] + # Internal points from connection point to end | |
| internal_points[:best_int_idx+1] + # Internal points from start to connection point | |
| [bridge_start] + # Bridge back to external | |
| result_points[best_ext_idx+1:] # Remaining external points | |
| ) | |
| result_points = new_result | |
| if self.debug: | |
| logger.info(f"[DEBUG] Created donut polygon with {len(result_points)} total points") | |
| return result_points | |
| def save_augmented_data( | |
| self, | |
| aug_image: np.ndarray, | |
| aug_polygons: List[List[List[float]]], | |
| aug_labels: List[str], | |
| original_data: Dict[str, Any], | |
| base_name: str | |
| ) -> Dict[str, Any]: | |
| aug_id = uuid.uuid4().hex[:4] | |
| aug_img_name = f"{base_name}_{aug_id}_aug.png" | |
| new_shapes = [] | |
| for poly, label in zip(aug_polygons, aug_labels): | |
| if not poly or len(poly) < 3: | |
| continue | |
| # Create LabelMe format shape | |
| shape_data = { | |
| "label": label, | |
| "points": poly, | |
| "group_id": None, | |
| "shape_type": "polygon", | |
| "flags": {}, | |
| "description": "", | |
| "attributes": {}, | |
| "iscrowd": 0, | |
| "difficult": 0 | |
| } | |
| # Add additional metadata for special polygon types | |
| if label.lower() in ['ring', 'donut', 'annulus', 'circle', 'round']: | |
| shape_data["attributes"]["polygon_type"] = "ring" | |
| elif label.lower() in ['background', 'bg', 'back']: | |
| shape_data["attributes"]["polygon_type"] = "background" | |
| else: | |
| shape_data["attributes"]["polygon_type"] = "object" | |
| new_shapes.append(shape_data) | |
| # Get actual dimensions from augmented image | |
| aug_height, aug_width = aug_image.shape[:2] | |
| # Create LabelMe compatible JSON structure | |
| aug_data = { | |
| "version": original_data.get("version", "5.0.1"), | |
| "flags": original_data.get("flags", {}), | |
| "shapes": new_shapes, | |
| "imagePath": aug_img_name, | |
| "imageData": None, # Explicitly set to None as requested | |
| "imageHeight": aug_height, | |
| "imageWidth": aug_width, | |
| "imageDepth": 3 if len(aug_image.shape) == 3 else 1, | |
| # Additional LabelMe metadata | |
| "lineColor": [0, 255, 0, 128], | |
| "fillColor": [255, 0, 0, 128], | |
| "textSize": 10, | |
| "textColor": [0, 0, 0, 255], | |
| # Augmentation metadata | |
| "augmentation": { | |
| "augmented": True, | |
| "augmentation_id": aug_id, | |
| "original_file": original_data.get("imagePath", "unknown"), | |
| "augmentation_timestamp": datetime.now().isoformat(), | |
| "augmentation_tool": "PolygonAugmentation v1.0" | |
| } | |
| } | |
| if self.debug: | |
| logger.info(f"[DEBUG] Created LabelMe JSON: {len(new_shapes)} shapes, size: {aug_width}x{aug_height}") | |
| logger.info(f"[DEBUG] Shape types: {[s['attributes'].get('polygon_type', 'unknown') for s in new_shapes]}") | |
| return aug_data | |
| def polygons_to_masks(self, image: np.ndarray, polygons: List[List[List[float]]], labels: List[str]) -> Tuple[np.ndarray, List[str]]: | |
| height, width = image.shape[:2] | |
| all_masks = [] | |
| all_labels = [] | |
| for poly_idx, (poly, label) in enumerate(zip(polygons, labels)): | |
| try: | |
| poly_np = np.array(poly, dtype=np.int32) | |
| if len(poly_np) < 3: | |
| if self.debug: | |
| logger.info(f"[DEBUG] Skipping polygon {poly_idx}: fewer than 3 points") | |
| continue | |
| mask = np.zeros((height, width), dtype=np.uint8) | |
| cv2.fillPoly(mask, [poly_np], 1) | |
| all_masks.append(mask) | |
| all_labels.append(label) | |
| except Exception as e: | |
| if self.debug: | |
| logger.info(f"[DEBUG] Error processing polygon {poly_idx}: {str(e)}") | |
| if not all_masks: | |
| return np.zeros((0, height, width), dtype=np.uint8), [] | |
| return np.array(all_masks, dtype=np.uint8), all_labels | |
| def process_contours( | |
| self, | |
| external_contour: np.ndarray, | |
| internal_contours: List[np.ndarray], | |
| width: int, | |
| height: int, | |
| label: str, | |
| all_polygons: List[List[List[float]]], | |
| all_labels: List[str], | |
| tolerance: float = None | |
| ) -> None: | |
| tol = tolerance if tolerance is not None else self.tolerance | |
| external_points = external_contour.reshape(-1, 2).tolist() | |
| simplified_external = self.simplify_polygon(external_points, tolerance=tol, label=label) | |
| if len(simplified_external) >= 3: | |
| poly_labelme = [[round(max(0, min(float(x), width - 1)), 2), | |
| round(max(0, min(float(y), height - 1)), 2)] | |
| for x, y in simplified_external] | |
| all_polygons.append(poly_labelme) | |
| all_labels.append(label) | |
| if self.debug: | |
| logger.info(f"[DEBUG] Added simplified external polygon with {len(poly_labelme)} points.") | |
| for internal_contour in internal_contours: | |
| internal_points = internal_contour.reshape(-1, 2).tolist() | |
| simplified_internal = self.simplify_polygon(internal_points, tolerance=tol, label=label) | |
| if len(simplified_internal) >= 3: | |
| poly_labelme = [[round(max(0, min(float(x), width - 1)), 2), | |
| round(max(0, min(float(y), height - 1)), 2)] | |
| for x, y in simplified_internal] | |
| all_polygons.append(poly_labelme) | |
| all_labels.append(label) | |
| if self.debug: | |
| logger.info(f"[DEBUG] Added simplified internal polygon with {len(poly_labelme)} points.") | |
| def masks_to_labelme_polygons( | |
| self, | |
| masks: np.ndarray, | |
| labels: List[str], | |
| original_areas: List[float], | |
| area_threshold: float = None, | |
| tolerance: float = None | |
| ) -> Tuple[List[List[List[float]]], List[str]]: | |
| tol = tolerance if tolerance is not None else self.tolerance | |
| area_thresh = area_threshold if area_threshold is not None else self.area_threshold | |
| height, width = masks[0].shape if len(masks) > 0 else (0, 0) | |
| all_polygons = [] | |
| all_labels = [] | |
| for mask_idx, (mask, label) in enumerate(zip(masks, labels)): | |
| if mask.sum() < 10: | |
| if self.debug: | |
| logger.info(f"[DEBUG] Skipping mask {mask_idx}: very small or empty.") | |
| continue | |
| contours, hierarchy = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE) | |
| if hierarchy is None or len(contours) == 0: | |
| if self.debug: | |
| logger.info(f"[DEBUG] No contours found in mask {mask_idx}.") | |
| continue | |
| hierarchy = hierarchy[0] | |
| external_contours = [] | |
| internal_contours_map = {} | |
| for i, (contour, h) in enumerate(zip(contours, hierarchy)): | |
| if h[3] == -1: | |
| external_contours.append(contour) | |
| internal_contours_map[len(external_contours)-1] = [] | |
| else: | |
| parent_idx = h[3] | |
| for j, _ in enumerate(external_contours): | |
| if parent_idx == j: | |
| internal_contours_map[j].append(contour) | |
| break | |
| if not external_contours: | |
| if self.debug: | |
| logger.info(f"[DEBUG] No external contours found in mask {mask_idx}.") | |
| continue | |
| for ext_idx, external_contour in enumerate(external_contours): | |
| internal_contours = internal_contours_map.get(ext_idx, []) | |
| ext_area = cv2.contourArea(external_contour) | |
| if ext_area <= 0: | |
| continue | |
| if mask_idx < len(original_areas) and original_areas[mask_idx] > 0: | |
| relative_area = ext_area / original_areas[mask_idx] | |
| if relative_area < area_thresh: | |
| if self.debug: | |
| logger.info(f"[DEBUG] Skipping contour {ext_idx} (area too small: {relative_area:.4f})") | |
| continue | |
| # Check if this is a ring/donut shape or complex polygon | |
| is_ring_shape = label.lower() in ['ring', 'donut', 'annulus', 'circle', 'round'] or len(internal_contours) > 0 | |
| is_background = label.lower() in ['background', 'bg', 'back'] | |
| # Handle different polygon types | |
| if (is_background or is_ring_shape) and internal_contours: | |
| try: | |
| # Create donut polygon for rings, backgrounds, or shapes with holes | |
| donut_points = self.create_donut_polygon(external_contour, internal_contours) | |
| simplified_donut = self.simplify_polygon(donut_points, tolerance=tol, label=label) | |
| if len(simplified_donut) >= 3: | |
| # Ensure all points are within image boundaries | |
| poly_labelme = [] | |
| for x, y in simplified_donut: | |
| clipped_x = round(max(0, min(float(x), width - 1)), 2) | |
| clipped_y = round(max(0, min(float(y), height - 1)), 2) | |
| poly_labelme.append([clipped_x, clipped_y]) | |
| all_polygons.append(poly_labelme) | |
| all_labels.append(label) | |
| if self.debug: | |
| logger.info(f"[DEBUG] Added {'ring' if is_ring_shape else 'background'} donut polygon with {len(poly_labelme)} points, {len(internal_contours)} holes") | |
| else: | |
| if self.debug: | |
| logger.info(f"[DEBUG] Donut polygon too small after simplification, falling back to separate contours") | |
| # Fallback to separate contours | |
| self.process_contours( | |
| external_contour, internal_contours, width, height, | |
| label, all_polygons, all_labels, tol | |
| ) | |
| except Exception as e: | |
| if self.debug: | |
| logger.info(f"[DEBUG] Error creating donut for {label}: {str(e)}, fallback to separate polygons.") | |
| # Fallback to processing contours separately | |
| self.process_contours( | |
| external_contour, internal_contours, width, height, | |
| label, all_polygons, all_labels, tol | |
| ) | |
| else: | |
| # Handle regular polygons (no holes or simple shapes) | |
| self.process_contours( | |
| external_contour, internal_contours, width, height, | |
| label, all_polygons, all_labels, tol | |
| ) | |
| return all_polygons, all_labels | |
| def augment_single_image( | |
| self, | |
| image: np.ndarray, | |
| polygons: List[List[List[float]]], | |
| labels: List[str], | |
| original_areas: List[float], | |
| original_data: Dict[str, Any], | |
| aug_type: str, | |
| aug_param: float | |
| ) -> Tuple[np.ndarray, Dict[str, Any]]: | |
| logger.info(f"Applying augmentation: {aug_type} with parameter {aug_param}") | |
| height, width = image.shape[:2] | |
| # Setup augmentation based on type with proper parameters | |
| if aug_type == "rotate": | |
| # For rotation, use the parameter as degrees and make it more visible | |
| rotation_angle = aug_param if abs(aug_param) >= 5 else (15 if aug_param >= 0 else -15) | |
| # Use angle directly (not abs) and set limit as tuple for specific angle | |
| aug_transform = A.Rotate(limit=(rotation_angle, rotation_angle), p=1.0, border_mode=cv2.BORDER_CONSTANT, value=0) | |
| logger.info(f"Applying rotation: {rotation_angle} degrees") | |
| elif aug_type == "horizontal_flip": | |
| aug_transform = A.HorizontalFlip(p=1.0 if aug_param == 1 else 0.0) | |
| elif aug_type == "vertical_flip": | |
| aug_transform = A.VerticalFlip(p=1.0 if aug_param == 1 else 0.0) | |
| elif aug_type == "scale": | |
| # Ensure scale parameter is reasonable | |
| scale_factor = max(0.5, min(2.0, aug_param)) | |
| aug_transform = A.Affine(scale=scale_factor, p=1.0, keep_ratio=True) | |
| logger.info(f"Applying scale: {scale_factor}") | |
| elif aug_type == "brightness_contrast": | |
| brightness_factor = max(-0.5, min(0.5, aug_param)) | |
| aug_transform = A.RandomBrightnessContrast( | |
| brightness_limit=abs(brightness_factor), | |
| contrast_limit=abs(brightness_factor), | |
| p=1.0 | |
| ) | |
| elif aug_type == "pixel_dropout": | |
| dropout_prob = min(max(aug_param, 0.0), 0.2) | |
| aug_transform = A.PixelDropout(dropout_prob=dropout_prob, p=1.0) | |
| else: | |
| raise ValueError(f"Unsupported augmentation type: {aug_type}") | |
| # Create masks from polygons | |
| masks, mask_labels = self.polygons_to_masks(image, polygons, labels) | |
| if masks.shape[0] == 0: | |
| raise ValueError("No valid masks created from polygons") | |
| # Convert masks array to list for albumentations | |
| masks_list = [masks[i] for i in range(masks.shape[0])] | |
| # Create additional targets for each mask | |
| additional_targets = {f'mask{i}': 'mask' for i in range(len(masks_list))} | |
| # Create transform with proper mask handling | |
| transform = A.Compose([ | |
| aug_transform | |
| ], additional_targets=additional_targets) | |
| # Prepare input dictionary | |
| input_dict = {'image': image} | |
| for i, mask in enumerate(masks_list): | |
| input_dict[f'mask{i}'] = mask | |
| # Apply augmentation | |
| aug_result = transform(**input_dict) | |
| aug_image = aug_result['image'] | |
| # Collect augmented masks and ensure they match image dimensions | |
| aug_masks_list = [] | |
| aug_height, aug_width = aug_image.shape[:2] | |
| for i in range(len(masks_list)): | |
| aug_mask = aug_result[f'mask{i}'] | |
| # Ensure mask dimensions match augmented image | |
| if aug_mask.shape[:2] != (aug_height, aug_width): | |
| aug_mask = cv2.resize(aug_mask, (aug_width, aug_height), interpolation=cv2.INTER_NEAREST) | |
| aug_masks_list.append(aug_mask) | |
| aug_masks = np.array(aug_masks_list, dtype=np.uint8) | |
| # Validate augmented image | |
| if aug_image is None or aug_image.size == 0: | |
| raise ValueError("Augmented image is empty or invalid") | |
| # Convert augmented masks back to polygons | |
| aug_polygons, aug_labels = self.masks_to_labelme_polygons( | |
| aug_masks, mask_labels, original_areas, self.area_threshold, self.tolerance | |
| ) | |
| # Apply random crop as post-processing to add variety | |
| if random.random() < 0.3: # 30% chance of cropping | |
| crop_scale = random.uniform(0.85, 0.95) | |
| crop_height = int(aug_height * crop_scale) | |
| crop_width = int(aug_width * crop_scale) | |
| # Create crop transform | |
| crop_transform = A.Compose([ | |
| A.RandomCrop(width=crop_width, height=crop_height, p=1.0) | |
| ], additional_targets={f'mask{i}': 'mask' for i in range(len(aug_masks_list))}) | |
| # Apply crop | |
| crop_input = {'image': aug_image} | |
| for i, mask in enumerate(aug_masks_list): | |
| crop_input[f'mask{i}'] = mask | |
| crop_result = crop_transform(**crop_input) | |
| aug_image = crop_result['image'] | |
| # Update masks after crop | |
| cropped_masks = [] | |
| for i in range(len(aug_masks_list)): | |
| cropped_masks.append(crop_result[f'mask{i}']) | |
| aug_masks = np.array(cropped_masks, dtype=np.uint8) | |
| # Re-convert masks to polygons after crop | |
| aug_polygons, aug_labels = self.masks_to_labelme_polygons( | |
| aug_masks, mask_labels, original_areas, self.area_threshold, self.tolerance | |
| ) | |
| # Create augmented data with correct dimensions | |
| aug_data = self.save_augmented_data(aug_image, aug_polygons, aug_labels, original_data, "input") | |
| logger.info(f"Augmentation completed: {len(aug_polygons)} polygons generated, final size: {aug_image.shape[:2]}") | |
| return aug_image, aug_data | |
| def batch_augment_images(self, image_json_pairs, aug_configs, num_augmentations): | |
| """Batch process multiple images with multiple augmentation configurations""" | |
| logger.info(f"Starting batch augmentation with {len(image_json_pairs)} pairs, {len(aug_configs)} configs, {num_augmentations} augmentations each") | |
| self.augmented_results = [] | |
| results = [] | |
| for pair_idx, (image, json_data) in enumerate(image_json_pairs): | |
| if image is None or json_data is None: | |
| logger.warning(f"Skipping pair {pair_idx}: missing image or JSON data") | |
| continue | |
| try: | |
| logger.info(f"Processing image pair {pair_idx}") | |
| # Convert PIL image to NumPy | |
| img_np = np.array(image) | |
| img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) | |
| # Load data - pass the JSON data directly | |
| img_np, polygons, labels, original_areas, original_data, _ = self.load_labelme_data(json_data, img_np) | |
| logger.info(f"Loaded {len(polygons)} polygons for image {pair_idx}") | |
| # Apply each augmentation configuration | |
| for config_idx, config in enumerate(aug_configs): | |
| logger.info(f"Applying config {config_idx}: {config['aug_type']}") | |
| for aug_idx in range(num_augmentations): | |
| # Generate random parameter within range | |
| min_val, max_val = config['param_range'] | |
| if config['aug_type'] in ['horizontal_flip', 'vertical_flip']: | |
| aug_param = random.choice([0, 1]) | |
| else: | |
| aug_param = random.uniform(min_val, max_val) | |
| try: | |
| logger.info(f"Generating augmentation {aug_idx} with {config['aug_type']}, param: {aug_param}") | |
| aug_image, aug_data = self.augment_single_image( | |
| img_np, polygons, labels, original_areas, | |
| original_data, config['aug_type'], aug_param | |
| ) | |
| # Create visualization | |
| aug_image_vis = self.create_visualization(aug_image, aug_data) | |
| # Store result | |
| result_data = { | |
| 'image': aug_image_vis, | |
| 'json_data': aug_data, | |
| 'metadata': { | |
| 'original_image_index': pair_idx, | |
| 'augmentation_index': aug_idx, | |
| 'augmentation_type': config['aug_type'], | |
| 'parameter_value': aug_param, | |
| 'parameter_range': config['param_range'], | |
| 'timestamp': datetime.now().isoformat(), | |
| 'filename': f'aug_{pair_idx}_{config["aug_type"]}_{aug_idx}.png' | |
| } | |
| } | |
| self.augmented_results.append(result_data) | |
| results.append(aug_image_vis) | |
| logger.info(f"Successfully generated augmentation {aug_idx} for image {pair_idx}") | |
| except Exception as e: | |
| logger.error(f"Error augmenting image {pair_idx} with {config['aug_type']}: {str(e)}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| continue | |
| except Exception as e: | |
| logger.error(f"Error processing image pair {pair_idx}: {str(e)}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| continue | |
| logger.info(f"Batch augmentation completed. Generated {len(results)} total results.") | |
| return results | |
| def create_visualization(self, aug_image, aug_data): | |
| """Create visualization with colored polygon masks and outlines for each class""" | |
| # Create a dynamic color map for unique labels with better color distribution | |
| unique_labels = list(set(shape['label'] for shape in aug_data['shapes'])) | |
| if not unique_labels: | |
| label_color_map = {"unknown": (0, 255, 0)} | |
| else: | |
| num_labels = len(unique_labels) | |
| # Create more distinct colors using different hue ranges | |
| label_color_map = {} | |
| for i, label in enumerate(unique_labels): | |
| if label.lower() in ['background', 'bg', 'back']: | |
| # Background gets a neutral gray-blue color | |
| rgb = (100, 149, 237) # Cornflower blue with low opacity | |
| elif 'ring' in label.lower() or 'donut' in label.lower(): | |
| # Ring/donut shapes get purple-pink colors | |
| hue = 0.8 + (i * 0.1) % 0.2 # Purple range | |
| rgb = colorsys.hsv_to_rgb(hue, 0.8, 0.9) | |
| rgb = tuple(int(c * 255) for c in rgb) | |
| else: | |
| # Regular objects get distributed colors across the spectrum | |
| hue = (i * 0.618033988749895) % 1.0 # Golden ratio for better distribution | |
| saturation = 0.7 + (i % 3) * 0.1 # Vary saturation | |
| value = 0.8 + (i % 2) * 0.15 # Vary brightness | |
| rgb = colorsys.hsv_to_rgb(hue, saturation, value) | |
| rgb = tuple(int(c * 255) for c in rgb) | |
| label_color_map[label] = rgb | |
| # Convert augmented image to RGB for visualization | |
| aug_image_rgb = cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB) | |
| overlay = aug_image_rgb.copy() | |
| height, width = aug_image.shape[:2] | |
| # Create a composite mask to handle overlapping polygons | |
| composite_mask = np.zeros((height, width, 3), dtype=np.uint8) | |
| # Group shapes by label for better visualization | |
| shapes_by_label = {} | |
| for shape in aug_data['shapes']: | |
| label = shape['label'] | |
| if label not in shapes_by_label: | |
| shapes_by_label[label] = [] | |
| shapes_by_label[label].append(shape) | |
| # Process each label group | |
| for label, shapes in shapes_by_label.items(): | |
| color = label_color_map.get(label, (0, 255, 0)) | |
| # Create mask for all polygons of this label | |
| label_mask = np.zeros((height, width), dtype=np.uint8) | |
| for shape in shapes: | |
| points = np.array(shape['points'], dtype=np.int32) | |
| if len(points) < 3: | |
| continue | |
| # Fill the polygon area | |
| cv2.fillPoly(label_mask, [points], 255) | |
| # Apply color to the mask areas | |
| if label_mask.sum() > 0: # Only if mask has content | |
| # Determine alpha based on label type | |
| if label.lower() in ['background', 'bg', 'back']: | |
| alpha = 0.15 # Lower opacity for background | |
| elif 'ring' in label.lower() or 'donut' in label.lower(): | |
| alpha = 0.4 # Medium opacity for rings | |
| else: | |
| alpha = 0.35 # Standard opacity for objects | |
| # Create colored mask | |
| colored_mask = np.zeros_like(aug_image_rgb) | |
| colored_mask[label_mask == 255] = color | |
| # Blend with overlay | |
| mask_area = label_mask == 255 | |
| overlay[mask_area] = cv2.addWeighted( | |
| overlay[mask_area], | |
| 1.0 - alpha, | |
| colored_mask[mask_area], | |
| alpha, | |
| 0 | |
| ) | |
| # Draw polygon outlines with thicker lines for better visibility | |
| for shape in aug_data['shapes']: | |
| label = shape['label'] | |
| color = label_color_map.get(label, (0, 255, 0)) | |
| points = np.array(shape['points'], dtype=np.int32) | |
| if len(points) < 3: | |
| continue | |
| # Determine line thickness based on polygon type | |
| if label.lower() in ['background', 'bg', 'back']: | |
| thickness = 1 # Thinner lines for background | |
| elif 'ring' in label.lower() or 'donut' in label.lower(): | |
| thickness = 3 # Thicker lines for rings to show structure | |
| else: | |
| thickness = 2 # Standard thickness | |
| # Draw polygon outline | |
| cv2.polylines(overlay, [points], isClosed=True, color=color, thickness=thickness) | |
| # Add label text near the polygon | |
| if len(points) > 0: | |
| # Find a good position for the label | |
| moments = cv2.moments(points) | |
| if moments['m00'] != 0: | |
| cx = int(moments['m10'] / moments['m00']) | |
| cy = int(moments['m01'] / moments['m00']) | |
| else: | |
| cx, cy = points[0][0], points[0][1] | |
| # Ensure text position is within image bounds | |
| cx = max(10, min(cx, width - 50)) | |
| cy = max(20, min(cy, height - 10)) | |
| # Add text background for better readability | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.4 | |
| text_thickness = 1 | |
| text_size = cv2.getTextSize(label, font, font_scale, text_thickness)[0] | |
| # Draw background rectangle | |
| cv2.rectangle(overlay, | |
| (cx - 2, cy - text_size[1] - 4), | |
| (cx + text_size[0] + 2, cy + 2), | |
| (0, 0, 0), -1) | |
| # Draw text | |
| cv2.putText(overlay, label, (cx, cy - 2), font, font_scale, color, text_thickness) | |
| if self.debug: | |
| logger.info(f"[DEBUG] Created visualization with {len(unique_labels)} unique labels: {list(unique_labels)}") | |
| return Image.fromarray(overlay) | |
| def create_download_package(self): | |
| """Create a zip file with all augmented images and proper LabelMe JSON files""" | |
| if not self.augmented_results: | |
| logger.warning("No augmented results available for download") | |
| return None | |
| logger.info(f"Creating download package with {len(self.augmented_results)} results") | |
| zip_buffer = io.BytesIO() | |
| try: | |
| with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: | |
| # Add all augmented images and their corresponding LabelMe JSON files | |
| for idx, result in enumerate(self.augmented_results): | |
| filename = result['metadata']['filename'] | |
| # Save augmented image | |
| try: | |
| # Convert PIL image to RGB if needed | |
| if result['image'].mode != 'RGB': | |
| img_rgb = result['image'].convert('RGB') | |
| else: | |
| img_rgb = result['image'] | |
| # Save as PNG bytes | |
| img_buffer = io.BytesIO() | |
| img_rgb.save(img_buffer, format='PNG', optimize=True) | |
| zip_file.writestr(filename, img_buffer.getvalue()) | |
| logger.info(f"Added image: {filename}") | |
| except Exception as e: | |
| logger.error(f"Error saving image {filename}: {str(e)}") | |
| continue | |
| # Save corresponding LabelMe JSON file | |
| json_filename = filename.replace('.png', '.json') | |
| try: | |
| # Create a clean LabelMe JSON structure | |
| clean_json_data = { | |
| "version": "5.0.1", | |
| "flags": {}, | |
| "shapes": [], | |
| "imagePath": filename, | |
| "imageData": None, # No embedded image data as requested | |
| "imageHeight": result['json_data']['imageHeight'], | |
| "imageWidth": result['json_data']['imageWidth'], | |
| "imageDepth": 3 | |
| } | |
| # Copy shapes with proper LabelMe format | |
| for shape in result['json_data']['shapes']: | |
| clean_shape = { | |
| "label": shape['label'], | |
| "points": shape['points'], | |
| "group_id": shape.get('group_id'), | |
| "shape_type": "polygon", | |
| "flags": shape.get('flags', {}), | |
| "description": shape.get('description', ''), | |
| "iscrowd": shape.get('iscrowd', 0), | |
| "attributes": shape.get('attributes', {}) | |
| } | |
| clean_json_data['shapes'].append(clean_shape) | |
| # Write JSON file | |
| json_str = json.dumps(clean_json_data, indent=2, ensure_ascii=False) | |
| zip_file.writestr(json_filename, json_str) | |
| logger.info(f"Added JSON: {json_filename} with {len(clean_json_data['shapes'])} shapes") | |
| except Exception as e: | |
| logger.error(f"Error saving JSON {json_filename}: {str(e)}") | |
| continue | |
| # Add comprehensive summary metadata | |
| summary = { | |
| 'package_info': { | |
| 'total_augmentations': len(self.augmented_results), | |
| 'generation_timestamp': datetime.now().isoformat(), | |
| 'generator': 'PolygonAugmentation v1.0', | |
| 'format': 'LabelMe JSON + PNG images' | |
| }, | |
| 'augmentation_summary': [ | |
| { | |
| 'filename': result['metadata']['filename'], | |
| 'json_file': result['metadata']['filename'].replace('.png', '.json'), | |
| 'augmentation_type': result['metadata']['augmentation_type'], | |
| 'parameter_value': result['metadata']['parameter_value'], | |
| 'polygon_count': len(result['json_data']['shapes']), | |
| 'image_size': f"{result['json_data']['imageWidth']}x{result['json_data']['imageHeight']}", | |
| 'timestamp': result['metadata']['timestamp'], | |
| 'labels': list(set([shape['label'] for shape in result['json_data']['shapes']])) | |
| } | |
| for result in self.augmented_results | |
| ], | |
| 'statistics': { | |
| 'unique_augmentation_types': list(set([r['metadata']['augmentation_type'] for r in self.augmented_results])), | |
| 'total_polygons': sum([len(r['json_data']['shapes']) for r in self.augmented_results]), | |
| 'unique_labels': list(set([ | |
| shape['label'] | |
| for result in self.augmented_results | |
| for shape in result['json_data']['shapes'] | |
| ])), | |
| 'average_polygons_per_image': sum([len(r['json_data']['shapes']) for r in self.augmented_results]) / len(self.augmented_results) if self.augmented_results else 0 | |
| } | |
| } | |
| zip_file.writestr('augmentation_summary.json', json.dumps(summary, indent=2, ensure_ascii=False)) | |
| # Add README for the package | |
| readme_content = f"""# Augmented Dataset Package | |
| ## Overview | |
| This package contains {len(self.augmented_results)} augmented images with their corresponding LabelMe annotation files. | |
| ## Contents | |
| - **Images**: PNG format augmented images | |
| - **Annotations**: LabelMe JSON format annotation files (standard format) | |
| - **Summary**: augmentation_summary.json with detailed metadata | |
| ## File Structure | |
| - Each image file (*.png) has a corresponding annotation file (*.json) with the same base name | |
| - All annotations are in standard LabelMe format without embedded image data | |
| - Compatible with LabelMe, CVAT, and other annotation tools | |
| ## Statistics | |
| - Total augmented images: {len(self.augmented_results)} | |
| - Total polygons: {sum([len(r['json_data']['shapes']) for r in self.augmented_results])} | |
| - Unique labels: {list(set([shape['label'] for result in self.augmented_results for shape in result['json_data']['shapes']]))} | |
| - Augmentation types used: {list(set([r['metadata']['augmentation_type'] for r in self.augmented_results]))} | |
| ## Usage | |
| 1. Extract the ZIP file | |
| 2. Load images and annotations using any tool that supports LabelMe format | |
| 3. Use the augmentation_summary.json for batch processing or analysis | |
| Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | |
| Tool: PolygonAugmentation v1.0 | |
| """ | |
| zip_file.writestr('README.md', readme_content) | |
| logger.info("Successfully created ZIP package with all files") | |
| zip_buffer.seek(0) | |
| logger.info(f"Created download package with {len(self.augmented_results)} image-annotation pairs") | |
| return zip_buffer.getvalue() | |
| except Exception as e: | |
| logger.error(f"Error creating ZIP package: {str(e)}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return None | |
| def create_interface(): | |
| augmenter = PolygonAugmentation(tolerance=2.0, area_threshold=0.01, debug=True) | |
| def process_batch_augmentation( | |
| images, json_files, num_augmentations, | |
| rotate_enabled, rotate_min, rotate_max, | |
| hflip_enabled, vflip_enabled, | |
| scale_enabled, scale_min, scale_max, | |
| brightness_enabled, brightness_min, brightness_max, | |
| dropout_enabled, dropout_min, dropout_max | |
| ): | |
| if not images or not json_files: | |
| return [], "No images or JSON files uploaded", None | |
| # Pair images with JSON files | |
| image_json_pairs = [] | |
| min_length = min(len(images), len(json_files)) | |
| for i in range(min_length): | |
| if images[i] is not None and json_files[i] is not None: | |
| try: | |
| image = Image.open(images[i].name) | |
| # Load JSON file content properly | |
| json_path = json_files[i].name | |
| logger.info(f"Loading JSON from: {json_path}") | |
| with open(json_path, 'r', encoding='utf-8') as f: | |
| json_data = json.load(f) | |
| logger.info(f"Successfully loaded JSON with keys: {list(json_data.keys())}") | |
| image_json_pairs.append((image, json_data)) | |
| except Exception as e: | |
| logger.error(f"Error loading image/JSON pair {i}: {str(e)}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| continue | |
| if not image_json_pairs: | |
| return [], "No valid image-JSON pairs found", None | |
| # Configure augmentations based on user selections | |
| aug_configs = [] | |
| if rotate_enabled: | |
| aug_configs.append({ | |
| 'aug_type': 'rotate', | |
| 'param_range': (rotate_min, rotate_max) | |
| }) | |
| if hflip_enabled: | |
| aug_configs.append({ | |
| 'aug_type': 'horizontal_flip', | |
| 'param_range': (0, 1) | |
| }) | |
| if vflip_enabled: | |
| aug_configs.append({ | |
| 'aug_type': 'vertical_flip', | |
| 'param_range': (0, 1) | |
| }) | |
| if scale_enabled: | |
| aug_configs.append({ | |
| 'aug_type': 'scale', | |
| 'param_range': (scale_min, scale_max) | |
| }) | |
| if brightness_enabled: | |
| aug_configs.append({ | |
| 'aug_type': 'brightness_contrast', | |
| 'param_range': (brightness_min, brightness_max) | |
| }) | |
| if dropout_enabled: | |
| aug_configs.append({ | |
| 'aug_type': 'pixel_dropout', | |
| 'param_range': (dropout_min, dropout_max) | |
| }) | |
| if not aug_configs: | |
| return [], "No augmentation types selected", None | |
| # Process augmentations | |
| try: | |
| logger.info(f"Starting batch augmentation with {len(image_json_pairs)} image pairs and {len(aug_configs)} configurations") | |
| augmented_images = augmenter.batch_augment_images( | |
| image_json_pairs, aug_configs, num_augmentations | |
| ) | |
| # Create JSON summary | |
| json_summary = json.dumps([result['metadata'] for result in augmenter.augmented_results], indent=2) | |
| status = f"Generated {len(augmented_images)} augmented images from {len(image_json_pairs)} input pairs" | |
| logger.info(status) | |
| return augmented_images, json_summary, status | |
| except Exception as e: | |
| error_msg = f"Batch augmentation error: {str(e)}" | |
| logger.error(error_msg) | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return [], error_msg, None | |
| def download_package(): | |
| """Handle download package creation and return proper file data""" | |
| try: | |
| package_data = augmenter.create_download_package() | |
| if package_data is None: | |
| return None | |
| # Save the package to a temporary file for download | |
| import tempfile | |
| import os | |
| # Create temporary file with proper name | |
| temp_file = tempfile.NamedTemporaryFile( | |
| delete=False, | |
| suffix='.zip', | |
| prefix='augmented_dataset_' | |
| ) | |
| with open(temp_file.name, 'wb') as f: | |
| f.write(package_data) | |
| logger.info(f"Created download package: {temp_file.name}") | |
| return temp_file.name | |
| except Exception as e: | |
| logger.error(f"Error creating download package: {str(e)}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return None | |
| def show_mask_overlay(evt: gr.SelectData): | |
| if evt.index < len(augmenter.augmented_results): | |
| return augmenter.augmented_results[evt.index]['image'] | |
| return None | |
| with gr.Blocks(title="Dynamic Donut Polygon Augmentation") as demo: | |
| gr.Markdown("# π Dynamic Donut Polygon Augmentation Tool") | |
| gr.Markdown("Upload multiple images and JSON files to apply batch augmentation with configurable parameter ranges") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## π Input Files") | |
| images_input = gr.File( | |
| file_count="multiple", | |
| file_types=["image"], | |
| label="Upload Images" | |
| ) | |
| json_input = gr.File( | |
| file_count="multiple", | |
| file_types=[".json"], | |
| label="Upload LabelMe JSON Files" | |
| ) | |
| num_augmentations = gr.Slider( | |
| minimum=1, maximum=5, value=2, step=1, | |
| label="Augmentations per configuration" | |
| ) | |
| gr.Markdown("## βοΈ Augmentation Configuration") | |
| # Rotation parameters | |
| with gr.Group(): | |
| rotate_enabled = gr.Checkbox(label="Enable Rotation", value=True) | |
| with gr.Row(): | |
| rotate_min = gr.Slider(-45, 45, -15, label="Min Rotation (degrees)") | |
| rotate_max = gr.Slider(-45, 45, 15, label="Max Rotation (degrees)") | |
| # Flip parameters | |
| with gr.Group(): | |
| hflip_enabled = gr.Checkbox(label="Enable Horizontal Flip", value=True) | |
| vflip_enabled = gr.Checkbox(label="Enable Vertical Flip", value=False) | |
| # Scale parameters | |
| with gr.Group(): | |
| scale_enabled = gr.Checkbox(label="Enable Scale", value=True) | |
| with gr.Row(): | |
| scale_min = gr.Slider(0.7, 1.3, 0.9, label="Min Scale") | |
| scale_max = gr.Slider(0.7, 1.3, 1.1, label="Max Scale") | |
| # Brightness parameters | |
| with gr.Group(): | |
| brightness_enabled = gr.Checkbox(label="Enable Brightness/Contrast", value=True) | |
| with gr.Row(): | |
| brightness_min = gr.Slider(-0.3, 0.3, -0.1, label="Min Brightness") | |
| brightness_max = gr.Slider(-0.3, 0.3, 0.1, label="Max Brightness") | |
| # Dropout parameters | |
| with gr.Group(): | |
| dropout_enabled = gr.Checkbox(label="Enable Pixel Dropout", value=False) | |
| with gr.Row(): | |
| dropout_min = gr.Slider(0.01, 0.1, 0.02, label="Min Dropout") | |
| dropout_max = gr.Slider(0.01, 0.1, 0.05, label="Max Dropout") | |
| generate_btn = gr.Button("π Generate Augmentations", variant="primary") | |
| status_text = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(scale=2): | |
| gr.Markdown("## πΌοΈ Augmented Results") | |
| gr.Markdown("*Click on any image to view with enhanced mask overlay*") | |
| augmented_gallery = gr.Gallery( | |
| label="Augmented Images with Polygon Masks", | |
| show_label=False, | |
| elem_id="gallery", | |
| columns=3, | |
| rows=3, | |
| height="auto" | |
| ) | |
| with gr.Row(): | |
| download_btn = gr.Button("π₯ Download All (ZIP)", variant="secondary") | |
| download_file = gr.File(label="Download Package", visible=True) | |
| gr.Markdown("## π Augmentation Metadata") | |
| json_output = gr.Code( | |
| label="Generated Metadata JSON", | |
| language="json", | |
| lines=15 | |
| ) | |
| gr.Markdown("## π Enhanced Preview") | |
| mask_preview = gr.Image(label="Selected Image with Mask Overlay") | |
| # Event handlers | |
| generate_btn.click( | |
| process_batch_augmentation, | |
| inputs=[ | |
| images_input, json_input, num_augmentations, | |
| rotate_enabled, rotate_min, rotate_max, | |
| hflip_enabled, vflip_enabled, | |
| scale_enabled, scale_min, scale_max, | |
| brightness_enabled, brightness_min, brightness_max, | |
| dropout_enabled, dropout_min, dropout_max | |
| ], | |
| outputs=[augmented_gallery, json_output, status_text] | |
| ) | |
| download_btn.click( | |
| download_package, | |
| outputs=download_file | |
| ) | |
| augmented_gallery.select( | |
| show_mask_overlay, | |
| outputs=mask_preview | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch() |