Implement metrics evaluation system - CVAT extraction, SAM3 inference, metrics calculation, visualization, and main pipeline
Browse files- metrics_evaluation/extraction/cvat_extractor.py +450 -0
- metrics_evaluation/inference/sam3_inference.py +265 -0
- metrics_evaluation/metrics/metrics_calculator.py +419 -0
- metrics_evaluation/run_evaluation.py +271 -0
- metrics_evaluation/utils/__init__.py +0 -0
- metrics_evaluation/utils/logging_config.py +44 -0
- metrics_evaluation/visualization/visual_comparison.py +168 -0
metrics_evaluation/extraction/cvat_extractor.py
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CVAT data extraction for SAM3 evaluation."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import random
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
from ..config.config_models import EvaluationConfig
|
| 10 |
+
from ..cvat_api.client import CVATClient
|
| 11 |
+
from ..schema.core.annotation.mask import Mask
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CVATExtractor:
|
| 17 |
+
"""Extract annotated images and masks from CVAT."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, config: EvaluationConfig):
|
| 20 |
+
"""Initialize extractor with configuration.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
config: Evaluation configuration
|
| 24 |
+
|
| 25 |
+
Raises:
|
| 26 |
+
ValueError: If configuration is invalid
|
| 27 |
+
"""
|
| 28 |
+
self.config = config
|
| 29 |
+
self.client: CVATClient | None = None
|
| 30 |
+
self.project_id: int | None = None
|
| 31 |
+
|
| 32 |
+
def connect(self) -> None:
|
| 33 |
+
"""Connect to CVAT API.
|
| 34 |
+
|
| 35 |
+
Raises:
|
| 36 |
+
ConnectionError: If connection fails
|
| 37 |
+
ValueError: If credentials are invalid
|
| 38 |
+
"""
|
| 39 |
+
import os
|
| 40 |
+
from dotenv import load_dotenv
|
| 41 |
+
|
| 42 |
+
load_dotenv()
|
| 43 |
+
|
| 44 |
+
username = os.getenv("CVAT_USERNAME")
|
| 45 |
+
password = os.getenv("CVAT_PASSWORD")
|
| 46 |
+
|
| 47 |
+
if not username or not password:
|
| 48 |
+
raise ValueError(
|
| 49 |
+
"CVAT credentials not found in .env file. "
|
| 50 |
+
"Required: CVAT_USERNAME, CVAT_PASSWORD"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
self.client = CVATClient(
|
| 55 |
+
host=self.config.cvat.url,
|
| 56 |
+
credentials=(username, password),
|
| 57 |
+
organization=self.config.cvat.organization,
|
| 58 |
+
)
|
| 59 |
+
logger.info(f"Connected to CVAT at {self.config.cvat.url}")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
raise ConnectionError(f"Failed to connect to CVAT: {e}") from e
|
| 62 |
+
|
| 63 |
+
def find_training_project(self) -> int:
|
| 64 |
+
"""Find AI training project in CVAT.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Project ID
|
| 68 |
+
|
| 69 |
+
Raises:
|
| 70 |
+
ValueError: If no suitable project found
|
| 71 |
+
"""
|
| 72 |
+
if not self.client:
|
| 73 |
+
raise ValueError("Not connected to CVAT. Call connect() first.")
|
| 74 |
+
|
| 75 |
+
projects = self.client.projects.list()
|
| 76 |
+
filter_str = self.config.cvat.project_name_filter.lower()
|
| 77 |
+
|
| 78 |
+
matching_projects = [
|
| 79 |
+
p for p in projects if filter_str in p.name.lower()
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
if not matching_projects:
|
| 83 |
+
available = [p.name for p in projects]
|
| 84 |
+
raise ValueError(
|
| 85 |
+
f"No project found with '{filter_str}' in name.\n"
|
| 86 |
+
f"Available projects: {available}"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
if len(matching_projects) > 1:
|
| 90 |
+
names = [p.name for p in matching_projects]
|
| 91 |
+
logger.warning(
|
| 92 |
+
f"Multiple matching projects found: {names}. Using first one."
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
project = matching_projects[0]
|
| 96 |
+
self.project_id = project.id
|
| 97 |
+
logger.info(f"Using project: {project.name} (ID: {project.id})")
|
| 98 |
+
return project.id
|
| 99 |
+
|
| 100 |
+
def discover_images(self) -> dict[str, list[dict[str, Any]]]:
|
| 101 |
+
"""Discover images with target labels.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Dict mapping class names to lists of image metadata
|
| 105 |
+
|
| 106 |
+
Raises:
|
| 107 |
+
ValueError: If no images found with target labels
|
| 108 |
+
"""
|
| 109 |
+
if not self.client or not self.project_id:
|
| 110 |
+
raise ValueError("Must connect and find project first")
|
| 111 |
+
|
| 112 |
+
tasks = self.client.tasks.list(project_id=self.project_id)
|
| 113 |
+
|
| 114 |
+
if not tasks:
|
| 115 |
+
raise ValueError(f"No tasks found in project {self.project_id}")
|
| 116 |
+
|
| 117 |
+
logger.info(f"Found {len(tasks)} tasks in project")
|
| 118 |
+
|
| 119 |
+
# Collect all images with annotations
|
| 120 |
+
class_images: dict[str, list[dict[str, Any]]] = {
|
| 121 |
+
class_name: [] for class_name in self.config.classes.keys()
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
for task in tasks:
|
| 125 |
+
try:
|
| 126 |
+
jobs = self.client.jobs.list(task_id=task.id)
|
| 127 |
+
|
| 128 |
+
for job in jobs:
|
| 129 |
+
annotations = self.client.annotations.get_job_annotations(job.id)
|
| 130 |
+
|
| 131 |
+
# Get frames with annotations
|
| 132 |
+
if not annotations or not hasattr(annotations, 'shapes'):
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
# Group annotations by frame
|
| 136 |
+
frame_annotations: dict[int, list] = {}
|
| 137 |
+
for shape in annotations.shapes:
|
| 138 |
+
frame_id = shape.frame
|
| 139 |
+
if frame_id not in frame_annotations:
|
| 140 |
+
frame_annotations[frame_id] = []
|
| 141 |
+
frame_annotations[frame_id].append(shape)
|
| 142 |
+
|
| 143 |
+
# Check which classes are present in each frame
|
| 144 |
+
for frame_id, shapes in frame_annotations.items():
|
| 145 |
+
labels_in_frame = {shape.label_name for shape in shapes if hasattr(shape, 'type') and shape.type == 'mask'}
|
| 146 |
+
|
| 147 |
+
for class_name in self.config.classes.keys():
|
| 148 |
+
if class_name in labels_in_frame:
|
| 149 |
+
class_images[class_name].append({
|
| 150 |
+
"task_id": task.id,
|
| 151 |
+
"job_id": job.id,
|
| 152 |
+
"frame_id": frame_id,
|
| 153 |
+
"task_name": task.name,
|
| 154 |
+
"labels": list(labels_in_frame),
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
except Exception as e:
|
| 158 |
+
logger.warning(f"Error processing task {task.id}: {e}")
|
| 159 |
+
continue
|
| 160 |
+
|
| 161 |
+
# Log discovered counts
|
| 162 |
+
for class_name, images in class_images.items():
|
| 163 |
+
logger.info(f"Found {len(images)} images with label '{class_name}'")
|
| 164 |
+
|
| 165 |
+
# Check if we have enough images
|
| 166 |
+
for class_name, images in class_images.items():
|
| 167 |
+
requested = self.config.classes[class_name]
|
| 168 |
+
if len(images) < requested:
|
| 169 |
+
logger.warning(
|
| 170 |
+
f"Class '{class_name}': Requested {requested} images but only found {len(images)}"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return class_images
|
| 174 |
+
|
| 175 |
+
def sample_images(
|
| 176 |
+
self, class_images: dict[str, list[dict[str, Any]]]
|
| 177 |
+
) -> dict[str, list[dict[str, Any]]]:
|
| 178 |
+
"""Randomly sample images for each class.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
class_images: Dict mapping class names to image lists
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Dict with sampled images
|
| 185 |
+
"""
|
| 186 |
+
sampled = {}
|
| 187 |
+
|
| 188 |
+
for class_name, images in class_images.items():
|
| 189 |
+
requested = self.config.classes[class_name]
|
| 190 |
+
available = len(images)
|
| 191 |
+
|
| 192 |
+
if available == 0:
|
| 193 |
+
logger.error(f"No images available for class '{class_name}'")
|
| 194 |
+
sampled[class_name] = []
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
sample_size = min(requested, available)
|
| 198 |
+
sampled[class_name] = random.sample(images, sample_size)
|
| 199 |
+
|
| 200 |
+
logger.info(
|
| 201 |
+
f"Sampled {sample_size}/{requested} images for class '{class_name}'"
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
return sampled
|
| 205 |
+
|
| 206 |
+
def download_images_and_masks(
|
| 207 |
+
self, sampled_images: dict[str, list[dict[str, Any]]]
|
| 208 |
+
) -> dict[str, list[Path]]:
|
| 209 |
+
"""Download images and extract ground truth masks.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
sampled_images: Dict of sampled image metadata
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
Dict mapping class names to lists of image paths
|
| 216 |
+
|
| 217 |
+
Raises:
|
| 218 |
+
ValueError: If download or extraction fails critically
|
| 219 |
+
"""
|
| 220 |
+
if not self.client:
|
| 221 |
+
raise ValueError("Not connected to CVAT")
|
| 222 |
+
|
| 223 |
+
cache_dir = self.config.get_cache_path()
|
| 224 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 225 |
+
|
| 226 |
+
downloaded_paths: dict[str, list[Path]] = {
|
| 227 |
+
class_name: [] for class_name in self.config.classes.keys()
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
total_images = sum(len(images) for images in sampled_images.values())
|
| 231 |
+
processed = 0
|
| 232 |
+
|
| 233 |
+
for class_name, images in sampled_images.items():
|
| 234 |
+
for img_meta in images:
|
| 235 |
+
processed += 1
|
| 236 |
+
logger.info(
|
| 237 |
+
f"Processing {processed}/{total_images}: "
|
| 238 |
+
f"{class_name} - Task {img_meta['task_id']} Frame {img_meta['frame_id']}"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
try:
|
| 242 |
+
image_path = self._download_image(class_name, img_meta, cache_dir)
|
| 243 |
+
self._extract_masks(class_name, img_meta, image_path)
|
| 244 |
+
downloaded_paths[class_name].append(image_path)
|
| 245 |
+
|
| 246 |
+
except Exception as e:
|
| 247 |
+
logger.error(
|
| 248 |
+
f"Failed to process {class_name} image "
|
| 249 |
+
f"(task={img_meta['task_id']}, frame={img_meta['frame_id']}): {e}"
|
| 250 |
+
)
|
| 251 |
+
continue
|
| 252 |
+
|
| 253 |
+
# Log final counts
|
| 254 |
+
for class_name, paths in downloaded_paths.items():
|
| 255 |
+
logger.info(f"Successfully processed {len(paths)} images for '{class_name}'")
|
| 256 |
+
|
| 257 |
+
return downloaded_paths
|
| 258 |
+
|
| 259 |
+
def _download_image(
|
| 260 |
+
self, class_name: str, img_meta: dict[str, Any], cache_dir: Path
|
| 261 |
+
) -> Path:
|
| 262 |
+
"""Download single image.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
class_name: Class label
|
| 266 |
+
img_meta: Image metadata
|
| 267 |
+
cache_dir: Cache directory
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
Path to downloaded image
|
| 271 |
+
|
| 272 |
+
Raises:
|
| 273 |
+
ValueError: If download fails
|
| 274 |
+
"""
|
| 275 |
+
# Create output directory
|
| 276 |
+
image_name = f"{img_meta['task_name']}_frame_{img_meta['frame_id']:06d}"
|
| 277 |
+
output_dir = cache_dir / class_name / image_name
|
| 278 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 279 |
+
|
| 280 |
+
image_path = output_dir / "image.jpg"
|
| 281 |
+
|
| 282 |
+
# Check cache
|
| 283 |
+
if image_path.exists():
|
| 284 |
+
logger.debug(f"Image already cached: {image_path}")
|
| 285 |
+
return image_path
|
| 286 |
+
|
| 287 |
+
# Download from CVAT
|
| 288 |
+
if not self.client:
|
| 289 |
+
raise ValueError("Client not initialized")
|
| 290 |
+
|
| 291 |
+
try:
|
| 292 |
+
image_data = self.client.tasks.get_frame(
|
| 293 |
+
img_meta["task_id"], img_meta["frame_id"]
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
with open(image_path, "wb") as f:
|
| 297 |
+
f.write(image_data)
|
| 298 |
+
|
| 299 |
+
logger.debug(f"Downloaded image to {image_path}")
|
| 300 |
+
return image_path
|
| 301 |
+
|
| 302 |
+
except Exception as e:
|
| 303 |
+
raise ValueError(
|
| 304 |
+
f"Failed to download image from task {img_meta['task_id']} "
|
| 305 |
+
f"frame {img_meta['frame_id']}: {e}"
|
| 306 |
+
) from e
|
| 307 |
+
|
| 308 |
+
def _extract_masks(
|
| 309 |
+
self, class_name: str, img_meta: dict[str, Any], image_path: Path
|
| 310 |
+
) -> None:
|
| 311 |
+
"""Extract ground truth masks for image.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
class_name: Class label
|
| 315 |
+
img_meta: Image metadata
|
| 316 |
+
image_path: Path to image file
|
| 317 |
+
|
| 318 |
+
Raises:
|
| 319 |
+
ValueError: If mask extraction fails
|
| 320 |
+
"""
|
| 321 |
+
output_dir = image_path.parent / "ground_truth"
|
| 322 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 323 |
+
|
| 324 |
+
# Check if already extracted
|
| 325 |
+
metadata_path = output_dir / "metadata.json"
|
| 326 |
+
if metadata_path.exists():
|
| 327 |
+
logger.debug(f"Masks already extracted: {output_dir}")
|
| 328 |
+
return
|
| 329 |
+
|
| 330 |
+
if not self.client:
|
| 331 |
+
raise ValueError("Client not initialized")
|
| 332 |
+
|
| 333 |
+
# Get annotations for this job
|
| 334 |
+
try:
|
| 335 |
+
annotations = self.client.annotations.get_job_annotations(
|
| 336 |
+
img_meta["job_id"]
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if not annotations or not hasattr(annotations, 'shapes'):
|
| 340 |
+
raise ValueError("No annotations found")
|
| 341 |
+
|
| 342 |
+
# Filter masks for this frame
|
| 343 |
+
frame_masks = [
|
| 344 |
+
shape
|
| 345 |
+
for shape in annotations.shapes
|
| 346 |
+
if shape.frame == img_meta["frame_id"]
|
| 347 |
+
and hasattr(shape, 'type')
|
| 348 |
+
and shape.type == 'mask'
|
| 349 |
+
]
|
| 350 |
+
|
| 351 |
+
if not frame_masks:
|
| 352 |
+
logger.warning(
|
| 353 |
+
f"No mask annotations found for frame {img_meta['frame_id']}"
|
| 354 |
+
)
|
| 355 |
+
# Create empty metadata
|
| 356 |
+
with open(metadata_path, "w") as f:
|
| 357 |
+
json.dump({"masks": []}, f, indent=2)
|
| 358 |
+
return
|
| 359 |
+
|
| 360 |
+
# Get image dimensions
|
| 361 |
+
from PIL import Image
|
| 362 |
+
with Image.open(image_path) as img:
|
| 363 |
+
width, height = img.size
|
| 364 |
+
|
| 365 |
+
# Extract each mask
|
| 366 |
+
mask_metadata = []
|
| 367 |
+
label_counts: dict[str, int] = {}
|
| 368 |
+
|
| 369 |
+
for shape in frame_masks:
|
| 370 |
+
label = shape.label_name
|
| 371 |
+
if label not in label_counts:
|
| 372 |
+
label_counts[label] = 0
|
| 373 |
+
|
| 374 |
+
instance_idx = label_counts[label]
|
| 375 |
+
label_counts[label] += 1
|
| 376 |
+
|
| 377 |
+
# Convert CVAT RLE to mask
|
| 378 |
+
if not hasattr(shape, 'points') or not shape.points:
|
| 379 |
+
logger.warning(f"Shape missing points data, skipping")
|
| 380 |
+
continue
|
| 381 |
+
|
| 382 |
+
mask_filename = f"mask_{label}_{instance_idx}.png"
|
| 383 |
+
mask_path = output_dir / mask_filename
|
| 384 |
+
|
| 385 |
+
try:
|
| 386 |
+
mask = Mask.from_cvat_api_rle(
|
| 387 |
+
cvat_rle=shape.points,
|
| 388 |
+
width=width,
|
| 389 |
+
height=height,
|
| 390 |
+
file_path=str(mask_path),
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
mask_metadata.append({
|
| 394 |
+
"filename": mask_filename,
|
| 395 |
+
"label": label,
|
| 396 |
+
"instance_idx": instance_idx,
|
| 397 |
+
})
|
| 398 |
+
|
| 399 |
+
logger.debug(f"Extracted mask: {mask_filename}")
|
| 400 |
+
|
| 401 |
+
except Exception as e:
|
| 402 |
+
logger.error(f"Failed to convert mask for label {label}: {e}")
|
| 403 |
+
continue
|
| 404 |
+
|
| 405 |
+
# Save metadata
|
| 406 |
+
with open(metadata_path, "w") as f:
|
| 407 |
+
json.dump(
|
| 408 |
+
{
|
| 409 |
+
"image": str(image_path.name),
|
| 410 |
+
"width": width,
|
| 411 |
+
"height": height,
|
| 412 |
+
"masks": mask_metadata,
|
| 413 |
+
},
|
| 414 |
+
f,
|
| 415 |
+
indent=2,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
logger.info(f"Extracted {len(mask_metadata)} masks to {output_dir}")
|
| 419 |
+
|
| 420 |
+
except Exception as e:
|
| 421 |
+
raise ValueError(f"Failed to extract masks: {e}") from e
|
| 422 |
+
|
| 423 |
+
def run_extraction(self) -> dict[str, list[Path]]:
|
| 424 |
+
"""Run complete extraction pipeline.
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
Dict mapping class names to lists of image paths
|
| 428 |
+
|
| 429 |
+
Raises:
|
| 430 |
+
Exception: If any critical step fails
|
| 431 |
+
"""
|
| 432 |
+
logger.info("Starting CVAT extraction pipeline")
|
| 433 |
+
|
| 434 |
+
# Connect
|
| 435 |
+
self.connect()
|
| 436 |
+
|
| 437 |
+
# Find project
|
| 438 |
+
self.find_training_project()
|
| 439 |
+
|
| 440 |
+
# Discover images
|
| 441 |
+
class_images = self.discover_images()
|
| 442 |
+
|
| 443 |
+
# Sample
|
| 444 |
+
sampled = self.sample_images(class_images)
|
| 445 |
+
|
| 446 |
+
# Download and extract
|
| 447 |
+
paths = self.download_images_and_masks(sampled)
|
| 448 |
+
|
| 449 |
+
logger.info("CVAT extraction complete")
|
| 450 |
+
return paths
|
metrics_evaluation/inference/sam3_inference.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAM3 inference for evaluation."""
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import requests
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from ..config.config_models import EvaluationConfig
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SAM3Inferencer:
|
| 18 |
+
"""Run SAM3 inference on images."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, config: EvaluationConfig):
|
| 21 |
+
"""Initialize inferencer with configuration.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
config: Evaluation configuration
|
| 25 |
+
"""
|
| 26 |
+
self.config = config
|
| 27 |
+
self.endpoint = config.sam3.endpoint
|
| 28 |
+
self.timeout = config.sam3.timeout
|
| 29 |
+
self.retry_attempts = config.sam3.retry_attempts
|
| 30 |
+
|
| 31 |
+
def infer_single_image(self, image_path: Path, classes: list[str]) -> list[dict]:
|
| 32 |
+
"""Run SAM3 inference on single image.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
image_path: Path to image file
|
| 36 |
+
classes: List of class names to detect
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
List of detection results with masks
|
| 40 |
+
|
| 41 |
+
Raises:
|
| 42 |
+
ValueError: If inference fails after retries
|
| 43 |
+
"""
|
| 44 |
+
# Load and encode image
|
| 45 |
+
with open(image_path, "rb") as f:
|
| 46 |
+
image_data = f.read()
|
| 47 |
+
|
| 48 |
+
image_b64 = base64.b64encode(image_data).decode()
|
| 49 |
+
|
| 50 |
+
# Prepare request
|
| 51 |
+
payload = {
|
| 52 |
+
"inputs": image_b64,
|
| 53 |
+
"parameters": {"classes": classes}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
# Retry logic
|
| 57 |
+
last_error = None
|
| 58 |
+
for attempt in range(self.retry_attempts):
|
| 59 |
+
try:
|
| 60 |
+
start_time = time.time()
|
| 61 |
+
|
| 62 |
+
response = requests.post(
|
| 63 |
+
self.endpoint,
|
| 64 |
+
json=payload,
|
| 65 |
+
timeout=self.timeout
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
elapsed = time.time() - start_time
|
| 69 |
+
|
| 70 |
+
if response.status_code == 200:
|
| 71 |
+
results = response.json()
|
| 72 |
+
logger.debug(
|
| 73 |
+
f"Inference successful for {image_path.name} "
|
| 74 |
+
f"({elapsed:.2f}s, {len(results)} detections)"
|
| 75 |
+
)
|
| 76 |
+
return results
|
| 77 |
+
|
| 78 |
+
else:
|
| 79 |
+
last_error = f"HTTP {response.status_code}: {response.text}"
|
| 80 |
+
logger.warning(f"Inference failed (attempt {attempt + 1}): {last_error}")
|
| 81 |
+
|
| 82 |
+
except requests.Timeout:
|
| 83 |
+
last_error = f"Request timeout after {self.timeout}s"
|
| 84 |
+
logger.warning(f"Inference timeout (attempt {attempt + 1})")
|
| 85 |
+
|
| 86 |
+
except Exception as e:
|
| 87 |
+
last_error = str(e)
|
| 88 |
+
logger.warning(f"Inference error (attempt {attempt + 1}): {e}")
|
| 89 |
+
|
| 90 |
+
# Exponential backoff
|
| 91 |
+
if attempt < self.retry_attempts - 1:
|
| 92 |
+
sleep_time = 2 ** attempt
|
| 93 |
+
logger.debug(f"Retrying in {sleep_time}s...")
|
| 94 |
+
time.sleep(sleep_time)
|
| 95 |
+
|
| 96 |
+
raise ValueError(
|
| 97 |
+
f"Inference failed after {self.retry_attempts} attempts for {image_path}: {last_error}"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def save_inference_results(
|
| 101 |
+
self,
|
| 102 |
+
results: list[dict],
|
| 103 |
+
output_dir: Path,
|
| 104 |
+
image_width: int,
|
| 105 |
+
image_height: int
|
| 106 |
+
) -> None:
|
| 107 |
+
"""Save inference results as masks.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
results: SAM3 detection results
|
| 111 |
+
output_dir: Directory to save masks
|
| 112 |
+
image_width: Image width
|
| 113 |
+
image_height: Image height
|
| 114 |
+
|
| 115 |
+
Raises:
|
| 116 |
+
ValueError: If mask conversion fails
|
| 117 |
+
"""
|
| 118 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 119 |
+
|
| 120 |
+
mask_metadata = []
|
| 121 |
+
label_counts: dict[str, int] = {}
|
| 122 |
+
|
| 123 |
+
for result in results:
|
| 124 |
+
label = result.get("label", "unknown")
|
| 125 |
+
score = result.get("score", 0.0)
|
| 126 |
+
mask_b64 = result.get("mask")
|
| 127 |
+
|
| 128 |
+
if not mask_b64:
|
| 129 |
+
logger.warning(f"Result missing mask data for label {label}")
|
| 130 |
+
continue
|
| 131 |
+
|
| 132 |
+
# Count instances per label
|
| 133 |
+
if label not in label_counts:
|
| 134 |
+
label_counts[label] = 0
|
| 135 |
+
|
| 136 |
+
instance_idx = label_counts[label]
|
| 137 |
+
label_counts[label] += 1
|
| 138 |
+
|
| 139 |
+
# Decode mask
|
| 140 |
+
try:
|
| 141 |
+
mask_data = base64.b64decode(mask_b64)
|
| 142 |
+
mask_img = Image.open(io.BytesIO(mask_data))
|
| 143 |
+
|
| 144 |
+
# Convert to L mode (grayscale) if needed
|
| 145 |
+
if mask_img.mode != 'L':
|
| 146 |
+
mask_img = mask_img.convert('L')
|
| 147 |
+
|
| 148 |
+
# Validate dimensions
|
| 149 |
+
if mask_img.size != (image_width, image_height):
|
| 150 |
+
logger.warning(
|
| 151 |
+
f"Mask dimension mismatch: expected {image_width}x{image_height}, "
|
| 152 |
+
f"got {mask_img.size}. Resizing."
|
| 153 |
+
)
|
| 154 |
+
mask_img = mask_img.resize((image_width, image_height), Image.NEAREST)
|
| 155 |
+
|
| 156 |
+
# Save as PNG
|
| 157 |
+
mask_filename = f"mask_{label}_{instance_idx}.png"
|
| 158 |
+
mask_path = output_dir / mask_filename
|
| 159 |
+
mask_img.save(mask_path)
|
| 160 |
+
|
| 161 |
+
mask_metadata.append({
|
| 162 |
+
"filename": mask_filename,
|
| 163 |
+
"label": label,
|
| 164 |
+
"instance_idx": instance_idx,
|
| 165 |
+
"score": score,
|
| 166 |
+
})
|
| 167 |
+
|
| 168 |
+
logger.debug(f"Saved inference mask: {mask_filename}")
|
| 169 |
+
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logger.error(f"Failed to save mask for label {label}: {e}")
|
| 172 |
+
continue
|
| 173 |
+
|
| 174 |
+
# Save metadata
|
| 175 |
+
metadata_path = output_dir / "metadata.json"
|
| 176 |
+
with open(metadata_path, "w") as f:
|
| 177 |
+
json.dump(
|
| 178 |
+
{
|
| 179 |
+
"width": image_width,
|
| 180 |
+
"height": image_height,
|
| 181 |
+
"masks": mask_metadata,
|
| 182 |
+
},
|
| 183 |
+
f,
|
| 184 |
+
indent=2,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
logger.info(f"Saved {len(mask_metadata)} inference masks to {output_dir}")
|
| 188 |
+
|
| 189 |
+
def run_inference_batch(
|
| 190 |
+
self, image_paths: dict[str, list[Path]], force: bool = False
|
| 191 |
+
) -> dict[str, int]:
|
| 192 |
+
"""Run inference on batch of images.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
image_paths: Dict mapping class names to image paths
|
| 196 |
+
force: Force re-inference even if results exist
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Dict with inference statistics
|
| 200 |
+
|
| 201 |
+
Raises:
|
| 202 |
+
ValueError: If no images provided
|
| 203 |
+
"""
|
| 204 |
+
total_images = sum(len(paths) for paths in image_paths.values())
|
| 205 |
+
if total_images == 0:
|
| 206 |
+
raise ValueError("No images provided for inference")
|
| 207 |
+
|
| 208 |
+
logger.info(f"Starting inference on {total_images} images")
|
| 209 |
+
|
| 210 |
+
stats = {
|
| 211 |
+
"total": total_images,
|
| 212 |
+
"successful": 0,
|
| 213 |
+
"failed": 0,
|
| 214 |
+
"skipped": 0,
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
import io
|
| 218 |
+
|
| 219 |
+
processed = 0
|
| 220 |
+
|
| 221 |
+
for class_name, paths in image_paths.items():
|
| 222 |
+
for image_path in paths:
|
| 223 |
+
processed += 1
|
| 224 |
+
logger.info(
|
| 225 |
+
f"Inference {processed}/{total_images}: {image_path.parent.name}"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
inference_dir = image_path.parent / "inference"
|
| 229 |
+
metadata_path = inference_dir / "metadata.json"
|
| 230 |
+
|
| 231 |
+
# Check cache
|
| 232 |
+
if not force and metadata_path.exists():
|
| 233 |
+
logger.debug(f"Inference results already exist: {inference_dir}")
|
| 234 |
+
stats["skipped"] += 1
|
| 235 |
+
continue
|
| 236 |
+
|
| 237 |
+
try:
|
| 238 |
+
# Get image dimensions
|
| 239 |
+
with Image.open(image_path) as img:
|
| 240 |
+
width, height = img.size
|
| 241 |
+
|
| 242 |
+
# Run inference
|
| 243 |
+
results = self.infer_single_image(
|
| 244 |
+
image_path,
|
| 245 |
+
list(self.config.classes.keys())
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Save results
|
| 249 |
+
self.save_inference_results(
|
| 250 |
+
results, inference_dir, width, height
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
stats["successful"] += 1
|
| 254 |
+
|
| 255 |
+
except Exception as e:
|
| 256 |
+
logger.error(f"Inference failed for {image_path}: {e}")
|
| 257 |
+
stats["failed"] += 1
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
logger.info(
|
| 261 |
+
f"Inference complete: {stats['successful']} successful, "
|
| 262 |
+
f"{stats['failed']} failed, {stats['skipped']} skipped"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
return stats
|
metrics_evaluation/metrics/metrics_calculator.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Metrics calculation for SAM3 evaluation."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from scipy.optimize import linear_sum_assignment
|
| 11 |
+
|
| 12 |
+
from ..config.config_models import EvaluationConfig
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MetricsCalculator:
|
| 18 |
+
"""Calculate segmentation metrics."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, config: EvaluationConfig):
|
| 21 |
+
"""Initialize calculator with configuration.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
config: Evaluation configuration
|
| 25 |
+
"""
|
| 26 |
+
self.config = config
|
| 27 |
+
self.iou_thresholds = config.metrics.iou_thresholds
|
| 28 |
+
|
| 29 |
+
def calculate_iou(self, mask1: np.ndarray, mask2: np.ndarray) -> float:
|
| 30 |
+
"""Calculate IoU between two binary masks.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
mask1: First binary mask
|
| 34 |
+
mask2: Second binary mask
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
IoU score between 0 and 1
|
| 38 |
+
"""
|
| 39 |
+
intersection = np.logical_and(mask1, mask2).sum()
|
| 40 |
+
union = np.logical_or(mask1, mask2).sum()
|
| 41 |
+
|
| 42 |
+
if union == 0:
|
| 43 |
+
return 0.0
|
| 44 |
+
|
| 45 |
+
return float(intersection / union)
|
| 46 |
+
|
| 47 |
+
def match_instances(
|
| 48 |
+
self,
|
| 49 |
+
gt_masks: list[np.ndarray],
|
| 50 |
+
pred_masks: list[np.ndarray],
|
| 51 |
+
iou_threshold: float
|
| 52 |
+
) -> dict[str, Any]:
|
| 53 |
+
"""Match predicted instances to ground truth.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
gt_masks: List of ground truth masks
|
| 57 |
+
pred_masks: List of predicted masks
|
| 58 |
+
iou_threshold: Minimum IoU for match
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Dict with matching results
|
| 62 |
+
"""
|
| 63 |
+
if len(gt_masks) == 0 and len(pred_masks) == 0:
|
| 64 |
+
return {
|
| 65 |
+
"matches": [],
|
| 66 |
+
"unmatched_gt": [],
|
| 67 |
+
"unmatched_pred": [],
|
| 68 |
+
"true_positives": 0,
|
| 69 |
+
"false_positives": 0,
|
| 70 |
+
"false_negatives": 0,
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
if len(gt_masks) == 0:
|
| 74 |
+
return {
|
| 75 |
+
"matches": [],
|
| 76 |
+
"unmatched_gt": [],
|
| 77 |
+
"unmatched_pred": list(range(len(pred_masks))),
|
| 78 |
+
"true_positives": 0,
|
| 79 |
+
"false_positives": len(pred_masks),
|
| 80 |
+
"false_negatives": 0,
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
if len(pred_masks) == 0:
|
| 84 |
+
return {
|
| 85 |
+
"matches": [],
|
| 86 |
+
"unmatched_gt": list(range(len(gt_masks))),
|
| 87 |
+
"unmatched_pred": [],
|
| 88 |
+
"true_positives": 0,
|
| 89 |
+
"false_positives": 0,
|
| 90 |
+
"false_negatives": len(gt_masks),
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
# Compute IoU matrix
|
| 94 |
+
iou_matrix = np.zeros((len(gt_masks), len(pred_masks)))
|
| 95 |
+
|
| 96 |
+
for i, gt_mask in enumerate(gt_masks):
|
| 97 |
+
for j, pred_mask in enumerate(pred_masks):
|
| 98 |
+
iou_matrix[i, j] = self.calculate_iou(gt_mask, pred_mask)
|
| 99 |
+
|
| 100 |
+
# Hungarian algorithm for optimal matching
|
| 101 |
+
gt_indices, pred_indices = linear_sum_assignment(-iou_matrix)
|
| 102 |
+
|
| 103 |
+
# Filter matches by threshold
|
| 104 |
+
matches = []
|
| 105 |
+
for gt_idx, pred_idx in zip(gt_indices, pred_indices):
|
| 106 |
+
iou = iou_matrix[gt_idx, pred_idx]
|
| 107 |
+
if iou >= iou_threshold:
|
| 108 |
+
matches.append({
|
| 109 |
+
"gt_idx": int(gt_idx),
|
| 110 |
+
"pred_idx": int(pred_idx),
|
| 111 |
+
"iou": float(iou),
|
| 112 |
+
})
|
| 113 |
+
|
| 114 |
+
matched_gt = {m["gt_idx"] for m in matches}
|
| 115 |
+
matched_pred = {m["pred_idx"] for m in matches}
|
| 116 |
+
|
| 117 |
+
unmatched_gt = [i for i in range(len(gt_masks)) if i not in matched_gt]
|
| 118 |
+
unmatched_pred = [i for i in range(len(pred_masks)) if i not in matched_pred]
|
| 119 |
+
|
| 120 |
+
return {
|
| 121 |
+
"matches": matches,
|
| 122 |
+
"unmatched_gt": unmatched_gt,
|
| 123 |
+
"unmatched_pred": unmatched_pred,
|
| 124 |
+
"true_positives": len(matches),
|
| 125 |
+
"false_positives": len(unmatched_pred),
|
| 126 |
+
"false_negatives": len(unmatched_gt),
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
def calculate_image_metrics(
|
| 130 |
+
self, image_dir: Path
|
| 131 |
+
) -> dict[str, Any]:
|
| 132 |
+
"""Calculate metrics for single image.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
image_dir: Directory containing ground_truth and inference subdirs
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
Dict with metrics for this image
|
| 139 |
+
|
| 140 |
+
Raises:
|
| 141 |
+
ValueError: If required files are missing
|
| 142 |
+
"""
|
| 143 |
+
gt_dir = image_dir / "ground_truth"
|
| 144 |
+
inf_dir = image_dir / "inference"
|
| 145 |
+
|
| 146 |
+
# Load metadata
|
| 147 |
+
gt_meta_path = gt_dir / "metadata.json"
|
| 148 |
+
inf_meta_path = inf_dir / "metadata.json"
|
| 149 |
+
|
| 150 |
+
if not gt_meta_path.exists():
|
| 151 |
+
raise ValueError(f"Ground truth metadata not found: {gt_meta_path}")
|
| 152 |
+
|
| 153 |
+
if not inf_meta_path.exists():
|
| 154 |
+
logger.warning(f"Inference metadata not found: {inf_meta_path}")
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
with open(gt_meta_path) as f:
|
| 158 |
+
gt_meta = json.load(f)
|
| 159 |
+
|
| 160 |
+
with open(inf_meta_path) as f:
|
| 161 |
+
inf_meta = json.load(f)
|
| 162 |
+
|
| 163 |
+
# Group masks by label
|
| 164 |
+
gt_by_label: dict[str, list[np.ndarray]] = {}
|
| 165 |
+
inf_by_label: dict[str, list[np.ndarray]] = {}
|
| 166 |
+
|
| 167 |
+
# Load ground truth masks
|
| 168 |
+
for mask_info in gt_meta.get("masks", []):
|
| 169 |
+
label = mask_info["label"]
|
| 170 |
+
mask_path = gt_dir / mask_info["filename"]
|
| 171 |
+
|
| 172 |
+
if not mask_path.exists():
|
| 173 |
+
logger.warning(f"Ground truth mask not found: {mask_path}")
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
mask_img = Image.open(mask_path)
|
| 177 |
+
mask_array = np.array(mask_img) > 0 # Binarize
|
| 178 |
+
|
| 179 |
+
if label not in gt_by_label:
|
| 180 |
+
gt_by_label[label] = []
|
| 181 |
+
gt_by_label[label].append(mask_array)
|
| 182 |
+
|
| 183 |
+
# Load inference masks
|
| 184 |
+
for mask_info in inf_meta.get("masks", []):
|
| 185 |
+
label = mask_info["label"]
|
| 186 |
+
mask_path = inf_dir / mask_info["filename"]
|
| 187 |
+
|
| 188 |
+
if not mask_path.exists():
|
| 189 |
+
logger.warning(f"Inference mask not found: {mask_path}")
|
| 190 |
+
continue
|
| 191 |
+
|
| 192 |
+
mask_img = Image.open(mask_path)
|
| 193 |
+
mask_array = np.array(mask_img) > 0 # Binarize
|
| 194 |
+
|
| 195 |
+
if label not in inf_by_label:
|
| 196 |
+
inf_by_label[label] = []
|
| 197 |
+
inf_by_label[label].append(mask_array)
|
| 198 |
+
|
| 199 |
+
# Calculate metrics at each IoU threshold
|
| 200 |
+
results = {
|
| 201 |
+
"image_name": image_dir.name,
|
| 202 |
+
"class": image_dir.parent.name,
|
| 203 |
+
"by_threshold": {},
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
all_labels = set(gt_by_label.keys()) | set(inf_by_label.keys())
|
| 207 |
+
|
| 208 |
+
for threshold in self.iou_thresholds:
|
| 209 |
+
threshold_results = {
|
| 210 |
+
"iou_threshold": threshold,
|
| 211 |
+
"by_label": {},
|
| 212 |
+
"total": {
|
| 213 |
+
"true_positives": 0,
|
| 214 |
+
"false_positives": 0,
|
| 215 |
+
"false_negatives": 0,
|
| 216 |
+
},
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
for label in all_labels:
|
| 220 |
+
gt_masks = gt_by_label.get(label, [])
|
| 221 |
+
pred_masks = inf_by_label.get(label, [])
|
| 222 |
+
|
| 223 |
+
matching = self.match_instances(gt_masks, pred_masks, threshold)
|
| 224 |
+
|
| 225 |
+
threshold_results["by_label"][label] = {
|
| 226 |
+
"gt_count": len(gt_masks),
|
| 227 |
+
"pred_count": len(pred_masks),
|
| 228 |
+
"true_positives": matching["true_positives"],
|
| 229 |
+
"false_positives": matching["false_positives"],
|
| 230 |
+
"false_negatives": matching["false_negatives"],
|
| 231 |
+
"matches": matching["matches"],
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
# Add to totals
|
| 235 |
+
threshold_results["total"]["true_positives"] += matching["true_positives"]
|
| 236 |
+
threshold_results["total"]["false_positives"] += matching["false_positives"]
|
| 237 |
+
threshold_results["total"]["false_negatives"] += matching["false_negatives"]
|
| 238 |
+
|
| 239 |
+
results["by_threshold"][str(threshold)] = threshold_results
|
| 240 |
+
|
| 241 |
+
return results
|
| 242 |
+
|
| 243 |
+
def calculate_aggregate_metrics(
|
| 244 |
+
self, image_results: list[dict[str, Any]]
|
| 245 |
+
) -> dict[str, Any]:
|
| 246 |
+
"""Calculate aggregate metrics across all images.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
image_results: List of per-image metrics
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
Dict with aggregate metrics
|
| 253 |
+
"""
|
| 254 |
+
aggregate = {
|
| 255 |
+
"total_images": len(image_results),
|
| 256 |
+
"by_threshold": {},
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
for threshold in self.iou_thresholds:
|
| 260 |
+
threshold_str = str(threshold)
|
| 261 |
+
|
| 262 |
+
# Aggregate counts
|
| 263 |
+
label_stats: dict[str, dict] = {}
|
| 264 |
+
total_tp = 0
|
| 265 |
+
total_fp = 0
|
| 266 |
+
total_fn = 0
|
| 267 |
+
|
| 268 |
+
for img_result in image_results:
|
| 269 |
+
threshold_data = img_result["by_threshold"][threshold_str]
|
| 270 |
+
|
| 271 |
+
for label, label_data in threshold_data["by_label"].items():
|
| 272 |
+
if label not in label_stats:
|
| 273 |
+
label_stats[label] = {
|
| 274 |
+
"tp": 0,
|
| 275 |
+
"fp": 0,
|
| 276 |
+
"fn": 0,
|
| 277 |
+
"gt_total": 0,
|
| 278 |
+
"pred_total": 0,
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
label_stats[label]["tp"] += label_data["true_positives"]
|
| 282 |
+
label_stats[label]["fp"] += label_data["false_positives"]
|
| 283 |
+
label_stats[label]["fn"] += label_data["false_negatives"]
|
| 284 |
+
label_stats[label]["gt_total"] += label_data["gt_count"]
|
| 285 |
+
label_stats[label]["pred_total"] += label_data["pred_count"]
|
| 286 |
+
|
| 287 |
+
total_tp += threshold_data["total"]["true_positives"]
|
| 288 |
+
total_fp += threshold_data["total"]["false_positives"]
|
| 289 |
+
total_fn += threshold_data["total"]["false_negatives"]
|
| 290 |
+
|
| 291 |
+
# Calculate precision, recall, F1 per label
|
| 292 |
+
for label, stats in label_stats.items():
|
| 293 |
+
precision = stats["tp"] / (stats["tp"] + stats["fp"]) if (stats["tp"] + stats["fp"]) > 0 else 0.0
|
| 294 |
+
recall = stats["tp"] / (stats["tp"] + stats["fn"]) if (stats["tp"] + stats["fn"]) > 0 else 0.0
|
| 295 |
+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 296 |
+
|
| 297 |
+
stats["precision"] = precision
|
| 298 |
+
stats["recall"] = recall
|
| 299 |
+
stats["f1"] = f1
|
| 300 |
+
|
| 301 |
+
# Calculate overall metrics
|
| 302 |
+
overall_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
|
| 303 |
+
overall_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
|
| 304 |
+
overall_f1 = 2 * overall_precision * overall_recall / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0.0
|
| 305 |
+
|
| 306 |
+
# Build confusion matrix
|
| 307 |
+
confusion_matrix = self._build_confusion_matrix(image_results, threshold_str)
|
| 308 |
+
|
| 309 |
+
aggregate["by_threshold"][threshold_str] = {
|
| 310 |
+
"iou_threshold": threshold,
|
| 311 |
+
"by_label": label_stats,
|
| 312 |
+
"overall": {
|
| 313 |
+
"true_positives": total_tp,
|
| 314 |
+
"false_positives": total_fp,
|
| 315 |
+
"false_negatives": total_fn,
|
| 316 |
+
"precision": overall_precision,
|
| 317 |
+
"recall": overall_recall,
|
| 318 |
+
"f1": overall_f1,
|
| 319 |
+
"map": overall_precision, # Simplified mAP
|
| 320 |
+
"mar": overall_recall, # Simplified mAR
|
| 321 |
+
},
|
| 322 |
+
"confusion_matrix": confusion_matrix,
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
return aggregate
|
| 326 |
+
|
| 327 |
+
def _build_confusion_matrix(
|
| 328 |
+
self, image_results: list[dict[str, Any]], threshold_str: str
|
| 329 |
+
) -> dict[str, Any]:
|
| 330 |
+
"""Build confusion matrix for given threshold.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
image_results: List of per-image metrics
|
| 334 |
+
threshold_str: IoU threshold as string
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
Confusion matrix data
|
| 338 |
+
"""
|
| 339 |
+
# Get all labels
|
| 340 |
+
all_labels = set()
|
| 341 |
+
for img_result in image_results:
|
| 342 |
+
threshold_data = img_result["by_threshold"][threshold_str]
|
| 343 |
+
all_labels.update(threshold_data["by_label"].keys())
|
| 344 |
+
|
| 345 |
+
labels = sorted(all_labels)
|
| 346 |
+
n_labels = len(labels)
|
| 347 |
+
|
| 348 |
+
# Initialize matrix
|
| 349 |
+
matrix = np.zeros((n_labels, n_labels), dtype=int)
|
| 350 |
+
|
| 351 |
+
# Fill matrix (simplified: just count matches)
|
| 352 |
+
for img_result in image_results:
|
| 353 |
+
threshold_data = img_result["by_threshold"][threshold_str]
|
| 354 |
+
|
| 355 |
+
for i, gt_label in enumerate(labels):
|
| 356 |
+
if gt_label not in threshold_data["by_label"]:
|
| 357 |
+
continue
|
| 358 |
+
|
| 359 |
+
label_data = threshold_data["by_label"][gt_label]
|
| 360 |
+
|
| 361 |
+
# True positives go on diagonal
|
| 362 |
+
matrix[i, i] += label_data["true_positives"]
|
| 363 |
+
|
| 364 |
+
# False negatives (missed) - simplified representation
|
| 365 |
+
# In a full implementation, we'd track which class they were predicted as
|
| 366 |
+
|
| 367 |
+
return {
|
| 368 |
+
"labels": labels,
|
| 369 |
+
"matrix": matrix.tolist(),
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
def run_evaluation(self, cache_dir: Path) -> dict[str, Any]:
|
| 373 |
+
"""Run complete metrics evaluation.
|
| 374 |
+
|
| 375 |
+
Args:
|
| 376 |
+
cache_dir: Cache directory with ground truth and inference results
|
| 377 |
+
|
| 378 |
+
Returns:
|
| 379 |
+
Complete metrics results
|
| 380 |
+
|
| 381 |
+
Raises:
|
| 382 |
+
ValueError: If cache directory is invalid
|
| 383 |
+
"""
|
| 384 |
+
if not cache_dir.exists():
|
| 385 |
+
raise ValueError(f"Cache directory not found: {cache_dir}")
|
| 386 |
+
|
| 387 |
+
logger.info(f"Calculating metrics from {cache_dir}")
|
| 388 |
+
|
| 389 |
+
# Find all image directories
|
| 390 |
+
image_results = []
|
| 391 |
+
|
| 392 |
+
for class_dir in cache_dir.iterdir():
|
| 393 |
+
if not class_dir.is_dir():
|
| 394 |
+
continue
|
| 395 |
+
|
| 396 |
+
for image_dir in class_dir.iterdir():
|
| 397 |
+
if not image_dir.is_dir():
|
| 398 |
+
continue
|
| 399 |
+
|
| 400 |
+
try:
|
| 401 |
+
metrics = self.calculate_image_metrics(image_dir)
|
| 402 |
+
if metrics:
|
| 403 |
+
image_results.append(metrics)
|
| 404 |
+
except Exception as e:
|
| 405 |
+
logger.error(f"Failed to calculate metrics for {image_dir}: {e}")
|
| 406 |
+
continue
|
| 407 |
+
|
| 408 |
+
if not image_results:
|
| 409 |
+
raise ValueError("No valid image results found for metrics calculation")
|
| 410 |
+
|
| 411 |
+
logger.info(f"Calculated metrics for {len(image_results)} images")
|
| 412 |
+
|
| 413 |
+
# Calculate aggregate
|
| 414 |
+
aggregate = self.calculate_aggregate_metrics(image_results)
|
| 415 |
+
|
| 416 |
+
return {
|
| 417 |
+
"per_image": image_results,
|
| 418 |
+
"aggregate": aggregate,
|
| 419 |
+
}
|
metrics_evaluation/run_evaluation.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Main execution script for SAM3 metrics evaluation."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from config.config_loader import load_config
|
| 11 |
+
from extraction.cvat_extractor import CVATExtractor
|
| 12 |
+
from inference.sam3_inference import SAM3Inferencer
|
| 13 |
+
from metrics.metrics_calculator import MetricsCalculator
|
| 14 |
+
from utils.logging_config import setup_logging
|
| 15 |
+
from visualization.visual_comparison import VisualComparator
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def write_metrics_summary(metrics: dict, output_path: Path) -> None:
|
| 21 |
+
"""Write human-readable metrics summary.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
metrics: Metrics dictionary
|
| 25 |
+
output_path: Path to output file
|
| 26 |
+
"""
|
| 27 |
+
with open(output_path, "w") as f:
|
| 28 |
+
f.write("=" * 80 + "\n")
|
| 29 |
+
f.write("SAM3 EVALUATION METRICS SUMMARY\n")
|
| 30 |
+
f.write("=" * 80 + "\n\n")
|
| 31 |
+
|
| 32 |
+
aggregate = metrics["aggregate"]
|
| 33 |
+
|
| 34 |
+
f.write(f"Total Images Evaluated: {aggregate['total_images']}\n\n")
|
| 35 |
+
|
| 36 |
+
for threshold_str, threshold_data in aggregate["by_threshold"].items():
|
| 37 |
+
iou = threshold_data["iou_threshold"]
|
| 38 |
+
f.write(f"\n{'='*80}\n")
|
| 39 |
+
f.write(f"IoU Threshold: {iou:.0%}\n")
|
| 40 |
+
f.write(f"{'='*80}\n\n")
|
| 41 |
+
|
| 42 |
+
overall = threshold_data["overall"]
|
| 43 |
+
|
| 44 |
+
f.write("Overall Metrics:\n")
|
| 45 |
+
f.write(f" True Positives: {overall['true_positives']}\n")
|
| 46 |
+
f.write(f" False Positives: {overall['false_positives']}\n")
|
| 47 |
+
f.write(f" False Negatives: {overall['false_negatives']}\n")
|
| 48 |
+
f.write(f" Precision: {overall['precision']:.2%}\n")
|
| 49 |
+
f.write(f" Recall: {overall['recall']:.2%}\n")
|
| 50 |
+
f.write(f" F1-Score: {overall['f1']:.2%}\n")
|
| 51 |
+
f.write(f" mAP: {overall['map']:.2%}\n")
|
| 52 |
+
f.write(f" mAR: {overall['mar']:.2%}\n\n")
|
| 53 |
+
|
| 54 |
+
f.write("Per-Class Metrics:\n")
|
| 55 |
+
f.write("-" * 80 + "\n")
|
| 56 |
+
f.write(f"{'Class':<20} {'GT':>6} {'Pred':>6} {'TP':>6} {'FP':>6} {'FN':>6} {'Prec':>8} {'Rec':>8} {'F1':>8}\n")
|
| 57 |
+
f.write("-" * 80 + "\n")
|
| 58 |
+
|
| 59 |
+
for label, stats in sorted(threshold_data["by_label"].items()):
|
| 60 |
+
f.write(
|
| 61 |
+
f"{label:<20} "
|
| 62 |
+
f"{stats['gt_total']:>6} "
|
| 63 |
+
f"{stats['pred_total']:>6} "
|
| 64 |
+
f"{stats['tp']:>6} "
|
| 65 |
+
f"{stats['fp']:>6} "
|
| 66 |
+
f"{stats['fn']:>6} "
|
| 67 |
+
f"{stats['precision']:>8.2%} "
|
| 68 |
+
f"{stats['recall']:>8.2%} "
|
| 69 |
+
f"{stats['f1']:>8.2%}\n"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
f.write("\n")
|
| 73 |
+
|
| 74 |
+
# Confusion Matrix
|
| 75 |
+
cm = threshold_data["confusion_matrix"]
|
| 76 |
+
labels = cm["labels"]
|
| 77 |
+
matrix = cm["matrix"]
|
| 78 |
+
|
| 79 |
+
if labels:
|
| 80 |
+
f.write("Confusion Matrix:\n")
|
| 81 |
+
f.write("-" * 80 + "\n")
|
| 82 |
+
|
| 83 |
+
# Header
|
| 84 |
+
header = "Actual \\ Pred |"
|
| 85 |
+
for label in labels:
|
| 86 |
+
header += f" {label[:10]:>10} |"
|
| 87 |
+
f.write(header + "\n")
|
| 88 |
+
f.write("-" * len(header) + "\n")
|
| 89 |
+
|
| 90 |
+
# Rows
|
| 91 |
+
for i, actual_label in enumerate(labels):
|
| 92 |
+
row = f"{actual_label[:13]:>13} |"
|
| 93 |
+
for j in range(len(labels)):
|
| 94 |
+
row += f" {matrix[i][j]:>10} |"
|
| 95 |
+
f.write(row + "\n")
|
| 96 |
+
|
| 97 |
+
f.write("\n")
|
| 98 |
+
|
| 99 |
+
f.write("=" * 80 + "\n")
|
| 100 |
+
f.write("END OF REPORT\n")
|
| 101 |
+
f.write("=" * 80 + "\n")
|
| 102 |
+
|
| 103 |
+
logger.info(f"Wrote metrics summary to {output_path}")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def main() -> int:
|
| 107 |
+
"""Main execution function.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Exit code (0 for success, non-zero for failure)
|
| 111 |
+
"""
|
| 112 |
+
parser = argparse.ArgumentParser(
|
| 113 |
+
description="Run SAM3 metrics evaluation against CVAT ground truth"
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--config",
|
| 117 |
+
type=str,
|
| 118 |
+
default="config/config.json",
|
| 119 |
+
help="Path to configuration file"
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--force-download",
|
| 123 |
+
action="store_true",
|
| 124 |
+
help="Force re-download images from CVAT"
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--force-inference",
|
| 128 |
+
action="store_true",
|
| 129 |
+
help="Force re-run SAM3 inference"
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--skip-inference",
|
| 133 |
+
action="store_true",
|
| 134 |
+
help="Skip inference, use cached results"
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--visualize",
|
| 138 |
+
action="store_true",
|
| 139 |
+
help="Generate visual comparisons"
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--log-level",
|
| 143 |
+
type=str,
|
| 144 |
+
default="INFO",
|
| 145 |
+
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
| 146 |
+
help="Logging level"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
args = parser.parse_args()
|
| 150 |
+
|
| 151 |
+
# Load configuration
|
| 152 |
+
try:
|
| 153 |
+
config = load_config(args.config)
|
| 154 |
+
except Exception as e:
|
| 155 |
+
print(f"ERROR: Failed to load configuration: {e}", file=sys.stderr)
|
| 156 |
+
return 1
|
| 157 |
+
|
| 158 |
+
# Setup logging
|
| 159 |
+
cache_dir = config.get_cache_path()
|
| 160 |
+
log_file = cache_dir / "evaluation_log.txt"
|
| 161 |
+
setup_logging(log_file, getattr(logging, args.log_level))
|
| 162 |
+
|
| 163 |
+
logger.info("=" * 80)
|
| 164 |
+
logger.info("SAM3 METRICS EVALUATION")
|
| 165 |
+
logger.info("=" * 80)
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
# Phase 1: Extract from CVAT
|
| 169 |
+
logger.info("\n" + "=" * 80)
|
| 170 |
+
logger.info("PHASE 1: CVAT Data Extraction")
|
| 171 |
+
logger.info("=" * 80)
|
| 172 |
+
|
| 173 |
+
extractor = CVATExtractor(config)
|
| 174 |
+
|
| 175 |
+
if args.force_download:
|
| 176 |
+
logger.info("Force download enabled - will re-download all images")
|
| 177 |
+
|
| 178 |
+
image_paths = extractor.run_extraction()
|
| 179 |
+
|
| 180 |
+
total_extracted = sum(len(paths) for paths in image_paths.values())
|
| 181 |
+
logger.info(f"Extraction complete: {total_extracted} images extracted")
|
| 182 |
+
|
| 183 |
+
if total_extracted == 0:
|
| 184 |
+
logger.error("No images extracted. Aborting.")
|
| 185 |
+
return 1
|
| 186 |
+
|
| 187 |
+
# Phase 2: Run SAM3 Inference
|
| 188 |
+
if not args.skip_inference:
|
| 189 |
+
logger.info("\n" + "=" * 80)
|
| 190 |
+
logger.info("PHASE 2: SAM3 Inference")
|
| 191 |
+
logger.info("=" * 80)
|
| 192 |
+
|
| 193 |
+
inferencer = SAM3Inferencer(config)
|
| 194 |
+
stats = inferencer.run_inference_batch(image_paths, args.force_inference)
|
| 195 |
+
|
| 196 |
+
logger.info(
|
| 197 |
+
f"Inference complete: {stats['successful']} successful, "
|
| 198 |
+
f"{stats['failed']} failed, {stats['skipped']} skipped"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if stats['successful'] == 0 and stats['skipped'] == 0:
|
| 202 |
+
logger.error("No successful inferences. Aborting.")
|
| 203 |
+
return 1
|
| 204 |
+
else:
|
| 205 |
+
logger.info("Skipping inference (--skip-inference)")
|
| 206 |
+
|
| 207 |
+
# Phase 3: Calculate Metrics
|
| 208 |
+
logger.info("\n" + "=" * 80)
|
| 209 |
+
logger.info("PHASE 3: Metrics Calculation")
|
| 210 |
+
logger.info("=" * 80)
|
| 211 |
+
|
| 212 |
+
calculator = MetricsCalculator(config)
|
| 213 |
+
metrics = calculator.run_evaluation(cache_dir)
|
| 214 |
+
|
| 215 |
+
# Save detailed metrics
|
| 216 |
+
metrics_json_path = cache_dir / "metrics_detailed.json"
|
| 217 |
+
with open(metrics_json_path, "w") as f:
|
| 218 |
+
json.dump(metrics, f, indent=2)
|
| 219 |
+
logger.info(f"Saved detailed metrics to {metrics_json_path}")
|
| 220 |
+
|
| 221 |
+
# Write summary
|
| 222 |
+
metrics_summary_path = cache_dir / "metrics_summary.txt"
|
| 223 |
+
write_metrics_summary(metrics, metrics_summary_path)
|
| 224 |
+
|
| 225 |
+
# Phase 4: Visualization (optional)
|
| 226 |
+
if args.visualize or config.output.generate_visualizations:
|
| 227 |
+
logger.info("\n" + "=" * 80)
|
| 228 |
+
logger.info("PHASE 4: Visual Comparisons")
|
| 229 |
+
logger.info("=" * 80)
|
| 230 |
+
|
| 231 |
+
comparator = VisualComparator()
|
| 232 |
+
comparison_paths = comparator.generate_all_comparisons(cache_dir)
|
| 233 |
+
logger.info(f"Generated {len(comparison_paths)} visual comparisons")
|
| 234 |
+
|
| 235 |
+
# Summary
|
| 236 |
+
logger.info("\n" + "=" * 80)
|
| 237 |
+
logger.info("EVALUATION COMPLETE")
|
| 238 |
+
logger.info("=" * 80)
|
| 239 |
+
|
| 240 |
+
aggregate = metrics["aggregate"]
|
| 241 |
+
logger.info(f"Images evaluated: {aggregate['total_images']}")
|
| 242 |
+
|
| 243 |
+
# Show metrics at 50% IoU
|
| 244 |
+
threshold_50 = aggregate["by_threshold"]["0.5"]
|
| 245 |
+
overall = threshold_50["overall"]
|
| 246 |
+
|
| 247 |
+
logger.info(f"\nMetrics at 50% IoU:")
|
| 248 |
+
logger.info(f" Precision: {overall['precision']:.2%}")
|
| 249 |
+
logger.info(f" Recall: {overall['recall']:.2%}")
|
| 250 |
+
logger.info(f" F1-Score: {overall['f1']:.2%}")
|
| 251 |
+
logger.info(f" mAP: {overall['map']:.2%}")
|
| 252 |
+
logger.info(f" mAR: {overall['mar']:.2%}")
|
| 253 |
+
|
| 254 |
+
logger.info(f"\nResults saved to:")
|
| 255 |
+
logger.info(f" Metrics Summary: {metrics_summary_path}")
|
| 256 |
+
logger.info(f" Detailed JSON: {metrics_json_path}")
|
| 257 |
+
logger.info(f" Execution Log: {log_file}")
|
| 258 |
+
|
| 259 |
+
return 0
|
| 260 |
+
|
| 261 |
+
except KeyboardInterrupt:
|
| 262 |
+
logger.warning("\nEvaluation interrupted by user")
|
| 263 |
+
return 130
|
| 264 |
+
|
| 265 |
+
except Exception as e:
|
| 266 |
+
logger.error(f"\nEvaluation failed with error: {e}", exc_info=True)
|
| 267 |
+
return 1
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
sys.exit(main())
|
metrics_evaluation/utils/__init__.py
ADDED
|
File without changes
|
metrics_evaluation/utils/logging_config.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging configuration for evaluation."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def setup_logging(log_file: Path | None = None, level: int = logging.INFO) -> None:
|
| 9 |
+
"""Set up logging configuration.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
log_file: Optional path to log file
|
| 13 |
+
level: Logging level
|
| 14 |
+
"""
|
| 15 |
+
# Create formatters
|
| 16 |
+
detailed_formatter = logging.Formatter(
|
| 17 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 18 |
+
datefmt="%Y-%m-%d %H:%M:%S"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
simple_formatter = logging.Formatter(
|
| 22 |
+
"%(levelname)s: %(message)s"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Get root logger
|
| 26 |
+
logger = logging.getLogger()
|
| 27 |
+
logger.setLevel(level)
|
| 28 |
+
|
| 29 |
+
# Remove existing handlers
|
| 30 |
+
logger.handlers.clear()
|
| 31 |
+
|
| 32 |
+
# Console handler
|
| 33 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 34 |
+
console_handler.setLevel(level)
|
| 35 |
+
console_handler.setFormatter(simple_formatter)
|
| 36 |
+
logger.addHandler(console_handler)
|
| 37 |
+
|
| 38 |
+
# File handler
|
| 39 |
+
if log_file:
|
| 40 |
+
log_file.parent.mkdir(parents=True, exist_ok=True)
|
| 41 |
+
file_handler = logging.FileHandler(log_file, mode="w")
|
| 42 |
+
file_handler.setLevel(logging.DEBUG)
|
| 43 |
+
file_handler.setFormatter(detailed_formatter)
|
| 44 |
+
logger.addHandler(file_handler)
|
metrics_evaluation/visualization/visual_comparison.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Visual comparison generation for evaluation."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class VisualComparator:
|
| 14 |
+
"""Generate visual comparisons between ground truth and predictions."""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
"""Initialize comparator."""
|
| 18 |
+
self.colors = {
|
| 19 |
+
"ground_truth": (0, 255, 0, 128), # Green
|
| 20 |
+
"prediction": (255, 0, 0, 128), # Red
|
| 21 |
+
"true_positive": (255, 255, 0, 128), # Yellow
|
| 22 |
+
"false_positive": (255, 0, 0, 128), # Red
|
| 23 |
+
"false_negative": (0, 0, 255, 128), # Blue
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def create_comparison(
|
| 27 |
+
self, image_dir: Path, output_path: Path | None = None
|
| 28 |
+
) -> Path:
|
| 29 |
+
"""Create visual comparison for image.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
image_dir: Directory containing image and masks
|
| 33 |
+
output_path: Optional output path (default: image_dir/comparison.png)
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Path to generated comparison image
|
| 37 |
+
|
| 38 |
+
Raises:
|
| 39 |
+
ValueError: If required files are missing
|
| 40 |
+
"""
|
| 41 |
+
# Load original image
|
| 42 |
+
image_path = image_dir / "image.jpg"
|
| 43 |
+
if not image_path.exists():
|
| 44 |
+
raise ValueError(f"Image not found: {image_path}")
|
| 45 |
+
|
| 46 |
+
original = Image.open(image_path).convert("RGBA")
|
| 47 |
+
width, height = original.size
|
| 48 |
+
|
| 49 |
+
# Create overlays
|
| 50 |
+
gt_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
| 51 |
+
pred_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
| 52 |
+
|
| 53 |
+
# Load ground truth masks
|
| 54 |
+
gt_dir = image_dir / "ground_truth"
|
| 55 |
+
if gt_dir.exists():
|
| 56 |
+
gt_meta_path = gt_dir / "metadata.json"
|
| 57 |
+
if gt_meta_path.exists():
|
| 58 |
+
with open(gt_meta_path) as f:
|
| 59 |
+
gt_meta = json.load(f)
|
| 60 |
+
|
| 61 |
+
for mask_info in gt_meta.get("masks", []):
|
| 62 |
+
mask_path = gt_dir / mask_info["filename"]
|
| 63 |
+
if not mask_path.exists():
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
mask = Image.open(mask_path).convert("L")
|
| 67 |
+
colored_mask = Image.new("RGBA", (width, height), self.colors["ground_truth"])
|
| 68 |
+
colored_mask.putalpha(mask)
|
| 69 |
+
gt_overlay = Image.alpha_composite(gt_overlay, colored_mask)
|
| 70 |
+
|
| 71 |
+
# Load prediction masks
|
| 72 |
+
pred_dir = image_dir / "inference"
|
| 73 |
+
if pred_dir.exists():
|
| 74 |
+
pred_meta_path = pred_dir / "metadata.json"
|
| 75 |
+
if pred_meta_path.exists():
|
| 76 |
+
with open(pred_meta_path) as f:
|
| 77 |
+
pred_meta = json.load(f)
|
| 78 |
+
|
| 79 |
+
for mask_info in pred_meta.get("masks", []):
|
| 80 |
+
mask_path = pred_dir / mask_info["filename"]
|
| 81 |
+
if not mask_path.exists():
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
mask = Image.open(mask_path).convert("L")
|
| 85 |
+
colored_mask = Image.new("RGBA", (width, height), self.colors["prediction"])
|
| 86 |
+
colored_mask.putalpha(mask)
|
| 87 |
+
pred_overlay = Image.alpha_composite(pred_overlay, colored_mask)
|
| 88 |
+
|
| 89 |
+
# Composite images
|
| 90 |
+
result = Image.alpha_composite(original, gt_overlay)
|
| 91 |
+
result = Image.alpha_composite(result, pred_overlay)
|
| 92 |
+
|
| 93 |
+
# Add legend
|
| 94 |
+
result = self._add_legend(result)
|
| 95 |
+
|
| 96 |
+
# Save
|
| 97 |
+
if output_path is None:
|
| 98 |
+
output_path = image_dir / "comparison.png"
|
| 99 |
+
|
| 100 |
+
result.convert("RGB").save(output_path)
|
| 101 |
+
logger.debug(f"Saved comparison to {output_path}")
|
| 102 |
+
|
| 103 |
+
return output_path
|
| 104 |
+
|
| 105 |
+
def _add_legend(self, image: Image.Image) -> Image.Image:
|
| 106 |
+
"""Add color legend to image.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
image: Input image
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Image with legend
|
| 113 |
+
"""
|
| 114 |
+
# Create legend area
|
| 115 |
+
legend_height = 60
|
| 116 |
+
legend_img = Image.new("RGB", (image.width, image.height + legend_height), (255, 255, 255))
|
| 117 |
+
legend_img.paste(image, (0, 0))
|
| 118 |
+
|
| 119 |
+
draw = ImageDraw.Draw(legend_img)
|
| 120 |
+
|
| 121 |
+
# Draw legend items
|
| 122 |
+
x_offset = 10
|
| 123 |
+
y_offset = image.height + 10
|
| 124 |
+
|
| 125 |
+
items = [
|
| 126 |
+
("Ground Truth", self.colors["ground_truth"][:3]),
|
| 127 |
+
("Prediction", self.colors["prediction"][:3]),
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
for label, color in items:
|
| 131 |
+
# Draw color box
|
| 132 |
+
draw.rectangle([x_offset, y_offset, x_offset + 30, y_offset + 30], fill=color)
|
| 133 |
+
|
| 134 |
+
# Draw label
|
| 135 |
+
draw.text((x_offset + 40, y_offset + 5), label, fill=(0, 0, 0))
|
| 136 |
+
|
| 137 |
+
x_offset += 200
|
| 138 |
+
|
| 139 |
+
return legend_img
|
| 140 |
+
|
| 141 |
+
def generate_all_comparisons(self, cache_dir: Path) -> list[Path]:
|
| 142 |
+
"""Generate comparisons for all images in cache.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
cache_dir: Cache directory
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
List of paths to generated comparisons
|
| 149 |
+
"""
|
| 150 |
+
comparison_paths = []
|
| 151 |
+
|
| 152 |
+
for class_dir in cache_dir.iterdir():
|
| 153 |
+
if not class_dir.is_dir():
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
for image_dir in class_dir.iterdir():
|
| 157 |
+
if not image_dir.is_dir():
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
comparison_path = self.create_comparison(image_dir)
|
| 162 |
+
comparison_paths.append(comparison_path)
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logger.error(f"Failed to create comparison for {image_dir}: {e}")
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
logger.info(f"Generated {len(comparison_paths)} comparison images")
|
| 168 |
+
return comparison_paths
|