sam3 / metrics_evaluation /visualization /visual_comparison.py
Thibaut's picture
Implement metrics evaluation system - CVAT extraction, SAM3 inference, metrics calculation, visualization, and main pipeline
6f98a26
"""Visual comparison generation for evaluation."""
import json
import logging
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw, ImageFont
logger = logging.getLogger(__name__)
class VisualComparator:
"""Generate visual comparisons between ground truth and predictions."""
def __init__(self):
"""Initialize comparator."""
self.colors = {
"ground_truth": (0, 255, 0, 128), # Green
"prediction": (255, 0, 0, 128), # Red
"true_positive": (255, 255, 0, 128), # Yellow
"false_positive": (255, 0, 0, 128), # Red
"false_negative": (0, 0, 255, 128), # Blue
}
def create_comparison(
self, image_dir: Path, output_path: Path | None = None
) -> Path:
"""Create visual comparison for image.
Args:
image_dir: Directory containing image and masks
output_path: Optional output path (default: image_dir/comparison.png)
Returns:
Path to generated comparison image
Raises:
ValueError: If required files are missing
"""
# Load original image
image_path = image_dir / "image.jpg"
if not image_path.exists():
raise ValueError(f"Image not found: {image_path}")
original = Image.open(image_path).convert("RGBA")
width, height = original.size
# Create overlays
gt_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
pred_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
# Load ground truth masks
gt_dir = image_dir / "ground_truth"
if gt_dir.exists():
gt_meta_path = gt_dir / "metadata.json"
if gt_meta_path.exists():
with open(gt_meta_path) as f:
gt_meta = json.load(f)
for mask_info in gt_meta.get("masks", []):
mask_path = gt_dir / mask_info["filename"]
if not mask_path.exists():
continue
mask = Image.open(mask_path).convert("L")
colored_mask = Image.new("RGBA", (width, height), self.colors["ground_truth"])
colored_mask.putalpha(mask)
gt_overlay = Image.alpha_composite(gt_overlay, colored_mask)
# Load prediction masks
pred_dir = image_dir / "inference"
if pred_dir.exists():
pred_meta_path = pred_dir / "metadata.json"
if pred_meta_path.exists():
with open(pred_meta_path) as f:
pred_meta = json.load(f)
for mask_info in pred_meta.get("masks", []):
mask_path = pred_dir / mask_info["filename"]
if not mask_path.exists():
continue
mask = Image.open(mask_path).convert("L")
colored_mask = Image.new("RGBA", (width, height), self.colors["prediction"])
colored_mask.putalpha(mask)
pred_overlay = Image.alpha_composite(pred_overlay, colored_mask)
# Composite images
result = Image.alpha_composite(original, gt_overlay)
result = Image.alpha_composite(result, pred_overlay)
# Add legend
result = self._add_legend(result)
# Save
if output_path is None:
output_path = image_dir / "comparison.png"
result.convert("RGB").save(output_path)
logger.debug(f"Saved comparison to {output_path}")
return output_path
def _add_legend(self, image: Image.Image) -> Image.Image:
"""Add color legend to image.
Args:
image: Input image
Returns:
Image with legend
"""
# Create legend area
legend_height = 60
legend_img = Image.new("RGB", (image.width, image.height + legend_height), (255, 255, 255))
legend_img.paste(image, (0, 0))
draw = ImageDraw.Draw(legend_img)
# Draw legend items
x_offset = 10
y_offset = image.height + 10
items = [
("Ground Truth", self.colors["ground_truth"][:3]),
("Prediction", self.colors["prediction"][:3]),
]
for label, color in items:
# Draw color box
draw.rectangle([x_offset, y_offset, x_offset + 30, y_offset + 30], fill=color)
# Draw label
draw.text((x_offset + 40, y_offset + 5), label, fill=(0, 0, 0))
x_offset += 200
return legend_img
def generate_all_comparisons(self, cache_dir: Path) -> list[Path]:
"""Generate comparisons for all images in cache.
Args:
cache_dir: Cache directory
Returns:
List of paths to generated comparisons
"""
comparison_paths = []
for class_dir in cache_dir.iterdir():
if not class_dir.is_dir():
continue
for image_dir in class_dir.iterdir():
if not image_dir.is_dir():
continue
try:
comparison_path = self.create_comparison(image_dir)
comparison_paths.append(comparison_path)
except Exception as e:
logger.error(f"Failed to create comparison for {image_dir}: {e}")
continue
logger.info(f"Generated {len(comparison_paths)} comparison images")
return comparison_paths