|
|
"""Visual comparison generation for evaluation.""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class VisualComparator: |
|
|
"""Generate visual comparisons between ground truth and predictions.""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize comparator.""" |
|
|
self.colors = { |
|
|
"ground_truth": (0, 255, 0, 128), |
|
|
"prediction": (255, 0, 0, 128), |
|
|
"true_positive": (255, 255, 0, 128), |
|
|
"false_positive": (255, 0, 0, 128), |
|
|
"false_negative": (0, 0, 255, 128), |
|
|
} |
|
|
|
|
|
def create_comparison( |
|
|
self, image_dir: Path, output_path: Path | None = None |
|
|
) -> Path: |
|
|
"""Create visual comparison for image. |
|
|
|
|
|
Args: |
|
|
image_dir: Directory containing image and masks |
|
|
output_path: Optional output path (default: image_dir/comparison.png) |
|
|
|
|
|
Returns: |
|
|
Path to generated comparison image |
|
|
|
|
|
Raises: |
|
|
ValueError: If required files are missing |
|
|
""" |
|
|
|
|
|
image_path = image_dir / "image.jpg" |
|
|
if not image_path.exists(): |
|
|
raise ValueError(f"Image not found: {image_path}") |
|
|
|
|
|
original = Image.open(image_path).convert("RGBA") |
|
|
width, height = original.size |
|
|
|
|
|
|
|
|
gt_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0)) |
|
|
pred_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0)) |
|
|
|
|
|
|
|
|
gt_dir = image_dir / "ground_truth" |
|
|
if gt_dir.exists(): |
|
|
gt_meta_path = gt_dir / "metadata.json" |
|
|
if gt_meta_path.exists(): |
|
|
with open(gt_meta_path) as f: |
|
|
gt_meta = json.load(f) |
|
|
|
|
|
for mask_info in gt_meta.get("masks", []): |
|
|
mask_path = gt_dir / mask_info["filename"] |
|
|
if not mask_path.exists(): |
|
|
continue |
|
|
|
|
|
mask = Image.open(mask_path).convert("L") |
|
|
colored_mask = Image.new("RGBA", (width, height), self.colors["ground_truth"]) |
|
|
colored_mask.putalpha(mask) |
|
|
gt_overlay = Image.alpha_composite(gt_overlay, colored_mask) |
|
|
|
|
|
|
|
|
pred_dir = image_dir / "inference" |
|
|
if pred_dir.exists(): |
|
|
pred_meta_path = pred_dir / "metadata.json" |
|
|
if pred_meta_path.exists(): |
|
|
with open(pred_meta_path) as f: |
|
|
pred_meta = json.load(f) |
|
|
|
|
|
for mask_info in pred_meta.get("masks", []): |
|
|
mask_path = pred_dir / mask_info["filename"] |
|
|
if not mask_path.exists(): |
|
|
continue |
|
|
|
|
|
mask = Image.open(mask_path).convert("L") |
|
|
colored_mask = Image.new("RGBA", (width, height), self.colors["prediction"]) |
|
|
colored_mask.putalpha(mask) |
|
|
pred_overlay = Image.alpha_composite(pred_overlay, colored_mask) |
|
|
|
|
|
|
|
|
result = Image.alpha_composite(original, gt_overlay) |
|
|
result = Image.alpha_composite(result, pred_overlay) |
|
|
|
|
|
|
|
|
result = self._add_legend(result) |
|
|
|
|
|
|
|
|
if output_path is None: |
|
|
output_path = image_dir / "comparison.png" |
|
|
|
|
|
result.convert("RGB").save(output_path) |
|
|
logger.debug(f"Saved comparison to {output_path}") |
|
|
|
|
|
return output_path |
|
|
|
|
|
def _add_legend(self, image: Image.Image) -> Image.Image: |
|
|
"""Add color legend to image. |
|
|
|
|
|
Args: |
|
|
image: Input image |
|
|
|
|
|
Returns: |
|
|
Image with legend |
|
|
""" |
|
|
|
|
|
legend_height = 60 |
|
|
legend_img = Image.new("RGB", (image.width, image.height + legend_height), (255, 255, 255)) |
|
|
legend_img.paste(image, (0, 0)) |
|
|
|
|
|
draw = ImageDraw.Draw(legend_img) |
|
|
|
|
|
|
|
|
x_offset = 10 |
|
|
y_offset = image.height + 10 |
|
|
|
|
|
items = [ |
|
|
("Ground Truth", self.colors["ground_truth"][:3]), |
|
|
("Prediction", self.colors["prediction"][:3]), |
|
|
] |
|
|
|
|
|
for label, color in items: |
|
|
|
|
|
draw.rectangle([x_offset, y_offset, x_offset + 30, y_offset + 30], fill=color) |
|
|
|
|
|
|
|
|
draw.text((x_offset + 40, y_offset + 5), label, fill=(0, 0, 0)) |
|
|
|
|
|
x_offset += 200 |
|
|
|
|
|
return legend_img |
|
|
|
|
|
def generate_all_comparisons(self, cache_dir: Path) -> list[Path]: |
|
|
"""Generate comparisons for all images in cache. |
|
|
|
|
|
Args: |
|
|
cache_dir: Cache directory |
|
|
|
|
|
Returns: |
|
|
List of paths to generated comparisons |
|
|
""" |
|
|
comparison_paths = [] |
|
|
|
|
|
for class_dir in cache_dir.iterdir(): |
|
|
if not class_dir.is_dir(): |
|
|
continue |
|
|
|
|
|
for image_dir in class_dir.iterdir(): |
|
|
if not image_dir.is_dir(): |
|
|
continue |
|
|
|
|
|
try: |
|
|
comparison_path = self.create_comparison(image_dir) |
|
|
comparison_paths.append(comparison_path) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to create comparison for {image_dir}: {e}") |
|
|
continue |
|
|
|
|
|
logger.info(f"Generated {len(comparison_paths)} comparison images") |
|
|
return comparison_paths |
|
|
|