"""Visual comparison generation for evaluation.""" import json import logging from pathlib import Path import numpy as np from PIL import Image, ImageDraw, ImageFont logger = logging.getLogger(__name__) class VisualComparator: """Generate visual comparisons between ground truth and predictions.""" def __init__(self): """Initialize comparator.""" self.colors = { "ground_truth": (0, 255, 0, 128), # Green "prediction": (255, 0, 0, 128), # Red "true_positive": (255, 255, 0, 128), # Yellow "false_positive": (255, 0, 0, 128), # Red "false_negative": (0, 0, 255, 128), # Blue } def create_comparison( self, image_dir: Path, output_path: Path | None = None ) -> Path: """Create visual comparison for image. Args: image_dir: Directory containing image and masks output_path: Optional output path (default: image_dir/comparison.png) Returns: Path to generated comparison image Raises: ValueError: If required files are missing """ # Load original image image_path = image_dir / "image.jpg" if not image_path.exists(): raise ValueError(f"Image not found: {image_path}") original = Image.open(image_path).convert("RGBA") width, height = original.size # Create overlays gt_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0)) pred_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0)) # Load ground truth masks gt_dir = image_dir / "ground_truth" if gt_dir.exists(): gt_meta_path = gt_dir / "metadata.json" if gt_meta_path.exists(): with open(gt_meta_path) as f: gt_meta = json.load(f) for mask_info in gt_meta.get("masks", []): mask_path = gt_dir / mask_info["filename"] if not mask_path.exists(): continue mask = Image.open(mask_path).convert("L") colored_mask = Image.new("RGBA", (width, height), self.colors["ground_truth"]) colored_mask.putalpha(mask) gt_overlay = Image.alpha_composite(gt_overlay, colored_mask) # Load prediction masks pred_dir = image_dir / "inference" if pred_dir.exists(): pred_meta_path = pred_dir / "metadata.json" if pred_meta_path.exists(): with open(pred_meta_path) as f: pred_meta = json.load(f) for mask_info in pred_meta.get("masks", []): mask_path = pred_dir / mask_info["filename"] if not mask_path.exists(): continue mask = Image.open(mask_path).convert("L") colored_mask = Image.new("RGBA", (width, height), self.colors["prediction"]) colored_mask.putalpha(mask) pred_overlay = Image.alpha_composite(pred_overlay, colored_mask) # Composite images result = Image.alpha_composite(original, gt_overlay) result = Image.alpha_composite(result, pred_overlay) # Add legend result = self._add_legend(result) # Save if output_path is None: output_path = image_dir / "comparison.png" result.convert("RGB").save(output_path) logger.debug(f"Saved comparison to {output_path}") return output_path def _add_legend(self, image: Image.Image) -> Image.Image: """Add color legend to image. Args: image: Input image Returns: Image with legend """ # Create legend area legend_height = 60 legend_img = Image.new("RGB", (image.width, image.height + legend_height), (255, 255, 255)) legend_img.paste(image, (0, 0)) draw = ImageDraw.Draw(legend_img) # Draw legend items x_offset = 10 y_offset = image.height + 10 items = [ ("Ground Truth", self.colors["ground_truth"][:3]), ("Prediction", self.colors["prediction"][:3]), ] for label, color in items: # Draw color box draw.rectangle([x_offset, y_offset, x_offset + 30, y_offset + 30], fill=color) # Draw label draw.text((x_offset + 40, y_offset + 5), label, fill=(0, 0, 0)) x_offset += 200 return legend_img def generate_all_comparisons(self, cache_dir: Path) -> list[Path]: """Generate comparisons for all images in cache. Args: cache_dir: Cache directory Returns: List of paths to generated comparisons """ comparison_paths = [] for class_dir in cache_dir.iterdir(): if not class_dir.is_dir(): continue for image_dir in class_dir.iterdir(): if not image_dir.is_dir(): continue try: comparison_path = self.create_comparison(image_dir) comparison_paths.append(comparison_path) except Exception as e: logger.error(f"Failed to create comparison for {image_dir}: {e}") continue logger.info(f"Generated {len(comparison_paths)} comparison images") return comparison_paths