Thibaut commited on
Commit
6f98a26
·
1 Parent(s): 5e12a05

Implement metrics evaluation system - CVAT extraction, SAM3 inference, metrics calculation, visualization, and main pipeline

Browse files
metrics_evaluation/extraction/cvat_extractor.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CVAT data extraction for SAM3 evaluation."""
2
+
3
+ import json
4
+ import logging
5
+ import random
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from ..config.config_models import EvaluationConfig
10
+ from ..cvat_api.client import CVATClient
11
+ from ..schema.core.annotation.mask import Mask
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class CVATExtractor:
17
+ """Extract annotated images and masks from CVAT."""
18
+
19
+ def __init__(self, config: EvaluationConfig):
20
+ """Initialize extractor with configuration.
21
+
22
+ Args:
23
+ config: Evaluation configuration
24
+
25
+ Raises:
26
+ ValueError: If configuration is invalid
27
+ """
28
+ self.config = config
29
+ self.client: CVATClient | None = None
30
+ self.project_id: int | None = None
31
+
32
+ def connect(self) -> None:
33
+ """Connect to CVAT API.
34
+
35
+ Raises:
36
+ ConnectionError: If connection fails
37
+ ValueError: If credentials are invalid
38
+ """
39
+ import os
40
+ from dotenv import load_dotenv
41
+
42
+ load_dotenv()
43
+
44
+ username = os.getenv("CVAT_USERNAME")
45
+ password = os.getenv("CVAT_PASSWORD")
46
+
47
+ if not username or not password:
48
+ raise ValueError(
49
+ "CVAT credentials not found in .env file. "
50
+ "Required: CVAT_USERNAME, CVAT_PASSWORD"
51
+ )
52
+
53
+ try:
54
+ self.client = CVATClient(
55
+ host=self.config.cvat.url,
56
+ credentials=(username, password),
57
+ organization=self.config.cvat.organization,
58
+ )
59
+ logger.info(f"Connected to CVAT at {self.config.cvat.url}")
60
+ except Exception as e:
61
+ raise ConnectionError(f"Failed to connect to CVAT: {e}") from e
62
+
63
+ def find_training_project(self) -> int:
64
+ """Find AI training project in CVAT.
65
+
66
+ Returns:
67
+ Project ID
68
+
69
+ Raises:
70
+ ValueError: If no suitable project found
71
+ """
72
+ if not self.client:
73
+ raise ValueError("Not connected to CVAT. Call connect() first.")
74
+
75
+ projects = self.client.projects.list()
76
+ filter_str = self.config.cvat.project_name_filter.lower()
77
+
78
+ matching_projects = [
79
+ p for p in projects if filter_str in p.name.lower()
80
+ ]
81
+
82
+ if not matching_projects:
83
+ available = [p.name for p in projects]
84
+ raise ValueError(
85
+ f"No project found with '{filter_str}' in name.\n"
86
+ f"Available projects: {available}"
87
+ )
88
+
89
+ if len(matching_projects) > 1:
90
+ names = [p.name for p in matching_projects]
91
+ logger.warning(
92
+ f"Multiple matching projects found: {names}. Using first one."
93
+ )
94
+
95
+ project = matching_projects[0]
96
+ self.project_id = project.id
97
+ logger.info(f"Using project: {project.name} (ID: {project.id})")
98
+ return project.id
99
+
100
+ def discover_images(self) -> dict[str, list[dict[str, Any]]]:
101
+ """Discover images with target labels.
102
+
103
+ Returns:
104
+ Dict mapping class names to lists of image metadata
105
+
106
+ Raises:
107
+ ValueError: If no images found with target labels
108
+ """
109
+ if not self.client or not self.project_id:
110
+ raise ValueError("Must connect and find project first")
111
+
112
+ tasks = self.client.tasks.list(project_id=self.project_id)
113
+
114
+ if not tasks:
115
+ raise ValueError(f"No tasks found in project {self.project_id}")
116
+
117
+ logger.info(f"Found {len(tasks)} tasks in project")
118
+
119
+ # Collect all images with annotations
120
+ class_images: dict[str, list[dict[str, Any]]] = {
121
+ class_name: [] for class_name in self.config.classes.keys()
122
+ }
123
+
124
+ for task in tasks:
125
+ try:
126
+ jobs = self.client.jobs.list(task_id=task.id)
127
+
128
+ for job in jobs:
129
+ annotations = self.client.annotations.get_job_annotations(job.id)
130
+
131
+ # Get frames with annotations
132
+ if not annotations or not hasattr(annotations, 'shapes'):
133
+ continue
134
+
135
+ # Group annotations by frame
136
+ frame_annotations: dict[int, list] = {}
137
+ for shape in annotations.shapes:
138
+ frame_id = shape.frame
139
+ if frame_id not in frame_annotations:
140
+ frame_annotations[frame_id] = []
141
+ frame_annotations[frame_id].append(shape)
142
+
143
+ # Check which classes are present in each frame
144
+ for frame_id, shapes in frame_annotations.items():
145
+ labels_in_frame = {shape.label_name for shape in shapes if hasattr(shape, 'type') and shape.type == 'mask'}
146
+
147
+ for class_name in self.config.classes.keys():
148
+ if class_name in labels_in_frame:
149
+ class_images[class_name].append({
150
+ "task_id": task.id,
151
+ "job_id": job.id,
152
+ "frame_id": frame_id,
153
+ "task_name": task.name,
154
+ "labels": list(labels_in_frame),
155
+ })
156
+
157
+ except Exception as e:
158
+ logger.warning(f"Error processing task {task.id}: {e}")
159
+ continue
160
+
161
+ # Log discovered counts
162
+ for class_name, images in class_images.items():
163
+ logger.info(f"Found {len(images)} images with label '{class_name}'")
164
+
165
+ # Check if we have enough images
166
+ for class_name, images in class_images.items():
167
+ requested = self.config.classes[class_name]
168
+ if len(images) < requested:
169
+ logger.warning(
170
+ f"Class '{class_name}': Requested {requested} images but only found {len(images)}"
171
+ )
172
+
173
+ return class_images
174
+
175
+ def sample_images(
176
+ self, class_images: dict[str, list[dict[str, Any]]]
177
+ ) -> dict[str, list[dict[str, Any]]]:
178
+ """Randomly sample images for each class.
179
+
180
+ Args:
181
+ class_images: Dict mapping class names to image lists
182
+
183
+ Returns:
184
+ Dict with sampled images
185
+ """
186
+ sampled = {}
187
+
188
+ for class_name, images in class_images.items():
189
+ requested = self.config.classes[class_name]
190
+ available = len(images)
191
+
192
+ if available == 0:
193
+ logger.error(f"No images available for class '{class_name}'")
194
+ sampled[class_name] = []
195
+ continue
196
+
197
+ sample_size = min(requested, available)
198
+ sampled[class_name] = random.sample(images, sample_size)
199
+
200
+ logger.info(
201
+ f"Sampled {sample_size}/{requested} images for class '{class_name}'"
202
+ )
203
+
204
+ return sampled
205
+
206
+ def download_images_and_masks(
207
+ self, sampled_images: dict[str, list[dict[str, Any]]]
208
+ ) -> dict[str, list[Path]]:
209
+ """Download images and extract ground truth masks.
210
+
211
+ Args:
212
+ sampled_images: Dict of sampled image metadata
213
+
214
+ Returns:
215
+ Dict mapping class names to lists of image paths
216
+
217
+ Raises:
218
+ ValueError: If download or extraction fails critically
219
+ """
220
+ if not self.client:
221
+ raise ValueError("Not connected to CVAT")
222
+
223
+ cache_dir = self.config.get_cache_path()
224
+ cache_dir.mkdir(parents=True, exist_ok=True)
225
+
226
+ downloaded_paths: dict[str, list[Path]] = {
227
+ class_name: [] for class_name in self.config.classes.keys()
228
+ }
229
+
230
+ total_images = sum(len(images) for images in sampled_images.values())
231
+ processed = 0
232
+
233
+ for class_name, images in sampled_images.items():
234
+ for img_meta in images:
235
+ processed += 1
236
+ logger.info(
237
+ f"Processing {processed}/{total_images}: "
238
+ f"{class_name} - Task {img_meta['task_id']} Frame {img_meta['frame_id']}"
239
+ )
240
+
241
+ try:
242
+ image_path = self._download_image(class_name, img_meta, cache_dir)
243
+ self._extract_masks(class_name, img_meta, image_path)
244
+ downloaded_paths[class_name].append(image_path)
245
+
246
+ except Exception as e:
247
+ logger.error(
248
+ f"Failed to process {class_name} image "
249
+ f"(task={img_meta['task_id']}, frame={img_meta['frame_id']}): {e}"
250
+ )
251
+ continue
252
+
253
+ # Log final counts
254
+ for class_name, paths in downloaded_paths.items():
255
+ logger.info(f"Successfully processed {len(paths)} images for '{class_name}'")
256
+
257
+ return downloaded_paths
258
+
259
+ def _download_image(
260
+ self, class_name: str, img_meta: dict[str, Any], cache_dir: Path
261
+ ) -> Path:
262
+ """Download single image.
263
+
264
+ Args:
265
+ class_name: Class label
266
+ img_meta: Image metadata
267
+ cache_dir: Cache directory
268
+
269
+ Returns:
270
+ Path to downloaded image
271
+
272
+ Raises:
273
+ ValueError: If download fails
274
+ """
275
+ # Create output directory
276
+ image_name = f"{img_meta['task_name']}_frame_{img_meta['frame_id']:06d}"
277
+ output_dir = cache_dir / class_name / image_name
278
+ output_dir.mkdir(parents=True, exist_ok=True)
279
+
280
+ image_path = output_dir / "image.jpg"
281
+
282
+ # Check cache
283
+ if image_path.exists():
284
+ logger.debug(f"Image already cached: {image_path}")
285
+ return image_path
286
+
287
+ # Download from CVAT
288
+ if not self.client:
289
+ raise ValueError("Client not initialized")
290
+
291
+ try:
292
+ image_data = self.client.tasks.get_frame(
293
+ img_meta["task_id"], img_meta["frame_id"]
294
+ )
295
+
296
+ with open(image_path, "wb") as f:
297
+ f.write(image_data)
298
+
299
+ logger.debug(f"Downloaded image to {image_path}")
300
+ return image_path
301
+
302
+ except Exception as e:
303
+ raise ValueError(
304
+ f"Failed to download image from task {img_meta['task_id']} "
305
+ f"frame {img_meta['frame_id']}: {e}"
306
+ ) from e
307
+
308
+ def _extract_masks(
309
+ self, class_name: str, img_meta: dict[str, Any], image_path: Path
310
+ ) -> None:
311
+ """Extract ground truth masks for image.
312
+
313
+ Args:
314
+ class_name: Class label
315
+ img_meta: Image metadata
316
+ image_path: Path to image file
317
+
318
+ Raises:
319
+ ValueError: If mask extraction fails
320
+ """
321
+ output_dir = image_path.parent / "ground_truth"
322
+ output_dir.mkdir(parents=True, exist_ok=True)
323
+
324
+ # Check if already extracted
325
+ metadata_path = output_dir / "metadata.json"
326
+ if metadata_path.exists():
327
+ logger.debug(f"Masks already extracted: {output_dir}")
328
+ return
329
+
330
+ if not self.client:
331
+ raise ValueError("Client not initialized")
332
+
333
+ # Get annotations for this job
334
+ try:
335
+ annotations = self.client.annotations.get_job_annotations(
336
+ img_meta["job_id"]
337
+ )
338
+
339
+ if not annotations or not hasattr(annotations, 'shapes'):
340
+ raise ValueError("No annotations found")
341
+
342
+ # Filter masks for this frame
343
+ frame_masks = [
344
+ shape
345
+ for shape in annotations.shapes
346
+ if shape.frame == img_meta["frame_id"]
347
+ and hasattr(shape, 'type')
348
+ and shape.type == 'mask'
349
+ ]
350
+
351
+ if not frame_masks:
352
+ logger.warning(
353
+ f"No mask annotations found for frame {img_meta['frame_id']}"
354
+ )
355
+ # Create empty metadata
356
+ with open(metadata_path, "w") as f:
357
+ json.dump({"masks": []}, f, indent=2)
358
+ return
359
+
360
+ # Get image dimensions
361
+ from PIL import Image
362
+ with Image.open(image_path) as img:
363
+ width, height = img.size
364
+
365
+ # Extract each mask
366
+ mask_metadata = []
367
+ label_counts: dict[str, int] = {}
368
+
369
+ for shape in frame_masks:
370
+ label = shape.label_name
371
+ if label not in label_counts:
372
+ label_counts[label] = 0
373
+
374
+ instance_idx = label_counts[label]
375
+ label_counts[label] += 1
376
+
377
+ # Convert CVAT RLE to mask
378
+ if not hasattr(shape, 'points') or not shape.points:
379
+ logger.warning(f"Shape missing points data, skipping")
380
+ continue
381
+
382
+ mask_filename = f"mask_{label}_{instance_idx}.png"
383
+ mask_path = output_dir / mask_filename
384
+
385
+ try:
386
+ mask = Mask.from_cvat_api_rle(
387
+ cvat_rle=shape.points,
388
+ width=width,
389
+ height=height,
390
+ file_path=str(mask_path),
391
+ )
392
+
393
+ mask_metadata.append({
394
+ "filename": mask_filename,
395
+ "label": label,
396
+ "instance_idx": instance_idx,
397
+ })
398
+
399
+ logger.debug(f"Extracted mask: {mask_filename}")
400
+
401
+ except Exception as e:
402
+ logger.error(f"Failed to convert mask for label {label}: {e}")
403
+ continue
404
+
405
+ # Save metadata
406
+ with open(metadata_path, "w") as f:
407
+ json.dump(
408
+ {
409
+ "image": str(image_path.name),
410
+ "width": width,
411
+ "height": height,
412
+ "masks": mask_metadata,
413
+ },
414
+ f,
415
+ indent=2,
416
+ )
417
+
418
+ logger.info(f"Extracted {len(mask_metadata)} masks to {output_dir}")
419
+
420
+ except Exception as e:
421
+ raise ValueError(f"Failed to extract masks: {e}") from e
422
+
423
+ def run_extraction(self) -> dict[str, list[Path]]:
424
+ """Run complete extraction pipeline.
425
+
426
+ Returns:
427
+ Dict mapping class names to lists of image paths
428
+
429
+ Raises:
430
+ Exception: If any critical step fails
431
+ """
432
+ logger.info("Starting CVAT extraction pipeline")
433
+
434
+ # Connect
435
+ self.connect()
436
+
437
+ # Find project
438
+ self.find_training_project()
439
+
440
+ # Discover images
441
+ class_images = self.discover_images()
442
+
443
+ # Sample
444
+ sampled = self.sample_images(class_images)
445
+
446
+ # Download and extract
447
+ paths = self.download_images_and_masks(sampled)
448
+
449
+ logger.info("CVAT extraction complete")
450
+ return paths
metrics_evaluation/inference/sam3_inference.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAM3 inference for evaluation."""
2
+
3
+ import base64
4
+ import json
5
+ import logging
6
+ import time
7
+ from pathlib import Path
8
+
9
+ import requests
10
+ from PIL import Image
11
+
12
+ from ..config.config_models import EvaluationConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class SAM3Inferencer:
18
+ """Run SAM3 inference on images."""
19
+
20
+ def __init__(self, config: EvaluationConfig):
21
+ """Initialize inferencer with configuration.
22
+
23
+ Args:
24
+ config: Evaluation configuration
25
+ """
26
+ self.config = config
27
+ self.endpoint = config.sam3.endpoint
28
+ self.timeout = config.sam3.timeout
29
+ self.retry_attempts = config.sam3.retry_attempts
30
+
31
+ def infer_single_image(self, image_path: Path, classes: list[str]) -> list[dict]:
32
+ """Run SAM3 inference on single image.
33
+
34
+ Args:
35
+ image_path: Path to image file
36
+ classes: List of class names to detect
37
+
38
+ Returns:
39
+ List of detection results with masks
40
+
41
+ Raises:
42
+ ValueError: If inference fails after retries
43
+ """
44
+ # Load and encode image
45
+ with open(image_path, "rb") as f:
46
+ image_data = f.read()
47
+
48
+ image_b64 = base64.b64encode(image_data).decode()
49
+
50
+ # Prepare request
51
+ payload = {
52
+ "inputs": image_b64,
53
+ "parameters": {"classes": classes}
54
+ }
55
+
56
+ # Retry logic
57
+ last_error = None
58
+ for attempt in range(self.retry_attempts):
59
+ try:
60
+ start_time = time.time()
61
+
62
+ response = requests.post(
63
+ self.endpoint,
64
+ json=payload,
65
+ timeout=self.timeout
66
+ )
67
+
68
+ elapsed = time.time() - start_time
69
+
70
+ if response.status_code == 200:
71
+ results = response.json()
72
+ logger.debug(
73
+ f"Inference successful for {image_path.name} "
74
+ f"({elapsed:.2f}s, {len(results)} detections)"
75
+ )
76
+ return results
77
+
78
+ else:
79
+ last_error = f"HTTP {response.status_code}: {response.text}"
80
+ logger.warning(f"Inference failed (attempt {attempt + 1}): {last_error}")
81
+
82
+ except requests.Timeout:
83
+ last_error = f"Request timeout after {self.timeout}s"
84
+ logger.warning(f"Inference timeout (attempt {attempt + 1})")
85
+
86
+ except Exception as e:
87
+ last_error = str(e)
88
+ logger.warning(f"Inference error (attempt {attempt + 1}): {e}")
89
+
90
+ # Exponential backoff
91
+ if attempt < self.retry_attempts - 1:
92
+ sleep_time = 2 ** attempt
93
+ logger.debug(f"Retrying in {sleep_time}s...")
94
+ time.sleep(sleep_time)
95
+
96
+ raise ValueError(
97
+ f"Inference failed after {self.retry_attempts} attempts for {image_path}: {last_error}"
98
+ )
99
+
100
+ def save_inference_results(
101
+ self,
102
+ results: list[dict],
103
+ output_dir: Path,
104
+ image_width: int,
105
+ image_height: int
106
+ ) -> None:
107
+ """Save inference results as masks.
108
+
109
+ Args:
110
+ results: SAM3 detection results
111
+ output_dir: Directory to save masks
112
+ image_width: Image width
113
+ image_height: Image height
114
+
115
+ Raises:
116
+ ValueError: If mask conversion fails
117
+ """
118
+ output_dir.mkdir(parents=True, exist_ok=True)
119
+
120
+ mask_metadata = []
121
+ label_counts: dict[str, int] = {}
122
+
123
+ for result in results:
124
+ label = result.get("label", "unknown")
125
+ score = result.get("score", 0.0)
126
+ mask_b64 = result.get("mask")
127
+
128
+ if not mask_b64:
129
+ logger.warning(f"Result missing mask data for label {label}")
130
+ continue
131
+
132
+ # Count instances per label
133
+ if label not in label_counts:
134
+ label_counts[label] = 0
135
+
136
+ instance_idx = label_counts[label]
137
+ label_counts[label] += 1
138
+
139
+ # Decode mask
140
+ try:
141
+ mask_data = base64.b64decode(mask_b64)
142
+ mask_img = Image.open(io.BytesIO(mask_data))
143
+
144
+ # Convert to L mode (grayscale) if needed
145
+ if mask_img.mode != 'L':
146
+ mask_img = mask_img.convert('L')
147
+
148
+ # Validate dimensions
149
+ if mask_img.size != (image_width, image_height):
150
+ logger.warning(
151
+ f"Mask dimension mismatch: expected {image_width}x{image_height}, "
152
+ f"got {mask_img.size}. Resizing."
153
+ )
154
+ mask_img = mask_img.resize((image_width, image_height), Image.NEAREST)
155
+
156
+ # Save as PNG
157
+ mask_filename = f"mask_{label}_{instance_idx}.png"
158
+ mask_path = output_dir / mask_filename
159
+ mask_img.save(mask_path)
160
+
161
+ mask_metadata.append({
162
+ "filename": mask_filename,
163
+ "label": label,
164
+ "instance_idx": instance_idx,
165
+ "score": score,
166
+ })
167
+
168
+ logger.debug(f"Saved inference mask: {mask_filename}")
169
+
170
+ except Exception as e:
171
+ logger.error(f"Failed to save mask for label {label}: {e}")
172
+ continue
173
+
174
+ # Save metadata
175
+ metadata_path = output_dir / "metadata.json"
176
+ with open(metadata_path, "w") as f:
177
+ json.dump(
178
+ {
179
+ "width": image_width,
180
+ "height": image_height,
181
+ "masks": mask_metadata,
182
+ },
183
+ f,
184
+ indent=2,
185
+ )
186
+
187
+ logger.info(f"Saved {len(mask_metadata)} inference masks to {output_dir}")
188
+
189
+ def run_inference_batch(
190
+ self, image_paths: dict[str, list[Path]], force: bool = False
191
+ ) -> dict[str, int]:
192
+ """Run inference on batch of images.
193
+
194
+ Args:
195
+ image_paths: Dict mapping class names to image paths
196
+ force: Force re-inference even if results exist
197
+
198
+ Returns:
199
+ Dict with inference statistics
200
+
201
+ Raises:
202
+ ValueError: If no images provided
203
+ """
204
+ total_images = sum(len(paths) for paths in image_paths.values())
205
+ if total_images == 0:
206
+ raise ValueError("No images provided for inference")
207
+
208
+ logger.info(f"Starting inference on {total_images} images")
209
+
210
+ stats = {
211
+ "total": total_images,
212
+ "successful": 0,
213
+ "failed": 0,
214
+ "skipped": 0,
215
+ }
216
+
217
+ import io
218
+
219
+ processed = 0
220
+
221
+ for class_name, paths in image_paths.items():
222
+ for image_path in paths:
223
+ processed += 1
224
+ logger.info(
225
+ f"Inference {processed}/{total_images}: {image_path.parent.name}"
226
+ )
227
+
228
+ inference_dir = image_path.parent / "inference"
229
+ metadata_path = inference_dir / "metadata.json"
230
+
231
+ # Check cache
232
+ if not force and metadata_path.exists():
233
+ logger.debug(f"Inference results already exist: {inference_dir}")
234
+ stats["skipped"] += 1
235
+ continue
236
+
237
+ try:
238
+ # Get image dimensions
239
+ with Image.open(image_path) as img:
240
+ width, height = img.size
241
+
242
+ # Run inference
243
+ results = self.infer_single_image(
244
+ image_path,
245
+ list(self.config.classes.keys())
246
+ )
247
+
248
+ # Save results
249
+ self.save_inference_results(
250
+ results, inference_dir, width, height
251
+ )
252
+
253
+ stats["successful"] += 1
254
+
255
+ except Exception as e:
256
+ logger.error(f"Inference failed for {image_path}: {e}")
257
+ stats["failed"] += 1
258
+ continue
259
+
260
+ logger.info(
261
+ f"Inference complete: {stats['successful']} successful, "
262
+ f"{stats['failed']} failed, {stats['skipped']} skipped"
263
+ )
264
+
265
+ return stats
metrics_evaluation/metrics/metrics_calculator.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Metrics calculation for SAM3 evaluation."""
2
+
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+ from scipy.optimize import linear_sum_assignment
11
+
12
+ from ..config.config_models import EvaluationConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class MetricsCalculator:
18
+ """Calculate segmentation metrics."""
19
+
20
+ def __init__(self, config: EvaluationConfig):
21
+ """Initialize calculator with configuration.
22
+
23
+ Args:
24
+ config: Evaluation configuration
25
+ """
26
+ self.config = config
27
+ self.iou_thresholds = config.metrics.iou_thresholds
28
+
29
+ def calculate_iou(self, mask1: np.ndarray, mask2: np.ndarray) -> float:
30
+ """Calculate IoU between two binary masks.
31
+
32
+ Args:
33
+ mask1: First binary mask
34
+ mask2: Second binary mask
35
+
36
+ Returns:
37
+ IoU score between 0 and 1
38
+ """
39
+ intersection = np.logical_and(mask1, mask2).sum()
40
+ union = np.logical_or(mask1, mask2).sum()
41
+
42
+ if union == 0:
43
+ return 0.0
44
+
45
+ return float(intersection / union)
46
+
47
+ def match_instances(
48
+ self,
49
+ gt_masks: list[np.ndarray],
50
+ pred_masks: list[np.ndarray],
51
+ iou_threshold: float
52
+ ) -> dict[str, Any]:
53
+ """Match predicted instances to ground truth.
54
+
55
+ Args:
56
+ gt_masks: List of ground truth masks
57
+ pred_masks: List of predicted masks
58
+ iou_threshold: Minimum IoU for match
59
+
60
+ Returns:
61
+ Dict with matching results
62
+ """
63
+ if len(gt_masks) == 0 and len(pred_masks) == 0:
64
+ return {
65
+ "matches": [],
66
+ "unmatched_gt": [],
67
+ "unmatched_pred": [],
68
+ "true_positives": 0,
69
+ "false_positives": 0,
70
+ "false_negatives": 0,
71
+ }
72
+
73
+ if len(gt_masks) == 0:
74
+ return {
75
+ "matches": [],
76
+ "unmatched_gt": [],
77
+ "unmatched_pred": list(range(len(pred_masks))),
78
+ "true_positives": 0,
79
+ "false_positives": len(pred_masks),
80
+ "false_negatives": 0,
81
+ }
82
+
83
+ if len(pred_masks) == 0:
84
+ return {
85
+ "matches": [],
86
+ "unmatched_gt": list(range(len(gt_masks))),
87
+ "unmatched_pred": [],
88
+ "true_positives": 0,
89
+ "false_positives": 0,
90
+ "false_negatives": len(gt_masks),
91
+ }
92
+
93
+ # Compute IoU matrix
94
+ iou_matrix = np.zeros((len(gt_masks), len(pred_masks)))
95
+
96
+ for i, gt_mask in enumerate(gt_masks):
97
+ for j, pred_mask in enumerate(pred_masks):
98
+ iou_matrix[i, j] = self.calculate_iou(gt_mask, pred_mask)
99
+
100
+ # Hungarian algorithm for optimal matching
101
+ gt_indices, pred_indices = linear_sum_assignment(-iou_matrix)
102
+
103
+ # Filter matches by threshold
104
+ matches = []
105
+ for gt_idx, pred_idx in zip(gt_indices, pred_indices):
106
+ iou = iou_matrix[gt_idx, pred_idx]
107
+ if iou >= iou_threshold:
108
+ matches.append({
109
+ "gt_idx": int(gt_idx),
110
+ "pred_idx": int(pred_idx),
111
+ "iou": float(iou),
112
+ })
113
+
114
+ matched_gt = {m["gt_idx"] for m in matches}
115
+ matched_pred = {m["pred_idx"] for m in matches}
116
+
117
+ unmatched_gt = [i for i in range(len(gt_masks)) if i not in matched_gt]
118
+ unmatched_pred = [i for i in range(len(pred_masks)) if i not in matched_pred]
119
+
120
+ return {
121
+ "matches": matches,
122
+ "unmatched_gt": unmatched_gt,
123
+ "unmatched_pred": unmatched_pred,
124
+ "true_positives": len(matches),
125
+ "false_positives": len(unmatched_pred),
126
+ "false_negatives": len(unmatched_gt),
127
+ }
128
+
129
+ def calculate_image_metrics(
130
+ self, image_dir: Path
131
+ ) -> dict[str, Any]:
132
+ """Calculate metrics for single image.
133
+
134
+ Args:
135
+ image_dir: Directory containing ground_truth and inference subdirs
136
+
137
+ Returns:
138
+ Dict with metrics for this image
139
+
140
+ Raises:
141
+ ValueError: If required files are missing
142
+ """
143
+ gt_dir = image_dir / "ground_truth"
144
+ inf_dir = image_dir / "inference"
145
+
146
+ # Load metadata
147
+ gt_meta_path = gt_dir / "metadata.json"
148
+ inf_meta_path = inf_dir / "metadata.json"
149
+
150
+ if not gt_meta_path.exists():
151
+ raise ValueError(f"Ground truth metadata not found: {gt_meta_path}")
152
+
153
+ if not inf_meta_path.exists():
154
+ logger.warning(f"Inference metadata not found: {inf_meta_path}")
155
+ return None
156
+
157
+ with open(gt_meta_path) as f:
158
+ gt_meta = json.load(f)
159
+
160
+ with open(inf_meta_path) as f:
161
+ inf_meta = json.load(f)
162
+
163
+ # Group masks by label
164
+ gt_by_label: dict[str, list[np.ndarray]] = {}
165
+ inf_by_label: dict[str, list[np.ndarray]] = {}
166
+
167
+ # Load ground truth masks
168
+ for mask_info in gt_meta.get("masks", []):
169
+ label = mask_info["label"]
170
+ mask_path = gt_dir / mask_info["filename"]
171
+
172
+ if not mask_path.exists():
173
+ logger.warning(f"Ground truth mask not found: {mask_path}")
174
+ continue
175
+
176
+ mask_img = Image.open(mask_path)
177
+ mask_array = np.array(mask_img) > 0 # Binarize
178
+
179
+ if label not in gt_by_label:
180
+ gt_by_label[label] = []
181
+ gt_by_label[label].append(mask_array)
182
+
183
+ # Load inference masks
184
+ for mask_info in inf_meta.get("masks", []):
185
+ label = mask_info["label"]
186
+ mask_path = inf_dir / mask_info["filename"]
187
+
188
+ if not mask_path.exists():
189
+ logger.warning(f"Inference mask not found: {mask_path}")
190
+ continue
191
+
192
+ mask_img = Image.open(mask_path)
193
+ mask_array = np.array(mask_img) > 0 # Binarize
194
+
195
+ if label not in inf_by_label:
196
+ inf_by_label[label] = []
197
+ inf_by_label[label].append(mask_array)
198
+
199
+ # Calculate metrics at each IoU threshold
200
+ results = {
201
+ "image_name": image_dir.name,
202
+ "class": image_dir.parent.name,
203
+ "by_threshold": {},
204
+ }
205
+
206
+ all_labels = set(gt_by_label.keys()) | set(inf_by_label.keys())
207
+
208
+ for threshold in self.iou_thresholds:
209
+ threshold_results = {
210
+ "iou_threshold": threshold,
211
+ "by_label": {},
212
+ "total": {
213
+ "true_positives": 0,
214
+ "false_positives": 0,
215
+ "false_negatives": 0,
216
+ },
217
+ }
218
+
219
+ for label in all_labels:
220
+ gt_masks = gt_by_label.get(label, [])
221
+ pred_masks = inf_by_label.get(label, [])
222
+
223
+ matching = self.match_instances(gt_masks, pred_masks, threshold)
224
+
225
+ threshold_results["by_label"][label] = {
226
+ "gt_count": len(gt_masks),
227
+ "pred_count": len(pred_masks),
228
+ "true_positives": matching["true_positives"],
229
+ "false_positives": matching["false_positives"],
230
+ "false_negatives": matching["false_negatives"],
231
+ "matches": matching["matches"],
232
+ }
233
+
234
+ # Add to totals
235
+ threshold_results["total"]["true_positives"] += matching["true_positives"]
236
+ threshold_results["total"]["false_positives"] += matching["false_positives"]
237
+ threshold_results["total"]["false_negatives"] += matching["false_negatives"]
238
+
239
+ results["by_threshold"][str(threshold)] = threshold_results
240
+
241
+ return results
242
+
243
+ def calculate_aggregate_metrics(
244
+ self, image_results: list[dict[str, Any]]
245
+ ) -> dict[str, Any]:
246
+ """Calculate aggregate metrics across all images.
247
+
248
+ Args:
249
+ image_results: List of per-image metrics
250
+
251
+ Returns:
252
+ Dict with aggregate metrics
253
+ """
254
+ aggregate = {
255
+ "total_images": len(image_results),
256
+ "by_threshold": {},
257
+ }
258
+
259
+ for threshold in self.iou_thresholds:
260
+ threshold_str = str(threshold)
261
+
262
+ # Aggregate counts
263
+ label_stats: dict[str, dict] = {}
264
+ total_tp = 0
265
+ total_fp = 0
266
+ total_fn = 0
267
+
268
+ for img_result in image_results:
269
+ threshold_data = img_result["by_threshold"][threshold_str]
270
+
271
+ for label, label_data in threshold_data["by_label"].items():
272
+ if label not in label_stats:
273
+ label_stats[label] = {
274
+ "tp": 0,
275
+ "fp": 0,
276
+ "fn": 0,
277
+ "gt_total": 0,
278
+ "pred_total": 0,
279
+ }
280
+
281
+ label_stats[label]["tp"] += label_data["true_positives"]
282
+ label_stats[label]["fp"] += label_data["false_positives"]
283
+ label_stats[label]["fn"] += label_data["false_negatives"]
284
+ label_stats[label]["gt_total"] += label_data["gt_count"]
285
+ label_stats[label]["pred_total"] += label_data["pred_count"]
286
+
287
+ total_tp += threshold_data["total"]["true_positives"]
288
+ total_fp += threshold_data["total"]["false_positives"]
289
+ total_fn += threshold_data["total"]["false_negatives"]
290
+
291
+ # Calculate precision, recall, F1 per label
292
+ for label, stats in label_stats.items():
293
+ precision = stats["tp"] / (stats["tp"] + stats["fp"]) if (stats["tp"] + stats["fp"]) > 0 else 0.0
294
+ recall = stats["tp"] / (stats["tp"] + stats["fn"]) if (stats["tp"] + stats["fn"]) > 0 else 0.0
295
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
296
+
297
+ stats["precision"] = precision
298
+ stats["recall"] = recall
299
+ stats["f1"] = f1
300
+
301
+ # Calculate overall metrics
302
+ overall_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
303
+ overall_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
304
+ overall_f1 = 2 * overall_precision * overall_recall / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0.0
305
+
306
+ # Build confusion matrix
307
+ confusion_matrix = self._build_confusion_matrix(image_results, threshold_str)
308
+
309
+ aggregate["by_threshold"][threshold_str] = {
310
+ "iou_threshold": threshold,
311
+ "by_label": label_stats,
312
+ "overall": {
313
+ "true_positives": total_tp,
314
+ "false_positives": total_fp,
315
+ "false_negatives": total_fn,
316
+ "precision": overall_precision,
317
+ "recall": overall_recall,
318
+ "f1": overall_f1,
319
+ "map": overall_precision, # Simplified mAP
320
+ "mar": overall_recall, # Simplified mAR
321
+ },
322
+ "confusion_matrix": confusion_matrix,
323
+ }
324
+
325
+ return aggregate
326
+
327
+ def _build_confusion_matrix(
328
+ self, image_results: list[dict[str, Any]], threshold_str: str
329
+ ) -> dict[str, Any]:
330
+ """Build confusion matrix for given threshold.
331
+
332
+ Args:
333
+ image_results: List of per-image metrics
334
+ threshold_str: IoU threshold as string
335
+
336
+ Returns:
337
+ Confusion matrix data
338
+ """
339
+ # Get all labels
340
+ all_labels = set()
341
+ for img_result in image_results:
342
+ threshold_data = img_result["by_threshold"][threshold_str]
343
+ all_labels.update(threshold_data["by_label"].keys())
344
+
345
+ labels = sorted(all_labels)
346
+ n_labels = len(labels)
347
+
348
+ # Initialize matrix
349
+ matrix = np.zeros((n_labels, n_labels), dtype=int)
350
+
351
+ # Fill matrix (simplified: just count matches)
352
+ for img_result in image_results:
353
+ threshold_data = img_result["by_threshold"][threshold_str]
354
+
355
+ for i, gt_label in enumerate(labels):
356
+ if gt_label not in threshold_data["by_label"]:
357
+ continue
358
+
359
+ label_data = threshold_data["by_label"][gt_label]
360
+
361
+ # True positives go on diagonal
362
+ matrix[i, i] += label_data["true_positives"]
363
+
364
+ # False negatives (missed) - simplified representation
365
+ # In a full implementation, we'd track which class they were predicted as
366
+
367
+ return {
368
+ "labels": labels,
369
+ "matrix": matrix.tolist(),
370
+ }
371
+
372
+ def run_evaluation(self, cache_dir: Path) -> dict[str, Any]:
373
+ """Run complete metrics evaluation.
374
+
375
+ Args:
376
+ cache_dir: Cache directory with ground truth and inference results
377
+
378
+ Returns:
379
+ Complete metrics results
380
+
381
+ Raises:
382
+ ValueError: If cache directory is invalid
383
+ """
384
+ if not cache_dir.exists():
385
+ raise ValueError(f"Cache directory not found: {cache_dir}")
386
+
387
+ logger.info(f"Calculating metrics from {cache_dir}")
388
+
389
+ # Find all image directories
390
+ image_results = []
391
+
392
+ for class_dir in cache_dir.iterdir():
393
+ if not class_dir.is_dir():
394
+ continue
395
+
396
+ for image_dir in class_dir.iterdir():
397
+ if not image_dir.is_dir():
398
+ continue
399
+
400
+ try:
401
+ metrics = self.calculate_image_metrics(image_dir)
402
+ if metrics:
403
+ image_results.append(metrics)
404
+ except Exception as e:
405
+ logger.error(f"Failed to calculate metrics for {image_dir}: {e}")
406
+ continue
407
+
408
+ if not image_results:
409
+ raise ValueError("No valid image results found for metrics calculation")
410
+
411
+ logger.info(f"Calculated metrics for {len(image_results)} images")
412
+
413
+ # Calculate aggregate
414
+ aggregate = self.calculate_aggregate_metrics(image_results)
415
+
416
+ return {
417
+ "per_image": image_results,
418
+ "aggregate": aggregate,
419
+ }
metrics_evaluation/run_evaluation.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Main execution script for SAM3 metrics evaluation."""
3
+
4
+ import argparse
5
+ import json
6
+ import logging
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ from config.config_loader import load_config
11
+ from extraction.cvat_extractor import CVATExtractor
12
+ from inference.sam3_inference import SAM3Inferencer
13
+ from metrics.metrics_calculator import MetricsCalculator
14
+ from utils.logging_config import setup_logging
15
+ from visualization.visual_comparison import VisualComparator
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def write_metrics_summary(metrics: dict, output_path: Path) -> None:
21
+ """Write human-readable metrics summary.
22
+
23
+ Args:
24
+ metrics: Metrics dictionary
25
+ output_path: Path to output file
26
+ """
27
+ with open(output_path, "w") as f:
28
+ f.write("=" * 80 + "\n")
29
+ f.write("SAM3 EVALUATION METRICS SUMMARY\n")
30
+ f.write("=" * 80 + "\n\n")
31
+
32
+ aggregate = metrics["aggregate"]
33
+
34
+ f.write(f"Total Images Evaluated: {aggregate['total_images']}\n\n")
35
+
36
+ for threshold_str, threshold_data in aggregate["by_threshold"].items():
37
+ iou = threshold_data["iou_threshold"]
38
+ f.write(f"\n{'='*80}\n")
39
+ f.write(f"IoU Threshold: {iou:.0%}\n")
40
+ f.write(f"{'='*80}\n\n")
41
+
42
+ overall = threshold_data["overall"]
43
+
44
+ f.write("Overall Metrics:\n")
45
+ f.write(f" True Positives: {overall['true_positives']}\n")
46
+ f.write(f" False Positives: {overall['false_positives']}\n")
47
+ f.write(f" False Negatives: {overall['false_negatives']}\n")
48
+ f.write(f" Precision: {overall['precision']:.2%}\n")
49
+ f.write(f" Recall: {overall['recall']:.2%}\n")
50
+ f.write(f" F1-Score: {overall['f1']:.2%}\n")
51
+ f.write(f" mAP: {overall['map']:.2%}\n")
52
+ f.write(f" mAR: {overall['mar']:.2%}\n\n")
53
+
54
+ f.write("Per-Class Metrics:\n")
55
+ f.write("-" * 80 + "\n")
56
+ f.write(f"{'Class':<20} {'GT':>6} {'Pred':>6} {'TP':>6} {'FP':>6} {'FN':>6} {'Prec':>8} {'Rec':>8} {'F1':>8}\n")
57
+ f.write("-" * 80 + "\n")
58
+
59
+ for label, stats in sorted(threshold_data["by_label"].items()):
60
+ f.write(
61
+ f"{label:<20} "
62
+ f"{stats['gt_total']:>6} "
63
+ f"{stats['pred_total']:>6} "
64
+ f"{stats['tp']:>6} "
65
+ f"{stats['fp']:>6} "
66
+ f"{stats['fn']:>6} "
67
+ f"{stats['precision']:>8.2%} "
68
+ f"{stats['recall']:>8.2%} "
69
+ f"{stats['f1']:>8.2%}\n"
70
+ )
71
+
72
+ f.write("\n")
73
+
74
+ # Confusion Matrix
75
+ cm = threshold_data["confusion_matrix"]
76
+ labels = cm["labels"]
77
+ matrix = cm["matrix"]
78
+
79
+ if labels:
80
+ f.write("Confusion Matrix:\n")
81
+ f.write("-" * 80 + "\n")
82
+
83
+ # Header
84
+ header = "Actual \\ Pred |"
85
+ for label in labels:
86
+ header += f" {label[:10]:>10} |"
87
+ f.write(header + "\n")
88
+ f.write("-" * len(header) + "\n")
89
+
90
+ # Rows
91
+ for i, actual_label in enumerate(labels):
92
+ row = f"{actual_label[:13]:>13} |"
93
+ for j in range(len(labels)):
94
+ row += f" {matrix[i][j]:>10} |"
95
+ f.write(row + "\n")
96
+
97
+ f.write("\n")
98
+
99
+ f.write("=" * 80 + "\n")
100
+ f.write("END OF REPORT\n")
101
+ f.write("=" * 80 + "\n")
102
+
103
+ logger.info(f"Wrote metrics summary to {output_path}")
104
+
105
+
106
+ def main() -> int:
107
+ """Main execution function.
108
+
109
+ Returns:
110
+ Exit code (0 for success, non-zero for failure)
111
+ """
112
+ parser = argparse.ArgumentParser(
113
+ description="Run SAM3 metrics evaluation against CVAT ground truth"
114
+ )
115
+ parser.add_argument(
116
+ "--config",
117
+ type=str,
118
+ default="config/config.json",
119
+ help="Path to configuration file"
120
+ )
121
+ parser.add_argument(
122
+ "--force-download",
123
+ action="store_true",
124
+ help="Force re-download images from CVAT"
125
+ )
126
+ parser.add_argument(
127
+ "--force-inference",
128
+ action="store_true",
129
+ help="Force re-run SAM3 inference"
130
+ )
131
+ parser.add_argument(
132
+ "--skip-inference",
133
+ action="store_true",
134
+ help="Skip inference, use cached results"
135
+ )
136
+ parser.add_argument(
137
+ "--visualize",
138
+ action="store_true",
139
+ help="Generate visual comparisons"
140
+ )
141
+ parser.add_argument(
142
+ "--log-level",
143
+ type=str,
144
+ default="INFO",
145
+ choices=["DEBUG", "INFO", "WARNING", "ERROR"],
146
+ help="Logging level"
147
+ )
148
+
149
+ args = parser.parse_args()
150
+
151
+ # Load configuration
152
+ try:
153
+ config = load_config(args.config)
154
+ except Exception as e:
155
+ print(f"ERROR: Failed to load configuration: {e}", file=sys.stderr)
156
+ return 1
157
+
158
+ # Setup logging
159
+ cache_dir = config.get_cache_path()
160
+ log_file = cache_dir / "evaluation_log.txt"
161
+ setup_logging(log_file, getattr(logging, args.log_level))
162
+
163
+ logger.info("=" * 80)
164
+ logger.info("SAM3 METRICS EVALUATION")
165
+ logger.info("=" * 80)
166
+
167
+ try:
168
+ # Phase 1: Extract from CVAT
169
+ logger.info("\n" + "=" * 80)
170
+ logger.info("PHASE 1: CVAT Data Extraction")
171
+ logger.info("=" * 80)
172
+
173
+ extractor = CVATExtractor(config)
174
+
175
+ if args.force_download:
176
+ logger.info("Force download enabled - will re-download all images")
177
+
178
+ image_paths = extractor.run_extraction()
179
+
180
+ total_extracted = sum(len(paths) for paths in image_paths.values())
181
+ logger.info(f"Extraction complete: {total_extracted} images extracted")
182
+
183
+ if total_extracted == 0:
184
+ logger.error("No images extracted. Aborting.")
185
+ return 1
186
+
187
+ # Phase 2: Run SAM3 Inference
188
+ if not args.skip_inference:
189
+ logger.info("\n" + "=" * 80)
190
+ logger.info("PHASE 2: SAM3 Inference")
191
+ logger.info("=" * 80)
192
+
193
+ inferencer = SAM3Inferencer(config)
194
+ stats = inferencer.run_inference_batch(image_paths, args.force_inference)
195
+
196
+ logger.info(
197
+ f"Inference complete: {stats['successful']} successful, "
198
+ f"{stats['failed']} failed, {stats['skipped']} skipped"
199
+ )
200
+
201
+ if stats['successful'] == 0 and stats['skipped'] == 0:
202
+ logger.error("No successful inferences. Aborting.")
203
+ return 1
204
+ else:
205
+ logger.info("Skipping inference (--skip-inference)")
206
+
207
+ # Phase 3: Calculate Metrics
208
+ logger.info("\n" + "=" * 80)
209
+ logger.info("PHASE 3: Metrics Calculation")
210
+ logger.info("=" * 80)
211
+
212
+ calculator = MetricsCalculator(config)
213
+ metrics = calculator.run_evaluation(cache_dir)
214
+
215
+ # Save detailed metrics
216
+ metrics_json_path = cache_dir / "metrics_detailed.json"
217
+ with open(metrics_json_path, "w") as f:
218
+ json.dump(metrics, f, indent=2)
219
+ logger.info(f"Saved detailed metrics to {metrics_json_path}")
220
+
221
+ # Write summary
222
+ metrics_summary_path = cache_dir / "metrics_summary.txt"
223
+ write_metrics_summary(metrics, metrics_summary_path)
224
+
225
+ # Phase 4: Visualization (optional)
226
+ if args.visualize or config.output.generate_visualizations:
227
+ logger.info("\n" + "=" * 80)
228
+ logger.info("PHASE 4: Visual Comparisons")
229
+ logger.info("=" * 80)
230
+
231
+ comparator = VisualComparator()
232
+ comparison_paths = comparator.generate_all_comparisons(cache_dir)
233
+ logger.info(f"Generated {len(comparison_paths)} visual comparisons")
234
+
235
+ # Summary
236
+ logger.info("\n" + "=" * 80)
237
+ logger.info("EVALUATION COMPLETE")
238
+ logger.info("=" * 80)
239
+
240
+ aggregate = metrics["aggregate"]
241
+ logger.info(f"Images evaluated: {aggregate['total_images']}")
242
+
243
+ # Show metrics at 50% IoU
244
+ threshold_50 = aggregate["by_threshold"]["0.5"]
245
+ overall = threshold_50["overall"]
246
+
247
+ logger.info(f"\nMetrics at 50% IoU:")
248
+ logger.info(f" Precision: {overall['precision']:.2%}")
249
+ logger.info(f" Recall: {overall['recall']:.2%}")
250
+ logger.info(f" F1-Score: {overall['f1']:.2%}")
251
+ logger.info(f" mAP: {overall['map']:.2%}")
252
+ logger.info(f" mAR: {overall['mar']:.2%}")
253
+
254
+ logger.info(f"\nResults saved to:")
255
+ logger.info(f" Metrics Summary: {metrics_summary_path}")
256
+ logger.info(f" Detailed JSON: {metrics_json_path}")
257
+ logger.info(f" Execution Log: {log_file}")
258
+
259
+ return 0
260
+
261
+ except KeyboardInterrupt:
262
+ logger.warning("\nEvaluation interrupted by user")
263
+ return 130
264
+
265
+ except Exception as e:
266
+ logger.error(f"\nEvaluation failed with error: {e}", exc_info=True)
267
+ return 1
268
+
269
+
270
+ if __name__ == "__main__":
271
+ sys.exit(main())
metrics_evaluation/utils/__init__.py ADDED
File without changes
metrics_evaluation/utils/logging_config.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Logging configuration for evaluation."""
2
+
3
+ import logging
4
+ import sys
5
+ from pathlib import Path
6
+
7
+
8
+ def setup_logging(log_file: Path | None = None, level: int = logging.INFO) -> None:
9
+ """Set up logging configuration.
10
+
11
+ Args:
12
+ log_file: Optional path to log file
13
+ level: Logging level
14
+ """
15
+ # Create formatters
16
+ detailed_formatter = logging.Formatter(
17
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
18
+ datefmt="%Y-%m-%d %H:%M:%S"
19
+ )
20
+
21
+ simple_formatter = logging.Formatter(
22
+ "%(levelname)s: %(message)s"
23
+ )
24
+
25
+ # Get root logger
26
+ logger = logging.getLogger()
27
+ logger.setLevel(level)
28
+
29
+ # Remove existing handlers
30
+ logger.handlers.clear()
31
+
32
+ # Console handler
33
+ console_handler = logging.StreamHandler(sys.stdout)
34
+ console_handler.setLevel(level)
35
+ console_handler.setFormatter(simple_formatter)
36
+ logger.addHandler(console_handler)
37
+
38
+ # File handler
39
+ if log_file:
40
+ log_file.parent.mkdir(parents=True, exist_ok=True)
41
+ file_handler = logging.FileHandler(log_file, mode="w")
42
+ file_handler.setLevel(logging.DEBUG)
43
+ file_handler.setFormatter(detailed_formatter)
44
+ logger.addHandler(file_handler)
metrics_evaluation/visualization/visual_comparison.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visual comparison generation for evaluation."""
2
+
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ from PIL import Image, ImageDraw, ImageFont
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class VisualComparator:
14
+ """Generate visual comparisons between ground truth and predictions."""
15
+
16
+ def __init__(self):
17
+ """Initialize comparator."""
18
+ self.colors = {
19
+ "ground_truth": (0, 255, 0, 128), # Green
20
+ "prediction": (255, 0, 0, 128), # Red
21
+ "true_positive": (255, 255, 0, 128), # Yellow
22
+ "false_positive": (255, 0, 0, 128), # Red
23
+ "false_negative": (0, 0, 255, 128), # Blue
24
+ }
25
+
26
+ def create_comparison(
27
+ self, image_dir: Path, output_path: Path | None = None
28
+ ) -> Path:
29
+ """Create visual comparison for image.
30
+
31
+ Args:
32
+ image_dir: Directory containing image and masks
33
+ output_path: Optional output path (default: image_dir/comparison.png)
34
+
35
+ Returns:
36
+ Path to generated comparison image
37
+
38
+ Raises:
39
+ ValueError: If required files are missing
40
+ """
41
+ # Load original image
42
+ image_path = image_dir / "image.jpg"
43
+ if not image_path.exists():
44
+ raise ValueError(f"Image not found: {image_path}")
45
+
46
+ original = Image.open(image_path).convert("RGBA")
47
+ width, height = original.size
48
+
49
+ # Create overlays
50
+ gt_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
51
+ pred_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
52
+
53
+ # Load ground truth masks
54
+ gt_dir = image_dir / "ground_truth"
55
+ if gt_dir.exists():
56
+ gt_meta_path = gt_dir / "metadata.json"
57
+ if gt_meta_path.exists():
58
+ with open(gt_meta_path) as f:
59
+ gt_meta = json.load(f)
60
+
61
+ for mask_info in gt_meta.get("masks", []):
62
+ mask_path = gt_dir / mask_info["filename"]
63
+ if not mask_path.exists():
64
+ continue
65
+
66
+ mask = Image.open(mask_path).convert("L")
67
+ colored_mask = Image.new("RGBA", (width, height), self.colors["ground_truth"])
68
+ colored_mask.putalpha(mask)
69
+ gt_overlay = Image.alpha_composite(gt_overlay, colored_mask)
70
+
71
+ # Load prediction masks
72
+ pred_dir = image_dir / "inference"
73
+ if pred_dir.exists():
74
+ pred_meta_path = pred_dir / "metadata.json"
75
+ if pred_meta_path.exists():
76
+ with open(pred_meta_path) as f:
77
+ pred_meta = json.load(f)
78
+
79
+ for mask_info in pred_meta.get("masks", []):
80
+ mask_path = pred_dir / mask_info["filename"]
81
+ if not mask_path.exists():
82
+ continue
83
+
84
+ mask = Image.open(mask_path).convert("L")
85
+ colored_mask = Image.new("RGBA", (width, height), self.colors["prediction"])
86
+ colored_mask.putalpha(mask)
87
+ pred_overlay = Image.alpha_composite(pred_overlay, colored_mask)
88
+
89
+ # Composite images
90
+ result = Image.alpha_composite(original, gt_overlay)
91
+ result = Image.alpha_composite(result, pred_overlay)
92
+
93
+ # Add legend
94
+ result = self._add_legend(result)
95
+
96
+ # Save
97
+ if output_path is None:
98
+ output_path = image_dir / "comparison.png"
99
+
100
+ result.convert("RGB").save(output_path)
101
+ logger.debug(f"Saved comparison to {output_path}")
102
+
103
+ return output_path
104
+
105
+ def _add_legend(self, image: Image.Image) -> Image.Image:
106
+ """Add color legend to image.
107
+
108
+ Args:
109
+ image: Input image
110
+
111
+ Returns:
112
+ Image with legend
113
+ """
114
+ # Create legend area
115
+ legend_height = 60
116
+ legend_img = Image.new("RGB", (image.width, image.height + legend_height), (255, 255, 255))
117
+ legend_img.paste(image, (0, 0))
118
+
119
+ draw = ImageDraw.Draw(legend_img)
120
+
121
+ # Draw legend items
122
+ x_offset = 10
123
+ y_offset = image.height + 10
124
+
125
+ items = [
126
+ ("Ground Truth", self.colors["ground_truth"][:3]),
127
+ ("Prediction", self.colors["prediction"][:3]),
128
+ ]
129
+
130
+ for label, color in items:
131
+ # Draw color box
132
+ draw.rectangle([x_offset, y_offset, x_offset + 30, y_offset + 30], fill=color)
133
+
134
+ # Draw label
135
+ draw.text((x_offset + 40, y_offset + 5), label, fill=(0, 0, 0))
136
+
137
+ x_offset += 200
138
+
139
+ return legend_img
140
+
141
+ def generate_all_comparisons(self, cache_dir: Path) -> list[Path]:
142
+ """Generate comparisons for all images in cache.
143
+
144
+ Args:
145
+ cache_dir: Cache directory
146
+
147
+ Returns:
148
+ List of paths to generated comparisons
149
+ """
150
+ comparison_paths = []
151
+
152
+ for class_dir in cache_dir.iterdir():
153
+ if not class_dir.is_dir():
154
+ continue
155
+
156
+ for image_dir in class_dir.iterdir():
157
+ if not image_dir.is_dir():
158
+ continue
159
+
160
+ try:
161
+ comparison_path = self.create_comparison(image_dir)
162
+ comparison_paths.append(comparison_path)
163
+ except Exception as e:
164
+ logger.error(f"Failed to create comparison for {image_dir}: {e}")
165
+ continue
166
+
167
+ logger.info(f"Generated {len(comparison_paths)} comparison images")
168
+ return comparison_paths