File size: 5,556 Bytes
6f98a26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""Visual comparison generation for evaluation."""

import json
import logging
from pathlib import Path

import numpy as np
from PIL import Image, ImageDraw, ImageFont

logger = logging.getLogger(__name__)


class VisualComparator:
    """Generate visual comparisons between ground truth and predictions."""

    def __init__(self):
        """Initialize comparator."""
        self.colors = {
            "ground_truth": (0, 255, 0, 128),  # Green
            "prediction": (255, 0, 0, 128),     # Red
            "true_positive": (255, 255, 0, 128),  # Yellow
            "false_positive": (255, 0, 0, 128),   # Red
            "false_negative": (0, 0, 255, 128),   # Blue
        }

    def create_comparison(
        self, image_dir: Path, output_path: Path | None = None
    ) -> Path:
        """Create visual comparison for image.

        Args:
            image_dir: Directory containing image and masks
            output_path: Optional output path (default: image_dir/comparison.png)

        Returns:
            Path to generated comparison image

        Raises:
            ValueError: If required files are missing
        """
        # Load original image
        image_path = image_dir / "image.jpg"
        if not image_path.exists():
            raise ValueError(f"Image not found: {image_path}")

        original = Image.open(image_path).convert("RGBA")
        width, height = original.size

        # Create overlays
        gt_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
        pred_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))

        # Load ground truth masks
        gt_dir = image_dir / "ground_truth"
        if gt_dir.exists():
            gt_meta_path = gt_dir / "metadata.json"
            if gt_meta_path.exists():
                with open(gt_meta_path) as f:
                    gt_meta = json.load(f)

                for mask_info in gt_meta.get("masks", []):
                    mask_path = gt_dir / mask_info["filename"]
                    if not mask_path.exists():
                        continue

                    mask = Image.open(mask_path).convert("L")
                    colored_mask = Image.new("RGBA", (width, height), self.colors["ground_truth"])
                    colored_mask.putalpha(mask)
                    gt_overlay = Image.alpha_composite(gt_overlay, colored_mask)

        # Load prediction masks
        pred_dir = image_dir / "inference"
        if pred_dir.exists():
            pred_meta_path = pred_dir / "metadata.json"
            if pred_meta_path.exists():
                with open(pred_meta_path) as f:
                    pred_meta = json.load(f)

                for mask_info in pred_meta.get("masks", []):
                    mask_path = pred_dir / mask_info["filename"]
                    if not mask_path.exists():
                        continue

                    mask = Image.open(mask_path).convert("L")
                    colored_mask = Image.new("RGBA", (width, height), self.colors["prediction"])
                    colored_mask.putalpha(mask)
                    pred_overlay = Image.alpha_composite(pred_overlay, colored_mask)

        # Composite images
        result = Image.alpha_composite(original, gt_overlay)
        result = Image.alpha_composite(result, pred_overlay)

        # Add legend
        result = self._add_legend(result)

        # Save
        if output_path is None:
            output_path = image_dir / "comparison.png"

        result.convert("RGB").save(output_path)
        logger.debug(f"Saved comparison to {output_path}")

        return output_path

    def _add_legend(self, image: Image.Image) -> Image.Image:
        """Add color legend to image.

        Args:
            image: Input image

        Returns:
            Image with legend
        """
        # Create legend area
        legend_height = 60
        legend_img = Image.new("RGB", (image.width, image.height + legend_height), (255, 255, 255))
        legend_img.paste(image, (0, 0))

        draw = ImageDraw.Draw(legend_img)

        # Draw legend items
        x_offset = 10
        y_offset = image.height + 10

        items = [
            ("Ground Truth", self.colors["ground_truth"][:3]),
            ("Prediction", self.colors["prediction"][:3]),
        ]

        for label, color in items:
            # Draw color box
            draw.rectangle([x_offset, y_offset, x_offset + 30, y_offset + 30], fill=color)

            # Draw label
            draw.text((x_offset + 40, y_offset + 5), label, fill=(0, 0, 0))

            x_offset += 200

        return legend_img

    def generate_all_comparisons(self, cache_dir: Path) -> list[Path]:
        """Generate comparisons for all images in cache.

        Args:
            cache_dir: Cache directory

        Returns:
            List of paths to generated comparisons
        """
        comparison_paths = []

        for class_dir in cache_dir.iterdir():
            if not class_dir.is_dir():
                continue

            for image_dir in class_dir.iterdir():
                if not image_dir.is_dir():
                    continue

                try:
                    comparison_path = self.create_comparison(image_dir)
                    comparison_paths.append(comparison_path)
                except Exception as e:
                    logger.error(f"Failed to create comparison for {image_dir}: {e}")
                    continue

        logger.info(f"Generated {len(comparison_paths)} comparison images")
        return comparison_paths