Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- sam3/__init__.py +7 -0
- sam3/agent/__init__.py +1 -0
- sam3/agent/agent_core.py +563 -0
- sam3/agent/client_llm.py +205 -0
- sam3/agent/client_sam3.py +138 -0
- sam3/agent/helpers/__init__.py +1 -0
- sam3/agent/helpers/boxes.py +438 -0
- sam3/agent/helpers/color_map.py +150 -0
- sam3/agent/helpers/keypoints.py +244 -0
- sam3/agent/helpers/mask_overlap_removal.py +128 -0
- sam3/agent/helpers/masks.py +560 -0
- sam3/agent/helpers/memory.py +87 -0
- sam3/agent/helpers/rle.py +122 -0
- sam3/agent/helpers/roi_align.py +75 -0
- sam3/agent/helpers/rotated_boxes.py +533 -0
- sam3/agent/helpers/som_utils.py +406 -0
- sam3/agent/helpers/visualizer.py +1662 -0
- sam3/agent/helpers/zoom_in.py +195 -0
- sam3/agent/inference.py +65 -0
- sam3/agent/system_prompts/system_prompt.txt +242 -0
- sam3/agent/system_prompts/system_prompt_iterative_checking.txt +26 -0
- sam3/agent/viz.py +114 -0
- sam3/eval/__init__.py +1 -0
- sam3/eval/cgf1_eval.py +703 -0
- sam3/eval/coco_eval.py +916 -0
- sam3/eval/coco_eval_offline.py +181 -0
- sam3/eval/coco_reindex.py +230 -0
- sam3/eval/coco_writer.py +352 -0
- sam3/eval/conversion_util.py +211 -0
- sam3/eval/demo_eval.py +658 -0
- sam3/eval/hota_eval_toolkit/__init__.py +1 -0
- sam3/eval/hota_eval_toolkit/run_ytvis_eval.py +114 -0
- sam3/eval/hota_eval_toolkit/trackeval/__init__.py +4 -0
- sam3/eval/hota_eval_toolkit/trackeval/_timing.py +68 -0
- sam3/eval/hota_eval_toolkit/trackeval/datasets/__init__.py +4 -0
- sam3/eval/hota_eval_toolkit/trackeval/datasets/_base_dataset.py +379 -0
- sam3/eval/hota_eval_toolkit/trackeval/datasets/tao_ow.py +891 -0
- sam3/eval/hota_eval_toolkit/trackeval/datasets/youtube_vis.py +524 -0
- sam3/eval/hota_eval_toolkit/trackeval/eval.py +395 -0
- sam3/eval/hota_eval_toolkit/trackeval/metrics/__init__.py +4 -0
- sam3/eval/hota_eval_toolkit/trackeval/metrics/_base_metric.py +145 -0
- sam3/eval/hota_eval_toolkit/trackeval/metrics/count.py +48 -0
- sam3/eval/hota_eval_toolkit/trackeval/metrics/hota.py +291 -0
- sam3/eval/hota_eval_toolkit/trackeval/utils.py +195 -0
- sam3/eval/postprocessors.py +648 -0
- sam3/eval/saco_veval_eval.py +155 -0
- sam3/eval/saco_veval_evaluators.py +838 -0
- sam3/eval/teta_eval_toolkit/__init__.py +5 -0
- sam3/eval/teta_eval_toolkit/_timing.py +69 -0
.gitattributes
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
sam3/perflib/tests/assets/masks.tiff filter=lfs diff=lfs merge=lfs -text
|
sam3/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
from .model_builder import build_sam3_image_model
|
| 4 |
+
|
| 5 |
+
__version__ = "0.1.0"
|
| 6 |
+
|
| 7 |
+
__all__ = ["build_sam3_image_model"]
|
sam3/agent/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
sam3/agent/agent_core.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from .client_llm import send_generate_request
|
| 11 |
+
from .client_sam3 import call_sam_service
|
| 12 |
+
from .viz import visualize
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def save_debug_messages(messages_list, debug, debug_folder_path, debug_jsonl_path):
|
| 16 |
+
"""Save messages to debug jsonl file if debug is enabled"""
|
| 17 |
+
if debug and debug_jsonl_path:
|
| 18 |
+
# Ensure the debug directory exists before writing
|
| 19 |
+
os.makedirs(debug_folder_path, exist_ok=True)
|
| 20 |
+
with open(debug_jsonl_path, "w") as f:
|
| 21 |
+
for msg in messages_list:
|
| 22 |
+
f.write(json.dumps(msg, indent=4) + "\n")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def cleanup_debug_files(debug, debug_folder_path, debug_jsonl_path):
|
| 26 |
+
"""Clean up debug files when function successfully returns"""
|
| 27 |
+
if debug and debug_folder_path:
|
| 28 |
+
try:
|
| 29 |
+
if os.path.exists(debug_jsonl_path):
|
| 30 |
+
os.remove(debug_jsonl_path)
|
| 31 |
+
if os.path.exists(debug_folder_path):
|
| 32 |
+
os.rmdir(debug_folder_path)
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Warning: Could not clean up debug files: {e}")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def count_images(messages):
|
| 38 |
+
"""Count the total number of images present in the messages history."""
|
| 39 |
+
total = 0
|
| 40 |
+
for message in messages:
|
| 41 |
+
# Check if message has content (should be a list)
|
| 42 |
+
if "content" in message and isinstance(message["content"], list):
|
| 43 |
+
# Iterate through each content item
|
| 44 |
+
for content_item in message["content"]:
|
| 45 |
+
# Check if content item is a dict with type "image"
|
| 46 |
+
if (
|
| 47 |
+
isinstance(content_item, dict)
|
| 48 |
+
and content_item.get("type") == "image"
|
| 49 |
+
):
|
| 50 |
+
total += 1
|
| 51 |
+
return total
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _prune_messages_for_next_round(
|
| 55 |
+
messages_list,
|
| 56 |
+
used_text_prompts,
|
| 57 |
+
latest_sam3_text_prompt,
|
| 58 |
+
img_path,
|
| 59 |
+
initial_text_prompt,
|
| 60 |
+
):
|
| 61 |
+
"""Return a new messages list that contains only:
|
| 62 |
+
1) messages[:2] (with optional warning text added to the second message's content)
|
| 63 |
+
2) the latest assistant message (and everything after it) that contains a segment_phrase tool call
|
| 64 |
+
"""
|
| 65 |
+
# There should not be more than 10 messages in the conversation history
|
| 66 |
+
assert len(messages_list) < 10
|
| 67 |
+
|
| 68 |
+
# Part 1: always keep the first two message JSONs
|
| 69 |
+
part1 = copy.deepcopy(messages_list[:2])
|
| 70 |
+
|
| 71 |
+
# Part 2: search backwards for the latest assistant message containing a segment_phrase tool call
|
| 72 |
+
part2_start_idx = None
|
| 73 |
+
for idx in range(len(messages_list) - 1, 1, -1):
|
| 74 |
+
msg = messages_list[idx]
|
| 75 |
+
# We only consider assistant messages with a "content" list
|
| 76 |
+
if msg.get("role") != "assistant" or "content" not in msg:
|
| 77 |
+
continue
|
| 78 |
+
# Look for any content element that is a text containing the segment_phrase tool call
|
| 79 |
+
for content in msg["content"]:
|
| 80 |
+
if (
|
| 81 |
+
isinstance(content, dict)
|
| 82 |
+
and content.get("type") == "text"
|
| 83 |
+
and "<tool>" in content.get("text", "")
|
| 84 |
+
and "segment_phrase" in content.get("text", "")
|
| 85 |
+
):
|
| 86 |
+
part2_start_idx = idx
|
| 87 |
+
break
|
| 88 |
+
if part2_start_idx is not None:
|
| 89 |
+
break
|
| 90 |
+
|
| 91 |
+
part2 = messages_list[part2_start_idx:] if part2_start_idx is not None else []
|
| 92 |
+
|
| 93 |
+
# Part 3: decide whether to add warning text to the second message in part1
|
| 94 |
+
previously_used = (
|
| 95 |
+
[p for p in used_text_prompts if p != latest_sam3_text_prompt]
|
| 96 |
+
if latest_sam3_text_prompt
|
| 97 |
+
else list(used_text_prompts)
|
| 98 |
+
)
|
| 99 |
+
if part2 and len(previously_used) > 0:
|
| 100 |
+
warning_text = f'Note that we have previously called the segment_phrase tool with each "text_prompt" in this list: {list(previously_used)}, but none of the generated results were satisfactory. So make sure that you do not use any of these phrases as the "text_prompt" to call the segment_phrase tool again.'
|
| 101 |
+
# Replace the second message entirely to keep exactly 2 content items
|
| 102 |
+
part1[1] = {
|
| 103 |
+
"role": "user",
|
| 104 |
+
"content": [
|
| 105 |
+
{"type": "image", "image": img_path},
|
| 106 |
+
{
|
| 107 |
+
"type": "text",
|
| 108 |
+
"text": f"The above image is the raw input image. The initial user input query is: '{initial_text_prompt}'."
|
| 109 |
+
+ " "
|
| 110 |
+
+ warning_text,
|
| 111 |
+
},
|
| 112 |
+
],
|
| 113 |
+
}
|
| 114 |
+
assert len(part1[1]["content"]) == 2
|
| 115 |
+
|
| 116 |
+
# Build the new messages list: part1 (with optional warning), then part2
|
| 117 |
+
new_messages = list(part1)
|
| 118 |
+
new_messages.extend(part2)
|
| 119 |
+
return new_messages
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def agent_inference(
|
| 123 |
+
img_path: str,
|
| 124 |
+
initial_text_prompt: str,
|
| 125 |
+
debug: bool = False,
|
| 126 |
+
send_generate_request=send_generate_request,
|
| 127 |
+
call_sam_service=call_sam_service,
|
| 128 |
+
max_generations: int = 100,
|
| 129 |
+
output_dir="../../sam3_agent_out",
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
Given a text prompt and an image, this tool will perform all aspects of agentic problem solving,
|
| 133 |
+
while saving sam3 and MLLM outputs to their respective directories.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
img_path: Path to the input image
|
| 137 |
+
initial_text_prompt: Initial text prompt from the user
|
| 138 |
+
debug: Whether to enable debug mode
|
| 139 |
+
max_generations: Maximum number of send_generate_request calls allowed (default: 100)
|
| 140 |
+
"""
|
| 141 |
+
# setup dir
|
| 142 |
+
sam_output_dir = os.path.join(output_dir, "sam_out")
|
| 143 |
+
error_save_dir = os.path.join(output_dir, "none_out")
|
| 144 |
+
debug_save_dir = os.path.join(output_dir, "agent_debug_out")
|
| 145 |
+
os.makedirs(sam_output_dir, exist_ok=True)
|
| 146 |
+
os.makedirs(error_save_dir, exist_ok=True)
|
| 147 |
+
os.makedirs(debug_save_dir, exist_ok=True)
|
| 148 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 149 |
+
MLLM_SYSTEM_PROMPT_PATH = os.path.join(
|
| 150 |
+
current_dir, "system_prompts/system_prompt.txt"
|
| 151 |
+
)
|
| 152 |
+
ITERATIVE_CHECKING_SYSTEM_PROMPT_PATH = os.path.join(
|
| 153 |
+
current_dir, "system_prompts/system_prompt_iterative_checking.txt"
|
| 154 |
+
)
|
| 155 |
+
# init variables
|
| 156 |
+
PATH_TO_LATEST_OUTPUT_JSON = ""
|
| 157 |
+
LATEST_SAM3_TEXT_PROMPT = ""
|
| 158 |
+
USED_TEXT_PROMPTS = (
|
| 159 |
+
set()
|
| 160 |
+
) # Track all previously used text prompts for segment_phrase
|
| 161 |
+
generation_count = 0 # Counter for number of send_generate_request calls
|
| 162 |
+
|
| 163 |
+
# debug setup
|
| 164 |
+
debug_folder_path = None
|
| 165 |
+
debug_jsonl_path = None
|
| 166 |
+
if debug:
|
| 167 |
+
debug_folder_path = os.path.join(
|
| 168 |
+
debug_save_dir, f"{img_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]}"
|
| 169 |
+
)
|
| 170 |
+
debug_jsonl_path = os.path.join(debug_folder_path, "debug_history.json")
|
| 171 |
+
os.makedirs(debug_folder_path, exist_ok=True)
|
| 172 |
+
|
| 173 |
+
# The helper functions are now defined outside the agent_inference function
|
| 174 |
+
with open(MLLM_SYSTEM_PROMPT_PATH, "r") as f:
|
| 175 |
+
system_prompt = f.read().strip()
|
| 176 |
+
with open(ITERATIVE_CHECKING_SYSTEM_PROMPT_PATH, "r") as f:
|
| 177 |
+
iterative_checking_system_prompt = f.read().strip()
|
| 178 |
+
|
| 179 |
+
# Construct the initial message list
|
| 180 |
+
messages = [
|
| 181 |
+
{"role": "system", "content": system_prompt},
|
| 182 |
+
{
|
| 183 |
+
"role": "user",
|
| 184 |
+
"content": [
|
| 185 |
+
{"type": "image", "image": img_path},
|
| 186 |
+
{
|
| 187 |
+
"type": "text",
|
| 188 |
+
"text": f"The above image is the raw input image. The initial user input query is: '{initial_text_prompt}'.",
|
| 189 |
+
},
|
| 190 |
+
],
|
| 191 |
+
},
|
| 192 |
+
]
|
| 193 |
+
print(f"> Text prompt: {initial_text_prompt}")
|
| 194 |
+
print(f"> Image path: {img_path}")
|
| 195 |
+
|
| 196 |
+
print("\n\n")
|
| 197 |
+
print("-" * 30 + f" Round {str(generation_count + 1)}" + "-" * 30)
|
| 198 |
+
print("\n\n")
|
| 199 |
+
generated_text = send_generate_request(messages)
|
| 200 |
+
print(f"\n>>> MLLM Response [start]\n{generated_text}\n<<< MLLM Response [end]\n")
|
| 201 |
+
while generated_text is not None:
|
| 202 |
+
save_debug_messages(messages, debug, debug_folder_path, debug_jsonl_path)
|
| 203 |
+
assert (
|
| 204 |
+
"<tool>" in generated_text,
|
| 205 |
+
f"Generated text does not contain <tool> tag: {generated_text}",
|
| 206 |
+
)
|
| 207 |
+
generated_text = generated_text.split("</tool>", 1)[0] + "</tool>"
|
| 208 |
+
tool_call_json_str = (
|
| 209 |
+
generated_text.split("<tool>")[-1]
|
| 210 |
+
.split("</tool>")[0]
|
| 211 |
+
.strip()
|
| 212 |
+
.replace(r"}}}", r"}}") # remove extra } if any
|
| 213 |
+
)
|
| 214 |
+
try:
|
| 215 |
+
tool_call = json.loads(tool_call_json_str)
|
| 216 |
+
except json.JSONDecodeError:
|
| 217 |
+
raise ValueError(f"Invalid JSON in tool call: {tool_call_json_str}")
|
| 218 |
+
|
| 219 |
+
if PATH_TO_LATEST_OUTPUT_JSON == "":
|
| 220 |
+
# The first tool call must be segment_phrase or report_no_mask
|
| 221 |
+
assert (
|
| 222 |
+
tool_call["name"] == "segment_phrase"
|
| 223 |
+
or tool_call["name"] == "report_no_mask"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
if tool_call["name"] == "segment_phrase":
|
| 227 |
+
print("🔍 Calling segment_phrase tool...")
|
| 228 |
+
assert list(tool_call["parameters"].keys()) == ["text_prompt"]
|
| 229 |
+
|
| 230 |
+
# Check if this text_prompt has been used before
|
| 231 |
+
current_text_prompt = tool_call["parameters"]["text_prompt"]
|
| 232 |
+
if current_text_prompt in USED_TEXT_PROMPTS:
|
| 233 |
+
print(
|
| 234 |
+
f"❌ Text prompt '{current_text_prompt}' has been used before. Requesting a different prompt."
|
| 235 |
+
)
|
| 236 |
+
duplicate_prompt_message = f"You have previously used '{current_text_prompt}' as your text_prompt to call the segment_phrase tool. You may not use it again. Please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase prompt, while adhering to all the rules stated in the system prompt. You must also never use any of the following text_prompt(s): {str(list(USED_TEXT_PROMPTS))}."
|
| 237 |
+
messages.append(
|
| 238 |
+
{
|
| 239 |
+
"role": "assistant",
|
| 240 |
+
"content": [{"type": "text", "text": generated_text}],
|
| 241 |
+
}
|
| 242 |
+
)
|
| 243 |
+
messages.append(
|
| 244 |
+
{
|
| 245 |
+
"role": "user",
|
| 246 |
+
"content": [{"type": "text", "text": duplicate_prompt_message}],
|
| 247 |
+
}
|
| 248 |
+
)
|
| 249 |
+
else:
|
| 250 |
+
# Add the text_prompt to the set of used prompts
|
| 251 |
+
USED_TEXT_PROMPTS.add(current_text_prompt)
|
| 252 |
+
LATEST_SAM3_TEXT_PROMPT = current_text_prompt
|
| 253 |
+
PATH_TO_LATEST_OUTPUT_JSON = call_sam_service(
|
| 254 |
+
image_path=img_path,
|
| 255 |
+
text_prompt=current_text_prompt,
|
| 256 |
+
output_folder_path=sam_output_dir,
|
| 257 |
+
)
|
| 258 |
+
sam3_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
|
| 259 |
+
sam3_output_image_path = sam3_outputs["output_image_path"]
|
| 260 |
+
num_masks = len(sam3_outputs["pred_boxes"])
|
| 261 |
+
|
| 262 |
+
messages.append(
|
| 263 |
+
{
|
| 264 |
+
"role": "assistant",
|
| 265 |
+
"content": [{"type": "text", "text": generated_text}],
|
| 266 |
+
}
|
| 267 |
+
)
|
| 268 |
+
if num_masks == 0:
|
| 269 |
+
print("❌ No masks generated by SAM3, reporting no mask to Qwen.")
|
| 270 |
+
sam3_output_text_message = f"The segment_phrase tool did not generate any masks for the text_prompt '{current_text_prompt}'. Now, please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase text_prompt, while adhering to all the rules stated in the system prompt. Please be reminded that the original user query was '{initial_text_prompt}'."
|
| 271 |
+
messages.append(
|
| 272 |
+
{
|
| 273 |
+
"role": "user",
|
| 274 |
+
"content": [
|
| 275 |
+
{"type": "text", "text": sam3_output_text_message}
|
| 276 |
+
],
|
| 277 |
+
}
|
| 278 |
+
)
|
| 279 |
+
else:
|
| 280 |
+
sam3_output_text_message = rf"The segment_phrase tool generated {num_masks} available masks. All {num_masks} available masks are rendered in this image below, now you must analyze the {num_masks} available mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action. Please be reminded that the original user query was '{initial_text_prompt}'."
|
| 281 |
+
messages.append(
|
| 282 |
+
{
|
| 283 |
+
"role": "user",
|
| 284 |
+
"content": [
|
| 285 |
+
{"type": "text", "text": sam3_output_text_message},
|
| 286 |
+
{"type": "image", "image": sam3_output_image_path},
|
| 287 |
+
],
|
| 288 |
+
}
|
| 289 |
+
)
|
| 290 |
+
print("\n\n>>> sam3_output_text_message:\n", sam3_output_text_message)
|
| 291 |
+
|
| 292 |
+
elif tool_call["name"] == "examine_each_mask":
|
| 293 |
+
print("🔍 Calling examine_each_mask tool...")
|
| 294 |
+
assert LATEST_SAM3_TEXT_PROMPT != ""
|
| 295 |
+
|
| 296 |
+
# Make sure that the last message is a image
|
| 297 |
+
assert (
|
| 298 |
+
messages[-1]["content"][1]["type"] == "image"
|
| 299 |
+
), "Second content element should be an image"
|
| 300 |
+
messages.pop() # Remove the last user message
|
| 301 |
+
# Add simplified replacement message
|
| 302 |
+
simplified_message = {
|
| 303 |
+
"role": "user",
|
| 304 |
+
"content": [
|
| 305 |
+
{
|
| 306 |
+
"type": "text",
|
| 307 |
+
"text": "The segment_phrase tool generated several masks. Now you must analyze the mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action.",
|
| 308 |
+
}
|
| 309 |
+
],
|
| 310 |
+
}
|
| 311 |
+
messages.append(simplified_message)
|
| 312 |
+
|
| 313 |
+
current_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
|
| 314 |
+
num_masks = len(current_outputs["pred_masks"])
|
| 315 |
+
masks_to_keep = []
|
| 316 |
+
|
| 317 |
+
# MLLM check the mask one by one
|
| 318 |
+
for i in range(num_masks):
|
| 319 |
+
print(f"🔍 Checking mask {i+1}/{num_masks}...")
|
| 320 |
+
image_w_mask_i, image_w_zoomed_in_mask_i = visualize(current_outputs, i)
|
| 321 |
+
|
| 322 |
+
image_w_zoomed_in_mask_i_path = os.path.join(
|
| 323 |
+
sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png".replace("/", "_")
|
| 324 |
+
).replace(".png", f"_zoom_in_mask_{i + 1}.png")
|
| 325 |
+
image_w_mask_i_path = os.path.join(
|
| 326 |
+
sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png".replace("/", "_")
|
| 327 |
+
).replace(".png", f"_selected_mask_{i + 1}.png")
|
| 328 |
+
image_w_zoomed_in_mask_i.save(image_w_zoomed_in_mask_i_path)
|
| 329 |
+
image_w_mask_i.save(image_w_mask_i_path)
|
| 330 |
+
|
| 331 |
+
iterative_checking_messages = [
|
| 332 |
+
{"role": "system", "content": iterative_checking_system_prompt},
|
| 333 |
+
{
|
| 334 |
+
"role": "user",
|
| 335 |
+
"content": [
|
| 336 |
+
{"type": "text", "text": f"The raw input image: "},
|
| 337 |
+
{"type": "image", "image": img_path},
|
| 338 |
+
{
|
| 339 |
+
"type": "text",
|
| 340 |
+
"text": f"The initial user input query is: '{initial_text_prompt}'",
|
| 341 |
+
},
|
| 342 |
+
{
|
| 343 |
+
"type": "text",
|
| 344 |
+
"text": f"Image with the predicted segmentation mask rendered on it: ",
|
| 345 |
+
},
|
| 346 |
+
{"type": "image", "image": image_w_mask_i_path},
|
| 347 |
+
{
|
| 348 |
+
"type": "text",
|
| 349 |
+
"text": f"Image with the zoomed-in mask: ",
|
| 350 |
+
},
|
| 351 |
+
{"type": "image", "image": image_w_zoomed_in_mask_i_path},
|
| 352 |
+
],
|
| 353 |
+
},
|
| 354 |
+
]
|
| 355 |
+
checking_generated_text = send_generate_request(
|
| 356 |
+
iterative_checking_messages
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
# Process the generated text to determine if the mask should be kept or rejected
|
| 360 |
+
if checking_generated_text is None:
|
| 361 |
+
raise ValueError(
|
| 362 |
+
"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters."
|
| 363 |
+
)
|
| 364 |
+
print(f"Generated text for mask {i+1}: {checking_generated_text}")
|
| 365 |
+
verdict = (
|
| 366 |
+
checking_generated_text.split("<verdict>")[-1]
|
| 367 |
+
.split("</verdict>")[0]
|
| 368 |
+
.strip()
|
| 369 |
+
)
|
| 370 |
+
if "Accept" in verdict:
|
| 371 |
+
assert not "Reject" in verdict
|
| 372 |
+
print(f"Mask {i+1} accepted, keeping it in the outputs.")
|
| 373 |
+
masks_to_keep.append(i)
|
| 374 |
+
elif "Reject" in verdict:
|
| 375 |
+
assert not "Accept" in verdict
|
| 376 |
+
print(f"Mask {i+1} rejected, removing it from the outputs.")
|
| 377 |
+
else:
|
| 378 |
+
raise ValueError(
|
| 379 |
+
f"Unexpected verdict in generated text: {checking_generated_text}. Expected 'Accept' or 'Reject'."
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
updated_outputs = {
|
| 383 |
+
"original_image_path": current_outputs["original_image_path"],
|
| 384 |
+
"orig_img_h": current_outputs["orig_img_h"],
|
| 385 |
+
"orig_img_w": current_outputs["orig_img_w"],
|
| 386 |
+
"pred_boxes": [current_outputs["pred_boxes"][i] for i in masks_to_keep],
|
| 387 |
+
"pred_scores": [
|
| 388 |
+
current_outputs["pred_scores"][i] for i in masks_to_keep
|
| 389 |
+
],
|
| 390 |
+
"pred_masks": [current_outputs["pred_masks"][i] for i in masks_to_keep],
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
image_w_check_masks = visualize(updated_outputs)
|
| 394 |
+
image_w_check_masks_path = os.path.join(
|
| 395 |
+
sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png"
|
| 396 |
+
).replace(
|
| 397 |
+
".png",
|
| 398 |
+
f"_selected_masks_{'-'.join(map(str, [i+1 for i in masks_to_keep]))}.png".replace(
|
| 399 |
+
"/", "_"
|
| 400 |
+
),
|
| 401 |
+
)
|
| 402 |
+
image_w_check_masks.save(image_w_check_masks_path)
|
| 403 |
+
# save the updated json outputs and append to message history
|
| 404 |
+
messages.append(
|
| 405 |
+
{
|
| 406 |
+
"role": "assistant",
|
| 407 |
+
"content": [{"type": "text", "text": generated_text}],
|
| 408 |
+
}
|
| 409 |
+
)
|
| 410 |
+
if len(masks_to_keep) == 0:
|
| 411 |
+
messages.append(
|
| 412 |
+
{
|
| 413 |
+
"role": "user",
|
| 414 |
+
"content": [
|
| 415 |
+
{
|
| 416 |
+
"type": "text",
|
| 417 |
+
"text": f"The original user query was: '{initial_text_prompt}'. The examine_each_mask tool examined and rejected all of the masks generated by the segment_phrase tool. Now, please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase text_prompt, while adhering to all the rules stated in the system prompt.",
|
| 418 |
+
}
|
| 419 |
+
],
|
| 420 |
+
}
|
| 421 |
+
)
|
| 422 |
+
else:
|
| 423 |
+
messages.append(
|
| 424 |
+
{
|
| 425 |
+
"role": "user",
|
| 426 |
+
"content": [
|
| 427 |
+
{
|
| 428 |
+
"type": "text",
|
| 429 |
+
"text": f"The original user query was: '{initial_text_prompt}'. After calling the examine_each_mask tool on the available masks, the number of available masks is now {len(masks_to_keep)}. All {len(masks_to_keep)} available masks are rendered in this image below, now you must analyze the {len(masks_to_keep)} available mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action.",
|
| 430 |
+
},
|
| 431 |
+
{"type": "image", "image": image_w_check_masks_path},
|
| 432 |
+
],
|
| 433 |
+
}
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
# Create a new filename based on the original path to avoid filename length issues
|
| 437 |
+
base_path = PATH_TO_LATEST_OUTPUT_JSON
|
| 438 |
+
# Remove any existing "masks_" suffix to avoid duplication
|
| 439 |
+
if "masks_" in base_path:
|
| 440 |
+
base_path = base_path.split("masks_")[0] + ".json"
|
| 441 |
+
# Create new filename with current masks; use a clearer suffix when empty
|
| 442 |
+
if len(masks_to_keep) == 0:
|
| 443 |
+
PATH_TO_LATEST_OUTPUT_JSON = base_path.replace(
|
| 444 |
+
".json", "masks_none.json"
|
| 445 |
+
)
|
| 446 |
+
else:
|
| 447 |
+
PATH_TO_LATEST_OUTPUT_JSON = base_path.replace(
|
| 448 |
+
".json", f"masks_{'_'.join(map(str, masks_to_keep))}.json"
|
| 449 |
+
)
|
| 450 |
+
json.dump(updated_outputs, open(PATH_TO_LATEST_OUTPUT_JSON, "w"), indent=4)
|
| 451 |
+
|
| 452 |
+
elif tool_call["name"] == "select_masks_and_return":
|
| 453 |
+
print("🔍 Calling select_masks_and_return tool...")
|
| 454 |
+
current_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
|
| 455 |
+
|
| 456 |
+
assert list(tool_call["parameters"].keys()) == ["final_answer_masks"]
|
| 457 |
+
masks_to_keep = tool_call["parameters"]["final_answer_masks"]
|
| 458 |
+
|
| 459 |
+
# Keep only valid mask indices, remove duplicates, and preserve deterministic ascending order
|
| 460 |
+
available_masks = set(range(1, len(current_outputs["pred_masks"]) + 1))
|
| 461 |
+
masks_to_keep = sorted({i for i in masks_to_keep if i in available_masks})
|
| 462 |
+
# Change this to a update message telling the model to try again along with information about errors made.
|
| 463 |
+
|
| 464 |
+
final_outputs = {
|
| 465 |
+
"original_image_path": current_outputs["original_image_path"],
|
| 466 |
+
"orig_img_h": current_outputs["orig_img_h"],
|
| 467 |
+
"orig_img_w": current_outputs["orig_img_w"],
|
| 468 |
+
"pred_boxes": [
|
| 469 |
+
current_outputs["pred_boxes"][i - 1] for i in masks_to_keep
|
| 470 |
+
],
|
| 471 |
+
"pred_scores": [
|
| 472 |
+
current_outputs["pred_scores"][i - 1] for i in masks_to_keep
|
| 473 |
+
],
|
| 474 |
+
"pred_masks": [
|
| 475 |
+
current_outputs["pred_masks"][i - 1] for i in masks_to_keep
|
| 476 |
+
],
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
rendered_final_output = visualize(final_outputs)
|
| 480 |
+
messages.append(
|
| 481 |
+
{
|
| 482 |
+
"role": "assistant",
|
| 483 |
+
"content": [{"type": "text", "text": generated_text}],
|
| 484 |
+
}
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
# Clean up debug files before successful return
|
| 488 |
+
cleanup_debug_files(debug, debug_folder_path, debug_jsonl_path)
|
| 489 |
+
return messages, final_outputs, rendered_final_output
|
| 490 |
+
|
| 491 |
+
elif tool_call["name"] == "report_no_mask":
|
| 492 |
+
print("🔍 Calling report_no_mask tool...")
|
| 493 |
+
height, width = cv2.imread(img_path).shape[:2]
|
| 494 |
+
final_outputs = {
|
| 495 |
+
"original_image_path": img_path,
|
| 496 |
+
"orig_img_h": height,
|
| 497 |
+
"orig_img_w": width,
|
| 498 |
+
"pred_boxes": [],
|
| 499 |
+
"pred_scores": [],
|
| 500 |
+
"pred_masks": [],
|
| 501 |
+
}
|
| 502 |
+
rendered_final_output = Image.open(img_path)
|
| 503 |
+
messages.append(
|
| 504 |
+
{
|
| 505 |
+
"role": "assistant",
|
| 506 |
+
"content": [{"type": "text", "text": generated_text}],
|
| 507 |
+
}
|
| 508 |
+
)
|
| 509 |
+
return messages, final_outputs, rendered_final_output
|
| 510 |
+
|
| 511 |
+
else:
|
| 512 |
+
raise ValueError(f"Unknown tool call: {tool_call['name']}")
|
| 513 |
+
|
| 514 |
+
# sometimes the MLLM don't know when to stop, and generates multiple tool calls in one round, so we need to split the generated text by </tool> and only keep the first one
|
| 515 |
+
|
| 516 |
+
for message in messages:
|
| 517 |
+
if message["role"] == "assistant" and "content" in message:
|
| 518 |
+
for content in message["content"]:
|
| 519 |
+
if (
|
| 520 |
+
isinstance(content, dict)
|
| 521 |
+
and content.get("type") == "text"
|
| 522 |
+
and "text" in content
|
| 523 |
+
):
|
| 524 |
+
content["text"] = (
|
| 525 |
+
content["text"].split("</tool>", 1)[0] + "</tool>\n\n"
|
| 526 |
+
)
|
| 527 |
+
# Prune the messages history before the next MLLM generation round according to the 3-part rules.
|
| 528 |
+
# This keeps history compact and ensures the model sees only the allowed parts.
|
| 529 |
+
messages = _prune_messages_for_next_round(
|
| 530 |
+
messages,
|
| 531 |
+
USED_TEXT_PROMPTS,
|
| 532 |
+
LATEST_SAM3_TEXT_PROMPT,
|
| 533 |
+
img_path,
|
| 534 |
+
initial_text_prompt,
|
| 535 |
+
)
|
| 536 |
+
# make sure there can never be more than 2 images in the context
|
| 537 |
+
assert count_images(messages) <= 2
|
| 538 |
+
generation_count += 1
|
| 539 |
+
if generation_count > max_generations:
|
| 540 |
+
raise ValueError(
|
| 541 |
+
f"Exceeded maximum number of allowed generation requests ({max_generations})"
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
print("\n\n")
|
| 545 |
+
print("-" * 30 + f" Round {str(generation_count + 1)}" + "-" * 30)
|
| 546 |
+
print("\n\n")
|
| 547 |
+
generated_text = send_generate_request(messages)
|
| 548 |
+
print(
|
| 549 |
+
f"\n>>> MLLM Response [start]\n{generated_text}\n<<< MLLM Response [end]\n"
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
print("\n\n>>> SAM 3 Agent execution ended.\n\n")
|
| 553 |
+
|
| 554 |
+
error_save_path = os.path.join(
|
| 555 |
+
error_save_dir,
|
| 556 |
+
f"{img_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]}_error_history.json",
|
| 557 |
+
)
|
| 558 |
+
with open(error_save_path, "w") as f:
|
| 559 |
+
json.dump(messages, f, indent=4)
|
| 560 |
+
print("Saved messages history that caused error to:", error_save_path)
|
| 561 |
+
raise ValueError(
|
| 562 |
+
rf"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters for image path: {img_path} and initial text prompt: {initial_text_prompt}."
|
| 563 |
+
)
|
sam3/agent/client_llm.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import os
|
| 5 |
+
from typing import Any, Optional
|
| 6 |
+
|
| 7 |
+
from openai import OpenAI
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_image_base64_and_mime(image_path):
|
| 11 |
+
"""Convert image file to base64 string and get MIME type"""
|
| 12 |
+
try:
|
| 13 |
+
# Get MIME type based on file extension
|
| 14 |
+
ext = os.path.splitext(image_path)[1].lower()
|
| 15 |
+
mime_types = {
|
| 16 |
+
".jpg": "image/jpeg",
|
| 17 |
+
".jpeg": "image/jpeg",
|
| 18 |
+
".png": "image/png",
|
| 19 |
+
".gif": "image/gif",
|
| 20 |
+
".webp": "image/webp",
|
| 21 |
+
".bmp": "image/bmp",
|
| 22 |
+
}
|
| 23 |
+
mime_type = mime_types.get(ext, "image/jpeg") # Default to JPEG
|
| 24 |
+
|
| 25 |
+
# Convert image to base64
|
| 26 |
+
with open(image_path, "rb") as image_file:
|
| 27 |
+
base64_data = base64.b64encode(image_file.read()).decode("utf-8")
|
| 28 |
+
return base64_data, mime_type
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print(f"Error converting image to base64: {e}")
|
| 31 |
+
return None, None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def send_generate_request(
|
| 35 |
+
messages,
|
| 36 |
+
server_url=None,
|
| 37 |
+
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
| 38 |
+
api_key=None,
|
| 39 |
+
max_tokens=4096,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Sends a request to the OpenAI-compatible API endpoint using the OpenAI client library.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
server_url (str): The base URL of the server, e.g. "http://127.0.0.1:8000"
|
| 46 |
+
messages (list): A list of message dicts, each containing role and content.
|
| 47 |
+
model (str): The model to use for generation (default: "llama-4")
|
| 48 |
+
max_tokens (int): Maximum number of tokens to generate (default: 4096)
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
str: The generated response text from the server.
|
| 52 |
+
"""
|
| 53 |
+
# Process messages to convert image paths to base64
|
| 54 |
+
processed_messages = []
|
| 55 |
+
for message in messages:
|
| 56 |
+
processed_message = message.copy()
|
| 57 |
+
if message["role"] == "user" and "content" in message:
|
| 58 |
+
processed_content = []
|
| 59 |
+
for c in message["content"]:
|
| 60 |
+
if isinstance(c, dict) and c.get("type") == "image":
|
| 61 |
+
# Convert image path to base64 format
|
| 62 |
+
image_path = c["image"]
|
| 63 |
+
|
| 64 |
+
print("image_path", image_path)
|
| 65 |
+
new_image_path = image_path.replace(
|
| 66 |
+
"?", "%3F"
|
| 67 |
+
) # Escape ? in the path
|
| 68 |
+
|
| 69 |
+
# Read the image file and convert to base64
|
| 70 |
+
try:
|
| 71 |
+
base64_image, mime_type = get_image_base64_and_mime(
|
| 72 |
+
new_image_path
|
| 73 |
+
)
|
| 74 |
+
if base64_image is None:
|
| 75 |
+
print(
|
| 76 |
+
f"Warning: Could not convert image to base64: {new_image_path}"
|
| 77 |
+
)
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
# Create the proper image_url structure with base64 data
|
| 81 |
+
processed_content.append(
|
| 82 |
+
{
|
| 83 |
+
"type": "image_url",
|
| 84 |
+
"image_url": {
|
| 85 |
+
"url": f"data:{mime_type};base64,{base64_image}",
|
| 86 |
+
"detail": "high",
|
| 87 |
+
},
|
| 88 |
+
}
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
except FileNotFoundError:
|
| 92 |
+
print(f"Warning: Image file not found: {new_image_path}")
|
| 93 |
+
continue
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f"Warning: Error processing image {new_image_path}: {e}")
|
| 96 |
+
continue
|
| 97 |
+
else:
|
| 98 |
+
processed_content.append(c)
|
| 99 |
+
|
| 100 |
+
processed_message["content"] = processed_content
|
| 101 |
+
processed_messages.append(processed_message)
|
| 102 |
+
|
| 103 |
+
# Create OpenAI client with custom base URL
|
| 104 |
+
client = OpenAI(api_key=api_key, base_url=server_url)
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
print(f"🔍 Calling model {model}...")
|
| 108 |
+
response = client.chat.completions.create(
|
| 109 |
+
model=model,
|
| 110 |
+
messages=processed_messages,
|
| 111 |
+
max_completion_tokens=max_tokens,
|
| 112 |
+
n=1,
|
| 113 |
+
)
|
| 114 |
+
# print(f"Received response: {response.choices[0].message}")
|
| 115 |
+
|
| 116 |
+
# Extract the response content
|
| 117 |
+
if response.choices and len(response.choices) > 0:
|
| 118 |
+
return response.choices[0].message.content
|
| 119 |
+
else:
|
| 120 |
+
print(f"Unexpected response format: {response}")
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f"Request failed: {e}")
|
| 125 |
+
return None
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def send_direct_request(
|
| 129 |
+
llm: Any,
|
| 130 |
+
messages: list[dict[str, Any]],
|
| 131 |
+
sampling_params: Any,
|
| 132 |
+
) -> Optional[str]:
|
| 133 |
+
"""
|
| 134 |
+
Run inference on a vLLM model instance directly without using a server.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
llm: Initialized vLLM LLM instance (passed from external initialization)
|
| 138 |
+
messages: List of message dicts with role and content (OpenAI format)
|
| 139 |
+
sampling_params: vLLM SamplingParams instance (initialized externally)
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
str: Generated response text, or None if inference fails
|
| 143 |
+
"""
|
| 144 |
+
try:
|
| 145 |
+
# Process messages to handle images (convert to base64 if needed)
|
| 146 |
+
processed_messages = []
|
| 147 |
+
for message in messages:
|
| 148 |
+
processed_message = message.copy()
|
| 149 |
+
if message["role"] == "user" and "content" in message:
|
| 150 |
+
processed_content = []
|
| 151 |
+
for c in message["content"]:
|
| 152 |
+
if isinstance(c, dict) and c.get("type") == "image":
|
| 153 |
+
# Convert image path to base64 format
|
| 154 |
+
image_path = c["image"]
|
| 155 |
+
new_image_path = image_path.replace("?", "%3F")
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
base64_image, mime_type = get_image_base64_and_mime(
|
| 159 |
+
new_image_path
|
| 160 |
+
)
|
| 161 |
+
if base64_image is None:
|
| 162 |
+
print(
|
| 163 |
+
f"Warning: Could not convert image: {new_image_path}"
|
| 164 |
+
)
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
# vLLM expects image_url format
|
| 168 |
+
processed_content.append(
|
| 169 |
+
{
|
| 170 |
+
"type": "image_url",
|
| 171 |
+
"image_url": {
|
| 172 |
+
"url": f"data:{mime_type};base64,{base64_image}"
|
| 173 |
+
},
|
| 174 |
+
}
|
| 175 |
+
)
|
| 176 |
+
except Exception as e:
|
| 177 |
+
print(
|
| 178 |
+
f"Warning: Error processing image {new_image_path}: {e}"
|
| 179 |
+
)
|
| 180 |
+
continue
|
| 181 |
+
else:
|
| 182 |
+
processed_content.append(c)
|
| 183 |
+
|
| 184 |
+
processed_message["content"] = processed_content
|
| 185 |
+
processed_messages.append(processed_message)
|
| 186 |
+
|
| 187 |
+
print("🔍 Running direct inference with vLLM...")
|
| 188 |
+
|
| 189 |
+
# Run inference using vLLM's chat interface
|
| 190 |
+
outputs = llm.chat(
|
| 191 |
+
messages=processed_messages,
|
| 192 |
+
sampling_params=sampling_params,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Extract the generated text from the first output
|
| 196 |
+
if outputs and len(outputs) > 0:
|
| 197 |
+
generated_text = outputs[0].outputs[0].text
|
| 198 |
+
return generated_text
|
| 199 |
+
else:
|
| 200 |
+
print(f"Unexpected output format: {outputs}")
|
| 201 |
+
return None
|
| 202 |
+
|
| 203 |
+
except Exception as e:
|
| 204 |
+
print(f"Direct inference failed: {e}")
|
| 205 |
+
return None
|
sam3/agent/client_sam3.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from sam3.model.box_ops import box_xyxy_to_xywh
|
| 10 |
+
from sam3.train.masks_ops import rle_encode
|
| 11 |
+
|
| 12 |
+
from .helpers.mask_overlap_removal import remove_overlapping_masks
|
| 13 |
+
from .viz import visualize
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def sam3_inference(processor, image_path, text_prompt):
|
| 17 |
+
"""Run SAM 3 image inference with text prompts and format the outputs"""
|
| 18 |
+
image = Image.open(image_path)
|
| 19 |
+
orig_img_w, orig_img_h = image.size
|
| 20 |
+
|
| 21 |
+
# model inference
|
| 22 |
+
inference_state = processor.set_image(image)
|
| 23 |
+
inference_state = processor.set_text_prompt(
|
| 24 |
+
state=inference_state, prompt=text_prompt
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# format and assemble outputs
|
| 28 |
+
pred_boxes_xyxy = torch.stack(
|
| 29 |
+
[
|
| 30 |
+
inference_state["boxes"][:, 0] / orig_img_w,
|
| 31 |
+
inference_state["boxes"][:, 1] / orig_img_h,
|
| 32 |
+
inference_state["boxes"][:, 2] / orig_img_w,
|
| 33 |
+
inference_state["boxes"][:, 3] / orig_img_h,
|
| 34 |
+
],
|
| 35 |
+
dim=-1,
|
| 36 |
+
) # normalized in range [0, 1]
|
| 37 |
+
pred_boxes_xywh = box_xyxy_to_xywh(pred_boxes_xyxy).tolist()
|
| 38 |
+
pred_masks = rle_encode(inference_state["masks"].squeeze(1))
|
| 39 |
+
pred_masks = [m["counts"] for m in pred_masks]
|
| 40 |
+
outputs = {
|
| 41 |
+
"orig_img_h": orig_img_h,
|
| 42 |
+
"orig_img_w": orig_img_w,
|
| 43 |
+
"pred_boxes": pred_boxes_xywh,
|
| 44 |
+
"pred_masks": pred_masks,
|
| 45 |
+
"pred_scores": inference_state["scores"].tolist(),
|
| 46 |
+
}
|
| 47 |
+
return outputs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def call_sam_service(
|
| 51 |
+
sam3_processor,
|
| 52 |
+
image_path: str,
|
| 53 |
+
text_prompt: str,
|
| 54 |
+
output_folder_path: str = "sam3_output",
|
| 55 |
+
):
|
| 56 |
+
"""
|
| 57 |
+
Loads an image, sends it with a text prompt to the service,
|
| 58 |
+
saves the results, and renders the visualization.
|
| 59 |
+
"""
|
| 60 |
+
print(f"📞 Loading image '{image_path}' and sending with prompt '{text_prompt}'...")
|
| 61 |
+
|
| 62 |
+
text_prompt_for_save_path = (
|
| 63 |
+
text_prompt.replace("/", "_") if "/" in text_prompt else text_prompt
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
os.makedirs(
|
| 67 |
+
os.path.join(output_folder_path, image_path.replace("/", "-")), exist_ok=True
|
| 68 |
+
)
|
| 69 |
+
output_json_path = os.path.join(
|
| 70 |
+
output_folder_path,
|
| 71 |
+
image_path.replace("/", "-"),
|
| 72 |
+
rf"{text_prompt_for_save_path}.json",
|
| 73 |
+
)
|
| 74 |
+
output_image_path = os.path.join(
|
| 75 |
+
output_folder_path,
|
| 76 |
+
image_path.replace("/", "-"),
|
| 77 |
+
rf"{text_prompt_for_save_path}.png",
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
# Send the image and text prompt as a multipart/form-data request
|
| 82 |
+
serialized_response = sam3_inference(sam3_processor, image_path, text_prompt)
|
| 83 |
+
|
| 84 |
+
# 1. Prepare the response dictionary
|
| 85 |
+
serialized_response = remove_overlapping_masks(serialized_response)
|
| 86 |
+
serialized_response = {
|
| 87 |
+
"original_image_path": image_path,
|
| 88 |
+
"output_image_path": output_image_path,
|
| 89 |
+
**serialized_response,
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
# 2. Reorder predictions by scores (highest to lowest) if scores are available
|
| 93 |
+
if "pred_scores" in serialized_response and serialized_response["pred_scores"]:
|
| 94 |
+
# Create indices sorted by scores in descending order
|
| 95 |
+
score_indices = sorted(
|
| 96 |
+
range(len(serialized_response["pred_scores"])),
|
| 97 |
+
key=lambda i: serialized_response["pred_scores"][i],
|
| 98 |
+
reverse=True,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Reorder all three lists based on the sorted indices
|
| 102 |
+
serialized_response["pred_scores"] = [
|
| 103 |
+
serialized_response["pred_scores"][i] for i in score_indices
|
| 104 |
+
]
|
| 105 |
+
serialized_response["pred_boxes"] = [
|
| 106 |
+
serialized_response["pred_boxes"][i] for i in score_indices
|
| 107 |
+
]
|
| 108 |
+
serialized_response["pred_masks"] = [
|
| 109 |
+
serialized_response["pred_masks"][i] for i in score_indices
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
# 3. Remove any invalid RLE masks that is too short (shorter than 5 characters)
|
| 113 |
+
valid_masks = []
|
| 114 |
+
valid_boxes = []
|
| 115 |
+
valid_scores = []
|
| 116 |
+
for i, rle in enumerate(serialized_response["pred_masks"]):
|
| 117 |
+
if len(rle) > 4:
|
| 118 |
+
valid_masks.append(rle)
|
| 119 |
+
valid_boxes.append(serialized_response["pred_boxes"][i])
|
| 120 |
+
valid_scores.append(serialized_response["pred_scores"][i])
|
| 121 |
+
serialized_response["pred_masks"] = valid_masks
|
| 122 |
+
serialized_response["pred_boxes"] = valid_boxes
|
| 123 |
+
serialized_response["pred_scores"] = valid_scores
|
| 124 |
+
|
| 125 |
+
with open(output_json_path, "w") as f:
|
| 126 |
+
json.dump(serialized_response, f, indent=4)
|
| 127 |
+
print(f"✅ Raw JSON response saved to '{output_json_path}'")
|
| 128 |
+
|
| 129 |
+
# 4. Render and save visualizations on the image and save it in the SAM3 output folder
|
| 130 |
+
print("🔍 Rendering visualizations on the image ...")
|
| 131 |
+
viz_image = visualize(serialized_response)
|
| 132 |
+
os.makedirs(os.path.dirname(output_image_path), exist_ok=True)
|
| 133 |
+
viz_image.save(output_image_path)
|
| 134 |
+
print("✅ Saved visualization at:", output_image_path)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"❌ Error calling service: {e}")
|
| 137 |
+
|
| 138 |
+
return output_json_path
|
sam3/agent/helpers/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
sam3/agent/helpers/boxes.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from enum import IntEnum, unique
|
| 5 |
+
from typing import List, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch import device
|
| 10 |
+
|
| 11 |
+
_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@unique
|
| 15 |
+
class BoxMode(IntEnum):
|
| 16 |
+
"""
|
| 17 |
+
Enum of different ways to represent a box.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
XYXY_ABS = 0
|
| 21 |
+
"""
|
| 22 |
+
(x0, y0, x1, y1) in absolute floating points coordinates.
|
| 23 |
+
The coordinates in range [0, width or height].
|
| 24 |
+
"""
|
| 25 |
+
XYWH_ABS = 1
|
| 26 |
+
"""
|
| 27 |
+
(x0, y0, w, h) in absolute floating points coordinates.
|
| 28 |
+
"""
|
| 29 |
+
XYXY_REL = 2
|
| 30 |
+
"""
|
| 31 |
+
Not yet supported!
|
| 32 |
+
(x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
|
| 33 |
+
"""
|
| 34 |
+
XYWH_REL = 3
|
| 35 |
+
"""
|
| 36 |
+
Not yet supported!
|
| 37 |
+
(x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
|
| 38 |
+
"""
|
| 39 |
+
XYWHA_ABS = 4
|
| 40 |
+
"""
|
| 41 |
+
(xc, yc, w, h, a) in absolute floating points coordinates.
|
| 42 |
+
(xc, yc) is the center of the rotated box, and the angle a is in degrees ccw.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def convert(
|
| 47 |
+
box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode"
|
| 48 |
+
) -> _RawBoxType:
|
| 49 |
+
"""
|
| 50 |
+
Args:
|
| 51 |
+
box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5
|
| 52 |
+
from_mode, to_mode (BoxMode)
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
The converted box of the same type.
|
| 56 |
+
"""
|
| 57 |
+
if from_mode == to_mode:
|
| 58 |
+
return box
|
| 59 |
+
|
| 60 |
+
original_type = type(box)
|
| 61 |
+
is_numpy = isinstance(box, np.ndarray)
|
| 62 |
+
single_box = isinstance(box, (list, tuple))
|
| 63 |
+
if single_box:
|
| 64 |
+
assert len(box) == 4 or len(box) == 5, (
|
| 65 |
+
"BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor,"
|
| 66 |
+
" where k == 4 or 5"
|
| 67 |
+
)
|
| 68 |
+
arr = torch.tensor(box)[None, :]
|
| 69 |
+
else:
|
| 70 |
+
# avoid modifying the input box
|
| 71 |
+
if is_numpy:
|
| 72 |
+
arr = torch.from_numpy(np.asarray(box)).clone()
|
| 73 |
+
else:
|
| 74 |
+
arr = box.clone()
|
| 75 |
+
|
| 76 |
+
assert to_mode not in [
|
| 77 |
+
BoxMode.XYXY_REL,
|
| 78 |
+
BoxMode.XYWH_REL,
|
| 79 |
+
] and from_mode not in [
|
| 80 |
+
BoxMode.XYXY_REL,
|
| 81 |
+
BoxMode.XYWH_REL,
|
| 82 |
+
], "Relative mode not yet supported!"
|
| 83 |
+
|
| 84 |
+
if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
|
| 85 |
+
assert (
|
| 86 |
+
arr.shape[-1] == 5
|
| 87 |
+
), "The last dimension of input shape must be 5 for XYWHA format"
|
| 88 |
+
original_dtype = arr.dtype
|
| 89 |
+
arr = arr.double()
|
| 90 |
+
|
| 91 |
+
w = arr[:, 2]
|
| 92 |
+
h = arr[:, 3]
|
| 93 |
+
a = arr[:, 4]
|
| 94 |
+
c = torch.abs(torch.cos(a * math.pi / 180.0))
|
| 95 |
+
s = torch.abs(torch.sin(a * math.pi / 180.0))
|
| 96 |
+
# This basically computes the horizontal bounding rectangle of the rotated box
|
| 97 |
+
new_w = c * w + s * h
|
| 98 |
+
new_h = c * h + s * w
|
| 99 |
+
|
| 100 |
+
# convert center to top-left corner
|
| 101 |
+
arr[:, 0] -= new_w / 2.0
|
| 102 |
+
arr[:, 1] -= new_h / 2.0
|
| 103 |
+
# bottom-right corner
|
| 104 |
+
arr[:, 2] = arr[:, 0] + new_w
|
| 105 |
+
arr[:, 3] = arr[:, 1] + new_h
|
| 106 |
+
|
| 107 |
+
arr = arr[:, :4].to(dtype=original_dtype)
|
| 108 |
+
elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS:
|
| 109 |
+
original_dtype = arr.dtype
|
| 110 |
+
arr = arr.double()
|
| 111 |
+
arr[:, 0] += arr[:, 2] / 2.0
|
| 112 |
+
arr[:, 1] += arr[:, 3] / 2.0
|
| 113 |
+
angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype)
|
| 114 |
+
arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype)
|
| 115 |
+
else:
|
| 116 |
+
if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
|
| 117 |
+
arr[:, 2] += arr[:, 0]
|
| 118 |
+
arr[:, 3] += arr[:, 1]
|
| 119 |
+
elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
|
| 120 |
+
arr[:, 2] -= arr[:, 0]
|
| 121 |
+
arr[:, 3] -= arr[:, 1]
|
| 122 |
+
else:
|
| 123 |
+
raise NotImplementedError(
|
| 124 |
+
"Conversion from BoxMode {} to {} is not supported yet".format(
|
| 125 |
+
from_mode, to_mode
|
| 126 |
+
)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if single_box:
|
| 130 |
+
return original_type(arr.flatten().tolist())
|
| 131 |
+
if is_numpy:
|
| 132 |
+
return arr.numpy()
|
| 133 |
+
else:
|
| 134 |
+
return arr
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class Boxes:
|
| 138 |
+
"""
|
| 139 |
+
This structure stores a list of boxes as a Nx4 torch.Tensor.
|
| 140 |
+
It supports some common methods about boxes
|
| 141 |
+
(`area`, `clip`, `nonempty`, etc),
|
| 142 |
+
and also behaves like a Tensor
|
| 143 |
+
(support indexing, `to(device)`, `.device`, and iteration over all boxes)
|
| 144 |
+
|
| 145 |
+
Attributes:
|
| 146 |
+
tensor (torch.Tensor): float matrix of Nx4. Each row is (x1, y1, x2, y2).
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def __init__(self, tensor: torch.Tensor):
|
| 150 |
+
"""
|
| 151 |
+
Args:
|
| 152 |
+
tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2).
|
| 153 |
+
"""
|
| 154 |
+
if not isinstance(tensor, torch.Tensor):
|
| 155 |
+
tensor = torch.as_tensor(
|
| 156 |
+
tensor, dtype=torch.float32, device=torch.device("cpu")
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
tensor = tensor.to(torch.float32)
|
| 160 |
+
if tensor.numel() == 0:
|
| 161 |
+
# Use reshape, so we don't end up creating a new tensor that does not depend on
|
| 162 |
+
# the inputs (and consequently confuses jit)
|
| 163 |
+
tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32)
|
| 164 |
+
assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
|
| 165 |
+
|
| 166 |
+
self.tensor = tensor
|
| 167 |
+
|
| 168 |
+
def clone(self) -> "Boxes":
|
| 169 |
+
"""
|
| 170 |
+
Clone the Boxes.
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Boxes
|
| 174 |
+
"""
|
| 175 |
+
return Boxes(self.tensor.clone())
|
| 176 |
+
|
| 177 |
+
def to(self, device: torch.device):
|
| 178 |
+
# Boxes are assumed float32 and does not support to(dtype)
|
| 179 |
+
return Boxes(self.tensor.to(device=device))
|
| 180 |
+
|
| 181 |
+
def area(self) -> torch.Tensor:
|
| 182 |
+
"""
|
| 183 |
+
Computes the area of all the boxes.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
torch.Tensor: a vector with areas of each box.
|
| 187 |
+
"""
|
| 188 |
+
box = self.tensor
|
| 189 |
+
area = (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
|
| 190 |
+
return area
|
| 191 |
+
|
| 192 |
+
def clip(self, box_size: Tuple[int, int]) -> None:
|
| 193 |
+
"""
|
| 194 |
+
Clip (in place) the boxes by limiting x coordinates to the range [0, width]
|
| 195 |
+
and y coordinates to the range [0, height].
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
box_size (height, width): The clipping box's size.
|
| 199 |
+
"""
|
| 200 |
+
assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
|
| 201 |
+
h, w = box_size
|
| 202 |
+
x1 = self.tensor[:, 0].clamp(min=0, max=w)
|
| 203 |
+
y1 = self.tensor[:, 1].clamp(min=0, max=h)
|
| 204 |
+
x2 = self.tensor[:, 2].clamp(min=0, max=w)
|
| 205 |
+
y2 = self.tensor[:, 3].clamp(min=0, max=h)
|
| 206 |
+
self.tensor = torch.stack((x1, y1, x2, y2), dim=-1)
|
| 207 |
+
|
| 208 |
+
def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
|
| 209 |
+
"""
|
| 210 |
+
Find boxes that are non-empty.
|
| 211 |
+
A box is considered empty, if either of its side is no larger than threshold.
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Tensor:
|
| 215 |
+
a binary vector which represents whether each box is empty
|
| 216 |
+
(False) or non-empty (True).
|
| 217 |
+
"""
|
| 218 |
+
box = self.tensor
|
| 219 |
+
widths = box[:, 2] - box[:, 0]
|
| 220 |
+
heights = box[:, 3] - box[:, 1]
|
| 221 |
+
keep = (widths > threshold) & (heights > threshold)
|
| 222 |
+
return keep
|
| 223 |
+
|
| 224 |
+
def __getitem__(self, item) -> "Boxes":
|
| 225 |
+
"""
|
| 226 |
+
Args:
|
| 227 |
+
item: int, slice, or a BoolTensor
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
Boxes: Create a new :class:`Boxes` by indexing.
|
| 231 |
+
|
| 232 |
+
The following usage are allowed:
|
| 233 |
+
|
| 234 |
+
1. `new_boxes = boxes[3]`: return a `Boxes` which contains only one box.
|
| 235 |
+
2. `new_boxes = boxes[2:10]`: return a slice of boxes.
|
| 236 |
+
3. `new_boxes = boxes[vector]`, where vector is a torch.BoolTensor
|
| 237 |
+
with `length = len(boxes)`. Nonzero elements in the vector will be selected.
|
| 238 |
+
|
| 239 |
+
Note that the returned Boxes might share storage with this Boxes,
|
| 240 |
+
subject to Pytorch's indexing semantics.
|
| 241 |
+
"""
|
| 242 |
+
if isinstance(item, int):
|
| 243 |
+
return Boxes(self.tensor[item].view(1, -1))
|
| 244 |
+
b = self.tensor[item]
|
| 245 |
+
assert (
|
| 246 |
+
b.dim() == 2
|
| 247 |
+
), "Indexing on Boxes with {} failed to return a matrix!".format(item)
|
| 248 |
+
return Boxes(b)
|
| 249 |
+
|
| 250 |
+
def __len__(self) -> int:
|
| 251 |
+
return self.tensor.shape[0]
|
| 252 |
+
|
| 253 |
+
def __repr__(self) -> str:
|
| 254 |
+
return "Boxes(" + str(self.tensor) + ")"
|
| 255 |
+
|
| 256 |
+
def inside_box(
|
| 257 |
+
self, box_size: Tuple[int, int], boundary_threshold: int = 0
|
| 258 |
+
) -> torch.Tensor:
|
| 259 |
+
"""
|
| 260 |
+
Args:
|
| 261 |
+
box_size (height, width): Size of the reference box.
|
| 262 |
+
boundary_threshold (int): Boxes that extend beyond the reference box
|
| 263 |
+
boundary by more than boundary_threshold are considered "outside".
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
a binary vector, indicating whether each box is inside the reference box.
|
| 267 |
+
"""
|
| 268 |
+
height, width = box_size
|
| 269 |
+
inds_inside = (
|
| 270 |
+
(self.tensor[..., 0] >= -boundary_threshold)
|
| 271 |
+
& (self.tensor[..., 1] >= -boundary_threshold)
|
| 272 |
+
& (self.tensor[..., 2] < width + boundary_threshold)
|
| 273 |
+
& (self.tensor[..., 3] < height + boundary_threshold)
|
| 274 |
+
)
|
| 275 |
+
return inds_inside
|
| 276 |
+
|
| 277 |
+
def get_centers(self) -> torch.Tensor:
|
| 278 |
+
"""
|
| 279 |
+
Returns:
|
| 280 |
+
The box centers in a Nx2 array of (x, y).
|
| 281 |
+
"""
|
| 282 |
+
return (self.tensor[:, :2] + self.tensor[:, 2:]) / 2
|
| 283 |
+
|
| 284 |
+
def scale(self, scale_x: float, scale_y: float) -> None:
|
| 285 |
+
"""
|
| 286 |
+
Scale the box with horizontal and vertical scaling factors
|
| 287 |
+
"""
|
| 288 |
+
self.tensor[:, 0::2] *= scale_x
|
| 289 |
+
self.tensor[:, 1::2] *= scale_y
|
| 290 |
+
|
| 291 |
+
@classmethod
|
| 292 |
+
def cat(cls, boxes_list: List["Boxes"]) -> "Boxes":
|
| 293 |
+
"""
|
| 294 |
+
Concatenates a list of Boxes into a single Boxes
|
| 295 |
+
|
| 296 |
+
Arguments:
|
| 297 |
+
boxes_list (list[Boxes])
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
Boxes: the concatenated Boxes
|
| 301 |
+
"""
|
| 302 |
+
assert isinstance(boxes_list, (list, tuple))
|
| 303 |
+
if len(boxes_list) == 0:
|
| 304 |
+
return cls(torch.empty(0))
|
| 305 |
+
assert all([isinstance(box, Boxes) for box in boxes_list])
|
| 306 |
+
|
| 307 |
+
# use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input
|
| 308 |
+
cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0))
|
| 309 |
+
return cat_boxes
|
| 310 |
+
|
| 311 |
+
@property
|
| 312 |
+
def device(self) -> device:
|
| 313 |
+
return self.tensor.device
|
| 314 |
+
|
| 315 |
+
# type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript
|
| 316 |
+
# https://github.com/pytorch/pytorch/issues/18627
|
| 317 |
+
@torch.jit.unused
|
| 318 |
+
def __iter__(self):
|
| 319 |
+
"""
|
| 320 |
+
Yield a box as a Tensor of shape (4,) at a time.
|
| 321 |
+
"""
|
| 322 |
+
yield from self.tensor
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def pairwise_intersection(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
|
| 326 |
+
"""
|
| 327 |
+
Given two lists of boxes of size N and M,
|
| 328 |
+
compute the intersection area between __all__ N x M pairs of boxes.
|
| 329 |
+
The box order must be (xmin, ymin, xmax, ymax)
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
|
| 333 |
+
|
| 334 |
+
Returns:
|
| 335 |
+
Tensor: intersection, sized [N,M].
|
| 336 |
+
"""
|
| 337 |
+
boxes1, boxes2 = boxes1.tensor, boxes2.tensor
|
| 338 |
+
width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max(
|
| 339 |
+
boxes1[:, None, :2], boxes2[:, :2]
|
| 340 |
+
) # [N,M,2]
|
| 341 |
+
|
| 342 |
+
width_height.clamp_(min=0) # [N,M,2]
|
| 343 |
+
intersection = width_height.prod(dim=2) # [N,M]
|
| 344 |
+
return intersection
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
|
| 348 |
+
# with slight modifications
|
| 349 |
+
def pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
|
| 350 |
+
"""
|
| 351 |
+
Given two lists of boxes of size N and M, compute the IoU
|
| 352 |
+
(intersection over union) between **all** N x M pairs of boxes.
|
| 353 |
+
The box order must be (xmin, ymin, xmax, ymax).
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
Tensor: IoU, sized [N,M].
|
| 360 |
+
"""
|
| 361 |
+
area1 = boxes1.area() # [N]
|
| 362 |
+
area2 = boxes2.area() # [M]
|
| 363 |
+
inter = pairwise_intersection(boxes1, boxes2)
|
| 364 |
+
|
| 365 |
+
# handle empty boxes
|
| 366 |
+
iou = torch.where(
|
| 367 |
+
inter > 0,
|
| 368 |
+
inter / (area1[:, None] + area2 - inter),
|
| 369 |
+
torch.zeros(1, dtype=inter.dtype, device=inter.device),
|
| 370 |
+
)
|
| 371 |
+
return iou
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def pairwise_ioa(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
|
| 375 |
+
"""
|
| 376 |
+
Similar to :func:`pariwise_iou` but compute the IoA (intersection over boxes2 area).
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
Tensor: IoA, sized [N,M].
|
| 383 |
+
"""
|
| 384 |
+
area2 = boxes2.area() # [M]
|
| 385 |
+
inter = pairwise_intersection(boxes1, boxes2)
|
| 386 |
+
|
| 387 |
+
# handle empty boxes
|
| 388 |
+
ioa = torch.where(
|
| 389 |
+
inter > 0, inter / area2, torch.zeros(1, dtype=inter.dtype, device=inter.device)
|
| 390 |
+
)
|
| 391 |
+
return ioa
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def pairwise_point_box_distance(points: torch.Tensor, boxes: Boxes):
|
| 395 |
+
"""
|
| 396 |
+
Pairwise distance between N points and M boxes. The distance between a
|
| 397 |
+
point and a box is represented by the distance from the point to 4 edges
|
| 398 |
+
of the box. Distances are all positive when the point is inside the box.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
points: Nx2 coordinates. Each row is (x, y)
|
| 402 |
+
boxes: M boxes
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
Tensor: distances of size (N, M, 4). The 4 values are distances from
|
| 406 |
+
the point to the left, top, right, bottom of the box.
|
| 407 |
+
"""
|
| 408 |
+
x, y = points.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
|
| 409 |
+
x0, y0, x1, y1 = boxes.tensor.unsqueeze(dim=0).unbind(dim=2) # (1, M)
|
| 410 |
+
return torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
|
| 414 |
+
"""
|
| 415 |
+
Compute pairwise intersection over union (IOU) of two sets of matched
|
| 416 |
+
boxes that have the same number of boxes.
|
| 417 |
+
Similar to :func:`pairwise_iou`, but computes only diagonal elements of the matrix.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
boxes1 (Boxes): bounding boxes, sized [N,4].
|
| 421 |
+
boxes2 (Boxes): same length as boxes1
|
| 422 |
+
Returns:
|
| 423 |
+
Tensor: iou, sized [N].
|
| 424 |
+
"""
|
| 425 |
+
assert len(boxes1) == len(boxes2), (
|
| 426 |
+
"boxlists should have the same" "number of entries, got {}, {}".format(
|
| 427 |
+
len(boxes1), len(boxes2)
|
| 428 |
+
)
|
| 429 |
+
)
|
| 430 |
+
area1 = boxes1.area() # [N]
|
| 431 |
+
area2 = boxes2.area() # [N]
|
| 432 |
+
box1, box2 = boxes1.tensor, boxes2.tensor
|
| 433 |
+
lt = torch.max(box1[:, :2], box2[:, :2]) # [N,2]
|
| 434 |
+
rb = torch.min(box1[:, 2:], box2[:, 2:]) # [N,2]
|
| 435 |
+
wh = (rb - lt).clamp(min=0) # [N,2]
|
| 436 |
+
inter = wh[:, 0] * wh[:, 1] # [N]
|
| 437 |
+
iou = inter / (area1 + area2 - inter) # [N]
|
| 438 |
+
return iou
|
sam3/agent/helpers/color_map.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
An awesome colormap for really neat visualizations.
|
| 5 |
+
Copied from Detectron, and removed gray colors.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
__all__ = ["colormap", "random_color", "random_colors"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# A list of 25 bright and sharp colors for segmentation masks,
|
| 16 |
+
# generated from the edges of the sRGB color space for maximum intensity.
|
| 17 |
+
_COLORS = (
|
| 18 |
+
np.array(
|
| 19 |
+
[
|
| 20 |
+
# The original 8 sharp colors
|
| 21 |
+
1.000,
|
| 22 |
+
1.000,
|
| 23 |
+
0.000, # 1. Yellow
|
| 24 |
+
0.000,
|
| 25 |
+
1.000,
|
| 26 |
+
0.000, # 2. Lime
|
| 27 |
+
0.000,
|
| 28 |
+
1.000,
|
| 29 |
+
1.000, # 3. Cyan
|
| 30 |
+
1.000,
|
| 31 |
+
0.000,
|
| 32 |
+
1.000, # 4. Magenta
|
| 33 |
+
1.000,
|
| 34 |
+
0.000,
|
| 35 |
+
0.000, # 5. Red
|
| 36 |
+
1.000,
|
| 37 |
+
0.498,
|
| 38 |
+
0.000, # 6. Orange
|
| 39 |
+
0.498,
|
| 40 |
+
1.000,
|
| 41 |
+
0.000, # 7. Chartreuse
|
| 42 |
+
0.000,
|
| 43 |
+
1.000,
|
| 44 |
+
0.498, # 8. Spring Green
|
| 45 |
+
1.000,
|
| 46 |
+
0.000,
|
| 47 |
+
0.498, # 9. Rose
|
| 48 |
+
0.498,
|
| 49 |
+
0.000,
|
| 50 |
+
1.000, # 10. Violet
|
| 51 |
+
0.753,
|
| 52 |
+
1.000,
|
| 53 |
+
0.000, # 11. Electric Lime
|
| 54 |
+
1.000,
|
| 55 |
+
0.753,
|
| 56 |
+
0.000, # 12. Vivid Orange
|
| 57 |
+
0.000,
|
| 58 |
+
1.000,
|
| 59 |
+
0.753, # 13. Turquoise
|
| 60 |
+
0.753,
|
| 61 |
+
0.000,
|
| 62 |
+
1.000, # 14. Bright Violet
|
| 63 |
+
1.000,
|
| 64 |
+
0.000,
|
| 65 |
+
0.753, # 15. Bright Pink
|
| 66 |
+
1.000,
|
| 67 |
+
0.251,
|
| 68 |
+
0.000, # 16. Fiery Orange
|
| 69 |
+
0.251,
|
| 70 |
+
1.000,
|
| 71 |
+
0.000, # 17. Bright Chartreuse
|
| 72 |
+
0.000,
|
| 73 |
+
1.000,
|
| 74 |
+
0.251, # 18. Malachite Green
|
| 75 |
+
0.251,
|
| 76 |
+
0.000,
|
| 77 |
+
1.000, # 19. Deep Violet
|
| 78 |
+
1.000,
|
| 79 |
+
0.000,
|
| 80 |
+
0.251, # 20. Hot Pink
|
| 81 |
+
]
|
| 82 |
+
)
|
| 83 |
+
.astype(np.float32)
|
| 84 |
+
.reshape(-1, 3)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def colormap(rgb=False, maximum=255):
|
| 89 |
+
"""
|
| 90 |
+
Args:
|
| 91 |
+
rgb (bool): whether to return RGB colors or BGR colors.
|
| 92 |
+
maximum (int): either 255 or 1
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1]
|
| 96 |
+
"""
|
| 97 |
+
assert maximum in [255, 1], maximum
|
| 98 |
+
c = _COLORS * maximum
|
| 99 |
+
if not rgb:
|
| 100 |
+
c = c[:, ::-1]
|
| 101 |
+
return c
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def random_color(rgb=False, maximum=255):
|
| 105 |
+
"""
|
| 106 |
+
Args:
|
| 107 |
+
rgb (bool): whether to return RGB colors or BGR colors.
|
| 108 |
+
maximum (int): either 255 or 1
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
ndarray: a vector of 3 numbers
|
| 112 |
+
"""
|
| 113 |
+
idx = np.random.randint(0, len(_COLORS))
|
| 114 |
+
ret = _COLORS[idx] * maximum
|
| 115 |
+
if not rgb:
|
| 116 |
+
ret = ret[::-1]
|
| 117 |
+
return ret
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def random_colors(N, rgb=False, maximum=255):
|
| 121 |
+
"""
|
| 122 |
+
Args:
|
| 123 |
+
N (int): number of unique colors needed
|
| 124 |
+
rgb (bool): whether to return RGB colors or BGR colors.
|
| 125 |
+
maximum (int): either 255 or 1
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
ndarray: a list of random_color
|
| 129 |
+
"""
|
| 130 |
+
indices = random.sample(range(len(_COLORS)), N)
|
| 131 |
+
ret = [_COLORS[i] * maximum for i in indices]
|
| 132 |
+
if not rgb:
|
| 133 |
+
ret = [x[::-1] for x in ret]
|
| 134 |
+
return ret
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
import cv2
|
| 139 |
+
|
| 140 |
+
size = 100
|
| 141 |
+
H, W = 10, 10
|
| 142 |
+
canvas = np.random.rand(H * size, W * size, 3).astype("float32")
|
| 143 |
+
for h in range(H):
|
| 144 |
+
for w in range(W):
|
| 145 |
+
idx = h * W + w
|
| 146 |
+
if idx >= len(_COLORS):
|
| 147 |
+
break
|
| 148 |
+
canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx]
|
| 149 |
+
cv2.imshow("a", canvas)
|
| 150 |
+
cv2.waitKey(0)
|
sam3/agent/helpers/keypoints.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
from typing import Any, List, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Keypoints:
|
| 11 |
+
"""
|
| 12 |
+
Stores keypoint **annotation** data. GT Instances have a `gt_keypoints` property
|
| 13 |
+
containing the x,y location and visibility flag of each keypoint. This tensor has shape
|
| 14 |
+
(N, K, 3) where N is the number of instances and K is the number of keypoints per instance.
|
| 15 |
+
|
| 16 |
+
The visibility flag follows the COCO format and must be one of three integers:
|
| 17 |
+
|
| 18 |
+
* v=0: not labeled (in which case x=y=0)
|
| 19 |
+
* v=1: labeled but not visible
|
| 20 |
+
* v=2: labeled and visible
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]):
|
| 24 |
+
"""
|
| 25 |
+
Arguments:
|
| 26 |
+
keypoints: A Tensor, numpy array, or list of the x, y, and visibility of each keypoint.
|
| 27 |
+
The shape should be (N, K, 3) where N is the number of
|
| 28 |
+
instances, and K is the number of keypoints per instance.
|
| 29 |
+
"""
|
| 30 |
+
device = (
|
| 31 |
+
keypoints.device
|
| 32 |
+
if isinstance(keypoints, torch.Tensor)
|
| 33 |
+
else torch.device("cpu")
|
| 34 |
+
)
|
| 35 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device)
|
| 36 |
+
assert keypoints.dim() == 3 and keypoints.shape[2] == 3, keypoints.shape
|
| 37 |
+
self.tensor = keypoints
|
| 38 |
+
|
| 39 |
+
def __len__(self) -> int:
|
| 40 |
+
return self.tensor.size(0)
|
| 41 |
+
|
| 42 |
+
def to(self, *args: Any, **kwargs: Any) -> "Keypoints":
|
| 43 |
+
return type(self)(self.tensor.to(*args, **kwargs))
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def device(self) -> torch.device:
|
| 47 |
+
return self.tensor.device
|
| 48 |
+
|
| 49 |
+
def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor:
|
| 50 |
+
"""
|
| 51 |
+
Convert keypoint annotations to a heatmap of one-hot labels for training,
|
| 52 |
+
as described in :paper:`Mask R-CNN`.
|
| 53 |
+
|
| 54 |
+
Arguments:
|
| 55 |
+
boxes: Nx4 tensor, the boxes to draw the keypoints to
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
heatmaps:
|
| 59 |
+
A tensor of shape (N, K), each element is integer spatial label
|
| 60 |
+
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
|
| 61 |
+
valid:
|
| 62 |
+
A tensor of shape (N, K) containing whether each keypoint is in the roi or not.
|
| 63 |
+
"""
|
| 64 |
+
return _keypoints_to_heatmap(self.tensor, boxes, heatmap_size)
|
| 65 |
+
|
| 66 |
+
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints":
|
| 67 |
+
"""
|
| 68 |
+
Create a new `Keypoints` by indexing on this `Keypoints`.
|
| 69 |
+
|
| 70 |
+
The following usage are allowed:
|
| 71 |
+
|
| 72 |
+
1. `new_kpts = kpts[3]`: return a `Keypoints` which contains only one instance.
|
| 73 |
+
2. `new_kpts = kpts[2:10]`: return a slice of key points.
|
| 74 |
+
3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor
|
| 75 |
+
with `length = len(kpts)`. Nonzero elements in the vector will be selected.
|
| 76 |
+
|
| 77 |
+
Note that the returned Keypoints might share storage with this Keypoints,
|
| 78 |
+
subject to Pytorch's indexing semantics.
|
| 79 |
+
"""
|
| 80 |
+
if isinstance(item, int):
|
| 81 |
+
return Keypoints([self.tensor[item]])
|
| 82 |
+
return Keypoints(self.tensor[item])
|
| 83 |
+
|
| 84 |
+
def __repr__(self) -> str:
|
| 85 |
+
s = self.__class__.__name__ + "("
|
| 86 |
+
s += "num_instances={})".format(len(self.tensor))
|
| 87 |
+
return s
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
def cat(keypoints_list: List["Keypoints"]) -> "Keypoints":
|
| 91 |
+
"""
|
| 92 |
+
Concatenates a list of Keypoints into a single Keypoints
|
| 93 |
+
|
| 94 |
+
Arguments:
|
| 95 |
+
keypoints_list (list[Keypoints])
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Keypoints: the concatenated Keypoints
|
| 99 |
+
"""
|
| 100 |
+
assert isinstance(keypoints_list, (list, tuple))
|
| 101 |
+
assert len(keypoints_list) > 0
|
| 102 |
+
assert all(isinstance(keypoints, Keypoints) for keypoints in keypoints_list)
|
| 103 |
+
|
| 104 |
+
cat_kpts = type(keypoints_list[0])(
|
| 105 |
+
torch.cat([kpts.tensor for kpts in keypoints_list], dim=0)
|
| 106 |
+
)
|
| 107 |
+
return cat_kpts
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _keypoints_to_heatmap(
|
| 111 |
+
keypoints: torch.Tensor, rois: torch.Tensor, heatmap_size: int
|
| 112 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 113 |
+
"""
|
| 114 |
+
Encode keypoint locations into a target heatmap for use in SoftmaxWithLoss across space.
|
| 115 |
+
|
| 116 |
+
Maps keypoints from the half-open interval [x1, x2) on continuous image coordinates to the
|
| 117 |
+
closed interval [0, heatmap_size - 1] on discrete image coordinates. We use the
|
| 118 |
+
continuous-discrete conversion from Heckbert 1990 ("What is the coordinate of a pixel?"):
|
| 119 |
+
d = floor(c) and c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
|
| 120 |
+
|
| 121 |
+
Arguments:
|
| 122 |
+
keypoints: tensor of keypoint locations in of shape (N, K, 3).
|
| 123 |
+
rois: Nx4 tensor of rois in xyxy format
|
| 124 |
+
heatmap_size: integer side length of square heatmap.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
heatmaps: A tensor of shape (N, K) containing an integer spatial label
|
| 128 |
+
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
|
| 129 |
+
valid: A tensor of shape (N, K) containing whether each keypoint is in
|
| 130 |
+
the roi or not.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
if rois.numel() == 0:
|
| 134 |
+
return rois.new().long(), rois.new().long()
|
| 135 |
+
offset_x = rois[:, 0]
|
| 136 |
+
offset_y = rois[:, 1]
|
| 137 |
+
scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
|
| 138 |
+
scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
|
| 139 |
+
|
| 140 |
+
offset_x = offset_x[:, None]
|
| 141 |
+
offset_y = offset_y[:, None]
|
| 142 |
+
scale_x = scale_x[:, None]
|
| 143 |
+
scale_y = scale_y[:, None]
|
| 144 |
+
|
| 145 |
+
x = keypoints[..., 0]
|
| 146 |
+
y = keypoints[..., 1]
|
| 147 |
+
|
| 148 |
+
x_boundary_inds = x == rois[:, 2][:, None]
|
| 149 |
+
y_boundary_inds = y == rois[:, 3][:, None]
|
| 150 |
+
|
| 151 |
+
x = (x - offset_x) * scale_x
|
| 152 |
+
x = x.floor().long()
|
| 153 |
+
y = (y - offset_y) * scale_y
|
| 154 |
+
y = y.floor().long()
|
| 155 |
+
|
| 156 |
+
x[x_boundary_inds] = heatmap_size - 1
|
| 157 |
+
y[y_boundary_inds] = heatmap_size - 1
|
| 158 |
+
|
| 159 |
+
valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
|
| 160 |
+
vis = keypoints[..., 2] > 0
|
| 161 |
+
valid = (valid_loc & vis).long()
|
| 162 |
+
|
| 163 |
+
lin_ind = y * heatmap_size + x
|
| 164 |
+
heatmaps = lin_ind * valid
|
| 165 |
+
|
| 166 |
+
return heatmaps, valid
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@torch.jit.script_if_tracing
|
| 170 |
+
def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
|
| 171 |
+
"""
|
| 172 |
+
Extract predicted keypoint locations from heatmaps.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for
|
| 176 |
+
each ROI and each keypoint.
|
| 177 |
+
rois (Tensor): (#ROIs, 4). The box of each ROI.
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to
|
| 181 |
+
(x, y, logit, score) for each keypoint.
|
| 182 |
+
|
| 183 |
+
When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate,
|
| 184 |
+
we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from
|
| 185 |
+
Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
offset_x = rois[:, 0]
|
| 189 |
+
offset_y = rois[:, 1]
|
| 190 |
+
|
| 191 |
+
widths = (rois[:, 2] - rois[:, 0]).clamp(min=1)
|
| 192 |
+
heights = (rois[:, 3] - rois[:, 1]).clamp(min=1)
|
| 193 |
+
widths_ceil = widths.ceil()
|
| 194 |
+
heights_ceil = heights.ceil()
|
| 195 |
+
|
| 196 |
+
num_rois, num_keypoints = maps.shape[:2]
|
| 197 |
+
xy_preds = maps.new_zeros(rois.shape[0], num_keypoints, 4)
|
| 198 |
+
|
| 199 |
+
width_corrections = widths / widths_ceil
|
| 200 |
+
height_corrections = heights / heights_ceil
|
| 201 |
+
|
| 202 |
+
keypoints_idx = torch.arange(num_keypoints, device=maps.device)
|
| 203 |
+
|
| 204 |
+
for i in range(num_rois):
|
| 205 |
+
outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
|
| 206 |
+
roi_map = F.interpolate(
|
| 207 |
+
maps[[i]], size=outsize, mode="bicubic", align_corners=False
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Although semantically equivalent, `reshape` is used instead of `squeeze` due
|
| 211 |
+
# to limitation during ONNX export of `squeeze` in scripting mode
|
| 212 |
+
roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W
|
| 213 |
+
|
| 214 |
+
# softmax over the spatial region
|
| 215 |
+
max_score, _ = roi_map.view(num_keypoints, -1).max(1)
|
| 216 |
+
max_score = max_score.view(num_keypoints, 1, 1)
|
| 217 |
+
tmp_full_resolution = (roi_map - max_score).exp_()
|
| 218 |
+
tmp_pool_resolution = (maps[i] - max_score).exp_()
|
| 219 |
+
# Produce scores over the region H x W, but normalize with POOL_H x POOL_W,
|
| 220 |
+
# so that the scores of objects of different absolute sizes will be more comparable
|
| 221 |
+
roi_map_scores = tmp_full_resolution / tmp_pool_resolution.sum(
|
| 222 |
+
(1, 2), keepdim=True
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
w = roi_map.shape[2]
|
| 226 |
+
pos = roi_map.view(num_keypoints, -1).argmax(1)
|
| 227 |
+
|
| 228 |
+
x_int = pos % w
|
| 229 |
+
y_int = (pos - x_int) // w
|
| 230 |
+
|
| 231 |
+
assert (
|
| 232 |
+
roi_map_scores[keypoints_idx, y_int, x_int]
|
| 233 |
+
== roi_map_scores.view(num_keypoints, -1).max(1)[0]
|
| 234 |
+
).all()
|
| 235 |
+
|
| 236 |
+
x = (x_int.float() + 0.5) * width_corrections[i]
|
| 237 |
+
y = (y_int.float() + 0.5) * height_corrections[i]
|
| 238 |
+
|
| 239 |
+
xy_preds[i, :, 0] = x + offset_x[i]
|
| 240 |
+
xy_preds[i, :, 1] = y + offset_y[i]
|
| 241 |
+
xy_preds[i, :, 2] = roi_map[keypoints_idx, y_int, x_int]
|
| 242 |
+
xy_preds[i, :, 3] = roi_map_scores[keypoints_idx, y_int, x_int]
|
| 243 |
+
|
| 244 |
+
return xy_preds
|
sam3/agent/helpers/mask_overlap_removal.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
from typing import Dict, List
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from pycocotools import mask as mask_utils
|
| 10 |
+
except Exception:
|
| 11 |
+
mask_utils = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def mask_intersection(
|
| 15 |
+
masks1: torch.Tensor, masks2: torch.Tensor, block_size: int = 16
|
| 16 |
+
) -> torch.Tensor:
|
| 17 |
+
assert masks1.shape[1:] == masks2.shape[1:]
|
| 18 |
+
assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
|
| 19 |
+
N, M = masks1.shape[0], masks2.shape[0]
|
| 20 |
+
out = torch.zeros(N, M, device=masks1.device, dtype=torch.long)
|
| 21 |
+
for i in range(0, N, block_size):
|
| 22 |
+
for j in range(0, M, block_size):
|
| 23 |
+
a = masks1[i : i + block_size]
|
| 24 |
+
b = masks2[j : j + block_size]
|
| 25 |
+
inter = (a[:, None] & b[None, :]).flatten(-2).sum(-1)
|
| 26 |
+
out[i : i + block_size, j : j + block_size] = inter
|
| 27 |
+
return out
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def mask_iom(masks1: torch.Tensor, masks2: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
assert masks1.shape[1:] == masks2.shape[1:]
|
| 32 |
+
assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
|
| 33 |
+
inter = mask_intersection(masks1, masks2)
|
| 34 |
+
area1 = masks1.flatten(-2).sum(-1) # (N,)
|
| 35 |
+
area2 = masks2.flatten(-2).sum(-1) # (M,)
|
| 36 |
+
min_area = torch.min(area1[:, None], area2[None, :]).clamp_min(1)
|
| 37 |
+
return inter.float() / (min_area.float() + 1e-8)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _decode_single_mask(mask_repr, h: int, w: int) -> np.ndarray:
|
| 41 |
+
if isinstance(mask_repr, (list, tuple, np.ndarray)):
|
| 42 |
+
arr = np.array(mask_repr)
|
| 43 |
+
if arr.ndim != 2:
|
| 44 |
+
raise ValueError("Mask array must be 2D (H, W).")
|
| 45 |
+
return (arr > 0).astype(np.uint8)
|
| 46 |
+
|
| 47 |
+
if mask_utils is None:
|
| 48 |
+
raise ImportError(
|
| 49 |
+
"pycocotools is required to decode RLE mask strings. pip install pycocotools"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
if not isinstance(mask_repr, (str, bytes)):
|
| 53 |
+
raise ValueError("Unsupported mask representation type for RLE decode.")
|
| 54 |
+
|
| 55 |
+
rle = {
|
| 56 |
+
"counts": mask_repr if isinstance(mask_repr, (str, bytes)) else str(mask_repr),
|
| 57 |
+
"size": [h, w],
|
| 58 |
+
}
|
| 59 |
+
decoded = mask_utils.decode(rle)
|
| 60 |
+
if decoded.ndim == 3:
|
| 61 |
+
decoded = decoded[:, :, 0]
|
| 62 |
+
return (decoded > 0).astype(np.uint8)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _decode_masks_to_torch_bool(pred_masks: List, h: int, w: int) -> torch.Tensor:
|
| 66 |
+
bin_masks = [_decode_single_mask(m, h, w) for m in pred_masks]
|
| 67 |
+
masks_np = np.stack(bin_masks, axis=0).astype(np.uint8) # (N, H, W)
|
| 68 |
+
return torch.from_numpy(masks_np > 0)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def remove_overlapping_masks(sample: Dict, iom_thresh: float = 0.3) -> Dict:
|
| 72 |
+
"""
|
| 73 |
+
Greedy keep: sort by score desc; keep a mask if IoM to all kept masks <= threshold.
|
| 74 |
+
If pred_masks has length 0 or 1, returns sample unchanged (no extra keys).
|
| 75 |
+
"""
|
| 76 |
+
# Basic presence checks
|
| 77 |
+
if "pred_masks" not in sample or not isinstance(sample["pred_masks"], list):
|
| 78 |
+
return sample # nothing to do / preserve as-is
|
| 79 |
+
|
| 80 |
+
pred_masks = sample["pred_masks"]
|
| 81 |
+
N = len(pred_masks)
|
| 82 |
+
|
| 83 |
+
# --- Early exit: 0 or 1 mask -> do NOT modify the JSON at all ---
|
| 84 |
+
if N <= 1:
|
| 85 |
+
return sample
|
| 86 |
+
|
| 87 |
+
# From here on we have at least 2 masks
|
| 88 |
+
h = int(sample["orig_img_h"])
|
| 89 |
+
w = int(sample["orig_img_w"])
|
| 90 |
+
pred_scores = sample.get("pred_scores", [1.0] * N) # fallback if scores missing
|
| 91 |
+
pred_boxes = sample.get("pred_boxes", None)
|
| 92 |
+
|
| 93 |
+
assert N == len(pred_scores), "pred_masks and pred_scores must have same length"
|
| 94 |
+
if pred_boxes is not None:
|
| 95 |
+
assert N == len(pred_boxes), "pred_masks and pred_boxes must have same length"
|
| 96 |
+
|
| 97 |
+
masks_bool = _decode_masks_to_torch_bool(pred_masks, h, w) # (N, H, W)
|
| 98 |
+
|
| 99 |
+
order = sorted(range(N), key=lambda i: float(pred_scores[i]), reverse=True)
|
| 100 |
+
kept_idx: List[int] = []
|
| 101 |
+
kept_masks: List[torch.Tensor] = []
|
| 102 |
+
|
| 103 |
+
for i in order:
|
| 104 |
+
cand = masks_bool[i].unsqueeze(0) # (1, H, W)
|
| 105 |
+
if len(kept_masks) == 0:
|
| 106 |
+
kept_idx.append(i)
|
| 107 |
+
kept_masks.append(masks_bool[i])
|
| 108 |
+
continue
|
| 109 |
+
|
| 110 |
+
kept_stack = torch.stack(kept_masks, dim=0) # (K, H, W)
|
| 111 |
+
iom_vals = mask_iom(cand, kept_stack).squeeze(0) # (K,)
|
| 112 |
+
if torch.any(iom_vals > iom_thresh):
|
| 113 |
+
continue # overlaps too much with a higher-scored kept mask
|
| 114 |
+
kept_idx.append(i)
|
| 115 |
+
kept_masks.append(masks_bool[i])
|
| 116 |
+
|
| 117 |
+
kept_idx_sorted = sorted(kept_idx)
|
| 118 |
+
|
| 119 |
+
# Build filtered JSON (this *does* modify fields; only for N>=2 case)
|
| 120 |
+
out = dict(sample)
|
| 121 |
+
out["pred_masks"] = [pred_masks[i] for i in kept_idx_sorted]
|
| 122 |
+
out["pred_scores"] = [pred_scores[i] for i in kept_idx_sorted]
|
| 123 |
+
if pred_boxes is not None:
|
| 124 |
+
out["pred_boxes"] = [pred_boxes[i] for i in kept_idx_sorted]
|
| 125 |
+
out["kept_indices"] = kept_idx_sorted
|
| 126 |
+
out["removed_indices"] = [i for i in range(N) if i not in set(kept_idx_sorted)]
|
| 127 |
+
out["iom_threshold"] = float(iom_thresh)
|
| 128 |
+
return out
|
sam3/agent/helpers/masks.py
ADDED
|
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import itertools
|
| 5 |
+
from typing import Any, Iterator, List, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pycocotools.mask as mask_util
|
| 9 |
+
import torch
|
| 10 |
+
from torch import device
|
| 11 |
+
|
| 12 |
+
from .boxes import Boxes
|
| 13 |
+
from .memory import retry_if_cuda_oom
|
| 14 |
+
|
| 15 |
+
from .roi_align import ROIAlign
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def polygon_area(x, y):
|
| 19 |
+
# Using the shoelace formula
|
| 20 |
+
# https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
|
| 21 |
+
return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def polygons_to_bitmask(
|
| 25 |
+
polygons: List[np.ndarray], height: int, width: int
|
| 26 |
+
) -> np.ndarray:
|
| 27 |
+
"""
|
| 28 |
+
Args:
|
| 29 |
+
polygons (list[ndarray]): each array has shape (Nx2,)
|
| 30 |
+
height, width (int)
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
ndarray: a bool mask of shape (height, width)
|
| 34 |
+
"""
|
| 35 |
+
if len(polygons) == 0:
|
| 36 |
+
# COCOAPI does not support empty polygons
|
| 37 |
+
return np.zeros((height, width)).astype(bool)
|
| 38 |
+
rles = mask_util.frPyObjects(polygons, height, width)
|
| 39 |
+
rle = mask_util.merge(rles)
|
| 40 |
+
return mask_util.decode(rle).astype(bool)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def rasterize_polygons_within_box(
|
| 44 |
+
polygons: List[np.ndarray], box: np.ndarray, mask_size: int
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
"""
|
| 47 |
+
Rasterize the polygons into a mask image and
|
| 48 |
+
crop the mask content in the given box.
|
| 49 |
+
The cropped mask is resized to (mask_size, mask_size).
|
| 50 |
+
|
| 51 |
+
This function is used when generating training targets for mask head in Mask R-CNN.
|
| 52 |
+
Given original ground-truth masks for an image, new ground-truth mask
|
| 53 |
+
training targets in the size of `mask_size x mask_size`
|
| 54 |
+
must be provided for each predicted box. This function will be called to
|
| 55 |
+
produce such targets.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
polygons (list[ndarray[float]]): a list of polygons, which represents an instance.
|
| 59 |
+
box: 4-element numpy array
|
| 60 |
+
mask_size (int):
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Tensor: BoolTensor of shape (mask_size, mask_size)
|
| 64 |
+
"""
|
| 65 |
+
# 1. Shift the polygons w.r.t the boxes
|
| 66 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
| 67 |
+
|
| 68 |
+
polygons = copy.deepcopy(polygons)
|
| 69 |
+
for p in polygons:
|
| 70 |
+
p[0::2] = p[0::2] - box[0]
|
| 71 |
+
p[1::2] = p[1::2] - box[1]
|
| 72 |
+
|
| 73 |
+
# 2. Rescale the polygons to the new box size
|
| 74 |
+
# max() to avoid division by small number
|
| 75 |
+
ratio_h = mask_size / max(h, 0.1)
|
| 76 |
+
ratio_w = mask_size / max(w, 0.1)
|
| 77 |
+
|
| 78 |
+
if ratio_h == ratio_w:
|
| 79 |
+
for p in polygons:
|
| 80 |
+
p *= ratio_h
|
| 81 |
+
else:
|
| 82 |
+
for p in polygons:
|
| 83 |
+
p[0::2] *= ratio_w
|
| 84 |
+
p[1::2] *= ratio_h
|
| 85 |
+
|
| 86 |
+
# 3. Rasterize the polygons with coco api
|
| 87 |
+
mask = polygons_to_bitmask(polygons, mask_size, mask_size)
|
| 88 |
+
mask = torch.from_numpy(mask)
|
| 89 |
+
return mask
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class BitMasks:
|
| 93 |
+
"""
|
| 94 |
+
This class stores the segmentation masks for all objects in one image, in
|
| 95 |
+
the form of bitmaps.
|
| 96 |
+
|
| 97 |
+
Attributes:
|
| 98 |
+
tensor: bool Tensor of N,H,W, representing N instances in the image.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(self, tensor: Union[torch.Tensor, np.ndarray]):
|
| 102 |
+
"""
|
| 103 |
+
Args:
|
| 104 |
+
tensor: bool Tensor of N,H,W, representing N instances in the image.
|
| 105 |
+
"""
|
| 106 |
+
if isinstance(tensor, torch.Tensor):
|
| 107 |
+
tensor = tensor.to(torch.bool)
|
| 108 |
+
else:
|
| 109 |
+
tensor = torch.as_tensor(
|
| 110 |
+
tensor, dtype=torch.bool, device=torch.device("cpu")
|
| 111 |
+
)
|
| 112 |
+
assert tensor.dim() == 3, tensor.size()
|
| 113 |
+
self.image_size = tensor.shape[1:]
|
| 114 |
+
self.tensor = tensor
|
| 115 |
+
|
| 116 |
+
@torch.jit.unused
|
| 117 |
+
def to(self, *args: Any, **kwargs: Any) -> "BitMasks":
|
| 118 |
+
return BitMasks(self.tensor.to(*args, **kwargs))
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def device(self) -> torch.device:
|
| 122 |
+
return self.tensor.device
|
| 123 |
+
|
| 124 |
+
@torch.jit.unused
|
| 125 |
+
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "BitMasks":
|
| 126 |
+
"""
|
| 127 |
+
Returns:
|
| 128 |
+
BitMasks: Create a new :class:`BitMasks` by indexing.
|
| 129 |
+
|
| 130 |
+
The following usage are allowed:
|
| 131 |
+
|
| 132 |
+
1. `new_masks = masks[3]`: return a `BitMasks` which contains only one mask.
|
| 133 |
+
2. `new_masks = masks[2:10]`: return a slice of masks.
|
| 134 |
+
3. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
|
| 135 |
+
with `length = len(masks)`. Nonzero elements in the vector will be selected.
|
| 136 |
+
|
| 137 |
+
Note that the returned object might share storage with this object,
|
| 138 |
+
subject to Pytorch's indexing semantics.
|
| 139 |
+
"""
|
| 140 |
+
if isinstance(item, int):
|
| 141 |
+
return BitMasks(self.tensor[item].unsqueeze(0))
|
| 142 |
+
m = self.tensor[item]
|
| 143 |
+
assert (
|
| 144 |
+
m.dim() == 3
|
| 145 |
+
), "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
|
| 146 |
+
item, m.shape
|
| 147 |
+
)
|
| 148 |
+
return BitMasks(m)
|
| 149 |
+
|
| 150 |
+
@torch.jit.unused
|
| 151 |
+
def __iter__(self) -> torch.Tensor:
|
| 152 |
+
yield from self.tensor
|
| 153 |
+
|
| 154 |
+
@torch.jit.unused
|
| 155 |
+
def __repr__(self) -> str:
|
| 156 |
+
s = self.__class__.__name__ + "("
|
| 157 |
+
s += "num_instances={})".format(len(self.tensor))
|
| 158 |
+
return s
|
| 159 |
+
|
| 160 |
+
def __len__(self) -> int:
|
| 161 |
+
return self.tensor.shape[0]
|
| 162 |
+
|
| 163 |
+
def nonempty(self) -> torch.Tensor:
|
| 164 |
+
"""
|
| 165 |
+
Find masks that are non-empty.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Tensor: a BoolTensor which represents
|
| 169 |
+
whether each mask is empty (False) or non-empty (True).
|
| 170 |
+
"""
|
| 171 |
+
return self.tensor.flatten(1).any(dim=1)
|
| 172 |
+
|
| 173 |
+
@staticmethod
|
| 174 |
+
def from_polygon_masks(
|
| 175 |
+
polygon_masks: Union["PolygonMasks", List[List[np.ndarray]]],
|
| 176 |
+
height: int,
|
| 177 |
+
width: int,
|
| 178 |
+
) -> "BitMasks":
|
| 179 |
+
"""
|
| 180 |
+
Args:
|
| 181 |
+
polygon_masks (list[list[ndarray]] or PolygonMasks)
|
| 182 |
+
height, width (int)
|
| 183 |
+
"""
|
| 184 |
+
if isinstance(polygon_masks, PolygonMasks):
|
| 185 |
+
polygon_masks = polygon_masks.polygons
|
| 186 |
+
masks = [polygons_to_bitmask(p, height, width) for p in polygon_masks]
|
| 187 |
+
if len(masks):
|
| 188 |
+
return BitMasks(torch.stack([torch.from_numpy(x) for x in masks]))
|
| 189 |
+
else:
|
| 190 |
+
return BitMasks(torch.empty(0, height, width, dtype=torch.bool))
|
| 191 |
+
|
| 192 |
+
@staticmethod
|
| 193 |
+
def from_roi_masks(roi_masks: "ROIMasks", height: int, width: int) -> "BitMasks":
|
| 194 |
+
"""
|
| 195 |
+
Args:
|
| 196 |
+
roi_masks:
|
| 197 |
+
height, width (int):
|
| 198 |
+
"""
|
| 199 |
+
return roi_masks.to_bitmasks(height, width)
|
| 200 |
+
|
| 201 |
+
def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
|
| 202 |
+
"""
|
| 203 |
+
Crop each bitmask by the given box, and resize results to (mask_size, mask_size).
|
| 204 |
+
This can be used to prepare training targets for Mask R-CNN.
|
| 205 |
+
It has less reconstruction error compared to rasterization with polygons.
|
| 206 |
+
However we observe no difference in accuracy,
|
| 207 |
+
but BitMasks requires more memory to store all the masks.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
boxes (Tensor): Nx4 tensor storing the boxes for each mask
|
| 211 |
+
mask_size (int): the size of the rasterized mask.
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Tensor:
|
| 215 |
+
A bool tensor of shape (N, mask_size, mask_size), where
|
| 216 |
+
N is the number of predicted boxes for this image.
|
| 217 |
+
"""
|
| 218 |
+
assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
|
| 219 |
+
device = self.tensor.device
|
| 220 |
+
|
| 221 |
+
batch_inds = torch.arange(len(boxes), device=device).to(dtype=boxes.dtype)[
|
| 222 |
+
:, None
|
| 223 |
+
]
|
| 224 |
+
rois = torch.cat([batch_inds, boxes], dim=1) # Nx5
|
| 225 |
+
|
| 226 |
+
bit_masks = self.tensor.to(dtype=torch.float32)
|
| 227 |
+
rois = rois.to(device=device)
|
| 228 |
+
output = (
|
| 229 |
+
ROIAlign((mask_size, mask_size), 1.0, 0, aligned=True)
|
| 230 |
+
.forward(bit_masks[:, None, :, :], rois)
|
| 231 |
+
.squeeze(1)
|
| 232 |
+
)
|
| 233 |
+
output = output >= 0.5
|
| 234 |
+
return output
|
| 235 |
+
|
| 236 |
+
def get_bounding_boxes(self) -> Boxes:
|
| 237 |
+
"""
|
| 238 |
+
Returns:
|
| 239 |
+
Boxes: tight bounding boxes around bitmasks.
|
| 240 |
+
If a mask is empty, it's bounding box will be all zero.
|
| 241 |
+
"""
|
| 242 |
+
boxes = torch.zeros(self.tensor.shape[0], 4, dtype=torch.float32)
|
| 243 |
+
x_any = torch.any(self.tensor, dim=1)
|
| 244 |
+
y_any = torch.any(self.tensor, dim=2)
|
| 245 |
+
for idx in range(self.tensor.shape[0]):
|
| 246 |
+
x = torch.where(x_any[idx, :])[0]
|
| 247 |
+
y = torch.where(y_any[idx, :])[0]
|
| 248 |
+
if len(x) > 0 and len(y) > 0:
|
| 249 |
+
boxes[idx, :] = torch.as_tensor(
|
| 250 |
+
[x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=torch.float32
|
| 251 |
+
)
|
| 252 |
+
return Boxes(boxes)
|
| 253 |
+
|
| 254 |
+
@staticmethod
|
| 255 |
+
def cat(bitmasks_list: List["BitMasks"]) -> "BitMasks":
|
| 256 |
+
"""
|
| 257 |
+
Concatenates a list of BitMasks into a single BitMasks
|
| 258 |
+
|
| 259 |
+
Arguments:
|
| 260 |
+
bitmasks_list (list[BitMasks])
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
BitMasks: the concatenated BitMasks
|
| 264 |
+
"""
|
| 265 |
+
assert isinstance(bitmasks_list, (list, tuple))
|
| 266 |
+
assert len(bitmasks_list) > 0
|
| 267 |
+
assert all(isinstance(bitmask, BitMasks) for bitmask in bitmasks_list)
|
| 268 |
+
|
| 269 |
+
cat_bitmasks = type(bitmasks_list[0])(
|
| 270 |
+
torch.cat([bm.tensor for bm in bitmasks_list], dim=0)
|
| 271 |
+
)
|
| 272 |
+
return cat_bitmasks
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class PolygonMasks:
|
| 276 |
+
"""
|
| 277 |
+
This class stores the segmentation masks for all objects in one image, in the form of polygons.
|
| 278 |
+
|
| 279 |
+
Attributes:
|
| 280 |
+
polygons: list[list[ndarray]]. Each ndarray is a float64 vector representing a polygon.
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
def __init__(self, polygons: List[List[Union[torch.Tensor, np.ndarray]]]):
|
| 284 |
+
"""
|
| 285 |
+
Arguments:
|
| 286 |
+
polygons (list[list[np.ndarray]]): The first
|
| 287 |
+
level of the list correspond to individual instances,
|
| 288 |
+
the second level to all the polygons that compose the
|
| 289 |
+
instance, and the third level to the polygon coordinates.
|
| 290 |
+
The third level array should have the format of
|
| 291 |
+
[x0, y0, x1, y1, ..., xn, yn] (n >= 3).
|
| 292 |
+
"""
|
| 293 |
+
if not isinstance(polygons, list):
|
| 294 |
+
raise ValueError(
|
| 295 |
+
"Cannot create PolygonMasks: Expect a list of list of polygons per image. "
|
| 296 |
+
"Got '{}' instead.".format(type(polygons))
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
def _make_array(t: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
| 300 |
+
# Use float64 for higher precision, because why not?
|
| 301 |
+
# Always put polygons on CPU (self.to is a no-op) since they
|
| 302 |
+
# are supposed to be small tensors.
|
| 303 |
+
# May need to change this assumption if GPU placement becomes useful
|
| 304 |
+
if isinstance(t, torch.Tensor):
|
| 305 |
+
t = t.cpu().numpy()
|
| 306 |
+
return np.asarray(t).astype("float64")
|
| 307 |
+
|
| 308 |
+
def process_polygons(
|
| 309 |
+
polygons_per_instance: List[Union[torch.Tensor, np.ndarray]],
|
| 310 |
+
) -> List[np.ndarray]:
|
| 311 |
+
if not isinstance(polygons_per_instance, list):
|
| 312 |
+
raise ValueError(
|
| 313 |
+
"Cannot create polygons: Expect a list of polygons per instance. "
|
| 314 |
+
"Got '{}' instead.".format(type(polygons_per_instance))
|
| 315 |
+
)
|
| 316 |
+
# transform each polygon to a numpy array
|
| 317 |
+
polygons_per_instance = [_make_array(p) for p in polygons_per_instance]
|
| 318 |
+
for polygon in polygons_per_instance:
|
| 319 |
+
if len(polygon) % 2 != 0 or len(polygon) < 6:
|
| 320 |
+
raise ValueError(
|
| 321 |
+
f"Cannot create a polygon from {len(polygon)} coordinates."
|
| 322 |
+
)
|
| 323 |
+
return polygons_per_instance
|
| 324 |
+
|
| 325 |
+
self.polygons: List[List[np.ndarray]] = [
|
| 326 |
+
process_polygons(polygons_per_instance)
|
| 327 |
+
for polygons_per_instance in polygons
|
| 328 |
+
]
|
| 329 |
+
|
| 330 |
+
def to(self, *args: Any, **kwargs: Any) -> "PolygonMasks":
|
| 331 |
+
return self
|
| 332 |
+
|
| 333 |
+
@property
|
| 334 |
+
def device(self) -> torch.device:
|
| 335 |
+
return torch.device("cpu")
|
| 336 |
+
|
| 337 |
+
def get_bounding_boxes(self) -> Boxes:
|
| 338 |
+
"""
|
| 339 |
+
Returns:
|
| 340 |
+
Boxes: tight bounding boxes around polygon masks.
|
| 341 |
+
"""
|
| 342 |
+
boxes = torch.zeros(len(self.polygons), 4, dtype=torch.float32)
|
| 343 |
+
for idx, polygons_per_instance in enumerate(self.polygons):
|
| 344 |
+
minxy = torch.as_tensor([float("inf"), float("inf")], dtype=torch.float32)
|
| 345 |
+
maxxy = torch.zeros(2, dtype=torch.float32)
|
| 346 |
+
for polygon in polygons_per_instance:
|
| 347 |
+
coords = torch.from_numpy(polygon).view(-1, 2).to(dtype=torch.float32)
|
| 348 |
+
minxy = torch.min(minxy, torch.min(coords, dim=0).values)
|
| 349 |
+
maxxy = torch.max(maxxy, torch.max(coords, dim=0).values)
|
| 350 |
+
boxes[idx, :2] = minxy
|
| 351 |
+
boxes[idx, 2:] = maxxy
|
| 352 |
+
return Boxes(boxes)
|
| 353 |
+
|
| 354 |
+
def nonempty(self) -> torch.Tensor:
|
| 355 |
+
"""
|
| 356 |
+
Find masks that are non-empty.
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
Tensor:
|
| 360 |
+
a BoolTensor which represents whether each mask is empty (False) or not (True).
|
| 361 |
+
"""
|
| 362 |
+
keep = [1 if len(polygon) > 0 else 0 for polygon in self.polygons]
|
| 363 |
+
return torch.from_numpy(np.asarray(keep, dtype=bool))
|
| 364 |
+
|
| 365 |
+
def __getitem__(
|
| 366 |
+
self, item: Union[int, slice, List[int], torch.BoolTensor]
|
| 367 |
+
) -> "PolygonMasks":
|
| 368 |
+
"""
|
| 369 |
+
Support indexing over the instances and return a `PolygonMasks` object.
|
| 370 |
+
`item` can be:
|
| 371 |
+
|
| 372 |
+
1. An integer. It will return an object with only one instance.
|
| 373 |
+
2. A slice. It will return an object with the selected instances.
|
| 374 |
+
3. A list[int]. It will return an object with the selected instances,
|
| 375 |
+
correpsonding to the indices in the list.
|
| 376 |
+
4. A vector mask of type BoolTensor, whose length is num_instances.
|
| 377 |
+
It will return an object with the instances whose mask is nonzero.
|
| 378 |
+
"""
|
| 379 |
+
if isinstance(item, int):
|
| 380 |
+
selected_polygons = [self.polygons[item]]
|
| 381 |
+
elif isinstance(item, slice):
|
| 382 |
+
selected_polygons = self.polygons[item]
|
| 383 |
+
elif isinstance(item, list):
|
| 384 |
+
selected_polygons = [self.polygons[i] for i in item]
|
| 385 |
+
elif isinstance(item, torch.Tensor):
|
| 386 |
+
# Polygons is a list, so we have to move the indices back to CPU.
|
| 387 |
+
if item.dtype == torch.bool:
|
| 388 |
+
assert item.dim() == 1, item.shape
|
| 389 |
+
item = item.nonzero().squeeze(1).cpu().numpy().tolist()
|
| 390 |
+
elif item.dtype in [torch.int32, torch.int64]:
|
| 391 |
+
item = item.cpu().numpy().tolist()
|
| 392 |
+
else:
|
| 393 |
+
raise ValueError(
|
| 394 |
+
"Unsupported tensor dtype={} for indexing!".format(item.dtype)
|
| 395 |
+
)
|
| 396 |
+
selected_polygons = [self.polygons[i] for i in item]
|
| 397 |
+
return PolygonMasks(selected_polygons)
|
| 398 |
+
|
| 399 |
+
def __iter__(self) -> Iterator[List[np.ndarray]]:
|
| 400 |
+
"""
|
| 401 |
+
Yields:
|
| 402 |
+
list[ndarray]: the polygons for one instance.
|
| 403 |
+
Each Tensor is a float64 vector representing a polygon.
|
| 404 |
+
"""
|
| 405 |
+
return iter(self.polygons)
|
| 406 |
+
|
| 407 |
+
def __repr__(self) -> str:
|
| 408 |
+
s = self.__class__.__name__ + "("
|
| 409 |
+
s += "num_instances={})".format(len(self.polygons))
|
| 410 |
+
return s
|
| 411 |
+
|
| 412 |
+
def __len__(self) -> int:
|
| 413 |
+
return len(self.polygons)
|
| 414 |
+
|
| 415 |
+
def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
|
| 416 |
+
"""
|
| 417 |
+
Crop each mask by the given box, and resize results to (mask_size, mask_size).
|
| 418 |
+
This can be used to prepare training targets for Mask R-CNN.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
boxes (Tensor): Nx4 tensor storing the boxes for each mask
|
| 422 |
+
mask_size (int): the size of the rasterized mask.
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
Tensor: A bool tensor of shape (N, mask_size, mask_size), where
|
| 426 |
+
N is the number of predicted boxes for this image.
|
| 427 |
+
"""
|
| 428 |
+
assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
|
| 429 |
+
|
| 430 |
+
device = boxes.device
|
| 431 |
+
# Put boxes on the CPU, as the polygon representation is not efficient GPU-wise
|
| 432 |
+
# (several small tensors for representing a single instance mask)
|
| 433 |
+
boxes = boxes.to(torch.device("cpu"))
|
| 434 |
+
|
| 435 |
+
results = [
|
| 436 |
+
rasterize_polygons_within_box(poly, box.numpy(), mask_size)
|
| 437 |
+
for poly, box in zip(self.polygons, boxes)
|
| 438 |
+
]
|
| 439 |
+
"""
|
| 440 |
+
poly: list[list[float]], the polygons for one instance
|
| 441 |
+
box: a tensor of shape (4,)
|
| 442 |
+
"""
|
| 443 |
+
if len(results) == 0:
|
| 444 |
+
return torch.empty(0, mask_size, mask_size, dtype=torch.bool, device=device)
|
| 445 |
+
return torch.stack(results, dim=0).to(device=device)
|
| 446 |
+
|
| 447 |
+
def area(self):
|
| 448 |
+
"""
|
| 449 |
+
Computes area of the mask.
|
| 450 |
+
Only works with Polygons, using the shoelace formula:
|
| 451 |
+
https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
Tensor: a vector, area for each instance
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
area = []
|
| 458 |
+
for polygons_per_instance in self.polygons:
|
| 459 |
+
area_per_instance = 0
|
| 460 |
+
for p in polygons_per_instance:
|
| 461 |
+
area_per_instance += polygon_area(p[0::2], p[1::2])
|
| 462 |
+
area.append(area_per_instance)
|
| 463 |
+
|
| 464 |
+
return torch.tensor(area)
|
| 465 |
+
|
| 466 |
+
@staticmethod
|
| 467 |
+
def cat(polymasks_list: List["PolygonMasks"]) -> "PolygonMasks":
|
| 468 |
+
"""
|
| 469 |
+
Concatenates a list of PolygonMasks into a single PolygonMasks
|
| 470 |
+
|
| 471 |
+
Arguments:
|
| 472 |
+
polymasks_list (list[PolygonMasks])
|
| 473 |
+
|
| 474 |
+
Returns:
|
| 475 |
+
PolygonMasks: the concatenated PolygonMasks
|
| 476 |
+
"""
|
| 477 |
+
assert isinstance(polymasks_list, (list, tuple))
|
| 478 |
+
assert len(polymasks_list) > 0
|
| 479 |
+
assert all(isinstance(polymask, PolygonMasks) for polymask in polymasks_list)
|
| 480 |
+
|
| 481 |
+
cat_polymasks = type(polymasks_list[0])(
|
| 482 |
+
list(itertools.chain.from_iterable(pm.polygons for pm in polymasks_list))
|
| 483 |
+
)
|
| 484 |
+
return cat_polymasks
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class ROIMasks:
|
| 488 |
+
"""
|
| 489 |
+
Represent masks by N smaller masks defined in some ROIs. Once ROI boxes are given,
|
| 490 |
+
full-image bitmask can be obtained by "pasting" the mask on the region defined
|
| 491 |
+
by the corresponding ROI box.
|
| 492 |
+
"""
|
| 493 |
+
|
| 494 |
+
def __init__(self, tensor: torch.Tensor):
|
| 495 |
+
"""
|
| 496 |
+
Args:
|
| 497 |
+
tensor: (N, M, M) mask tensor that defines the mask within each ROI.
|
| 498 |
+
"""
|
| 499 |
+
if tensor.dim() != 3:
|
| 500 |
+
raise ValueError("ROIMasks must take a masks of 3 dimension.")
|
| 501 |
+
self.tensor = tensor
|
| 502 |
+
|
| 503 |
+
def to(self, device: torch.device) -> "ROIMasks":
|
| 504 |
+
return ROIMasks(self.tensor.to(device))
|
| 505 |
+
|
| 506 |
+
@property
|
| 507 |
+
def device(self) -> device:
|
| 508 |
+
return self.tensor.device
|
| 509 |
+
|
| 510 |
+
def __len__(self):
|
| 511 |
+
return self.tensor.shape[0]
|
| 512 |
+
|
| 513 |
+
def __getitem__(self, item) -> "ROIMasks":
|
| 514 |
+
"""
|
| 515 |
+
Returns:
|
| 516 |
+
ROIMasks: Create a new :class:`ROIMasks` by indexing.
|
| 517 |
+
|
| 518 |
+
The following usage are allowed:
|
| 519 |
+
|
| 520 |
+
1. `new_masks = masks[2:10]`: return a slice of masks.
|
| 521 |
+
2. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
|
| 522 |
+
with `length = len(masks)`. Nonzero elements in the vector will be selected.
|
| 523 |
+
|
| 524 |
+
Note that the returned object might share storage with this object,
|
| 525 |
+
subject to Pytorch's indexing semantics.
|
| 526 |
+
"""
|
| 527 |
+
t = self.tensor[item]
|
| 528 |
+
if t.dim() != 3:
|
| 529 |
+
raise ValueError(
|
| 530 |
+
f"Indexing on ROIMasks with {item} returns a tensor with shape {t.shape}!"
|
| 531 |
+
)
|
| 532 |
+
return ROIMasks(t)
|
| 533 |
+
|
| 534 |
+
@torch.jit.unused
|
| 535 |
+
def __repr__(self) -> str:
|
| 536 |
+
s = self.__class__.__name__ + "("
|
| 537 |
+
s += "num_instances={})".format(len(self.tensor))
|
| 538 |
+
return s
|
| 539 |
+
|
| 540 |
+
@torch.jit.unused
|
| 541 |
+
def to_bitmasks(self, boxes: torch.Tensor, height, width, threshold=0.5):
|
| 542 |
+
"""
|
| 543 |
+
Args: see documentation of :func:`paste_masks_in_image`.
|
| 544 |
+
"""
|
| 545 |
+
from detectron2.layers.mask_ops import (
|
| 546 |
+
_paste_masks_tensor_shape,
|
| 547 |
+
paste_masks_in_image,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
if torch.jit.is_tracing():
|
| 551 |
+
if isinstance(height, torch.Tensor):
|
| 552 |
+
paste_func = _paste_masks_tensor_shape
|
| 553 |
+
else:
|
| 554 |
+
paste_func = paste_masks_in_image
|
| 555 |
+
else:
|
| 556 |
+
paste_func = retry_if_cuda_oom(paste_masks_in_image)
|
| 557 |
+
bitmasks = paste_func(
|
| 558 |
+
self.tensor, boxes.tensor, (height, width), threshold=threshold
|
| 559 |
+
)
|
| 560 |
+
return BitMasks(bitmasks)
|
sam3/agent/helpers/memory.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
from functools import wraps
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
__all__ = ["retry_if_cuda_oom"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@contextmanager
|
| 13 |
+
def _ignore_torch_cuda_oom():
|
| 14 |
+
"""
|
| 15 |
+
A context which ignores CUDA OOM exception from pytorch.
|
| 16 |
+
"""
|
| 17 |
+
try:
|
| 18 |
+
yield
|
| 19 |
+
except RuntimeError as e:
|
| 20 |
+
# NOTE: the string may change?
|
| 21 |
+
if "CUDA out of memory. " in str(e):
|
| 22 |
+
pass
|
| 23 |
+
else:
|
| 24 |
+
raise
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def retry_if_cuda_oom(func):
|
| 28 |
+
"""
|
| 29 |
+
Makes a function retry itself after encountering
|
| 30 |
+
pytorch's CUDA OOM error.
|
| 31 |
+
It will first retry after calling `torch.cuda.empty_cache()`.
|
| 32 |
+
|
| 33 |
+
If that still fails, it will then retry by trying to convert inputs to CPUs.
|
| 34 |
+
In this case, it expects the function to dispatch to CPU implementation.
|
| 35 |
+
The return values may become CPU tensors as well and it's user's
|
| 36 |
+
responsibility to convert it back to CUDA tensor if needed.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
func: a stateless callable that takes tensor-like objects as arguments
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
a callable which retries `func` if OOM is encountered.
|
| 43 |
+
|
| 44 |
+
Examples:
|
| 45 |
+
::
|
| 46 |
+
output = retry_if_cuda_oom(some_torch_function)(input1, input2)
|
| 47 |
+
# output may be on CPU even if inputs are on GPU
|
| 48 |
+
|
| 49 |
+
Note:
|
| 50 |
+
1. When converting inputs to CPU, it will only look at each argument and check
|
| 51 |
+
if it has `.device` and `.to` for conversion. Nested structures of tensors
|
| 52 |
+
are not supported.
|
| 53 |
+
|
| 54 |
+
2. Since the function might be called more than once, it has to be
|
| 55 |
+
stateless.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def maybe_to_cpu(x):
|
| 59 |
+
try:
|
| 60 |
+
like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to")
|
| 61 |
+
except AttributeError:
|
| 62 |
+
like_gpu_tensor = False
|
| 63 |
+
if like_gpu_tensor:
|
| 64 |
+
return x.to(device="cpu")
|
| 65 |
+
else:
|
| 66 |
+
return x
|
| 67 |
+
|
| 68 |
+
@wraps(func)
|
| 69 |
+
def wrapped(*args, **kwargs):
|
| 70 |
+
with _ignore_torch_cuda_oom():
|
| 71 |
+
return func(*args, **kwargs)
|
| 72 |
+
|
| 73 |
+
# Clear cache and retry
|
| 74 |
+
torch.cuda.empty_cache()
|
| 75 |
+
with _ignore_torch_cuda_oom():
|
| 76 |
+
return func(*args, **kwargs)
|
| 77 |
+
|
| 78 |
+
# Try on CPU. This slows down the code significantly, therefore print a notice.
|
| 79 |
+
logger = logging.getLogger(__name__)
|
| 80 |
+
logger.info(
|
| 81 |
+
"Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func))
|
| 82 |
+
)
|
| 83 |
+
new_args = (maybe_to_cpu(x) for x in args)
|
| 84 |
+
new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
|
| 85 |
+
return func(*new_args, **new_kwargs)
|
| 86 |
+
|
| 87 |
+
return wrapped
|
sam3/agent/helpers/rle.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
"""Some utilities for RLE encoding that doesn't require downloading the masks to the cpu"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from pycocotools import mask as mask_util
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@torch.no_grad()
|
| 11 |
+
def rle_encode(orig_mask, return_areas=False):
|
| 12 |
+
"""Encodes a collection of masks in RLE format
|
| 13 |
+
|
| 14 |
+
This function emulates the behavior of the COCO API's encode function, but
|
| 15 |
+
is executed partially on the GPU for faster execution.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
mask (torch.Tensor): A mask of shape (N, H, W) with dtype=torch.bool
|
| 19 |
+
return_areas (bool): If True, add the areas of the masks as a part of
|
| 20 |
+
the RLE output dict under the "area" key. Default is False.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
str: The RLE encoded masks
|
| 24 |
+
"""
|
| 25 |
+
assert orig_mask.ndim == 3, "Mask must be of shape (N, H, W)"
|
| 26 |
+
assert orig_mask.dtype == torch.bool, "Mask must have dtype=torch.bool"
|
| 27 |
+
|
| 28 |
+
if orig_mask.numel() == 0:
|
| 29 |
+
return []
|
| 30 |
+
|
| 31 |
+
# First, transpose the spatial dimensions.
|
| 32 |
+
# This is necessary because the COCO API uses Fortran order
|
| 33 |
+
mask = orig_mask.transpose(1, 2)
|
| 34 |
+
|
| 35 |
+
# Flatten the mask
|
| 36 |
+
flat_mask = mask.reshape(mask.shape[0], -1)
|
| 37 |
+
if return_areas:
|
| 38 |
+
mask_areas = flat_mask.sum(-1).tolist()
|
| 39 |
+
# Find the indices where the mask changes
|
| 40 |
+
differences = torch.ones(
|
| 41 |
+
mask.shape[0], flat_mask.shape[1] + 1, device=mask.device, dtype=torch.bool
|
| 42 |
+
)
|
| 43 |
+
differences[:, 1:-1] = flat_mask[:, :-1] != flat_mask[:, 1:]
|
| 44 |
+
differences[:, 0] = flat_mask[:, 0]
|
| 45 |
+
_, change_indices = torch.where(differences)
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
boundaries = torch.cumsum(differences.sum(-1), 0).cpu()
|
| 49 |
+
except RuntimeError as _:
|
| 50 |
+
boundaries = torch.cumsum(differences.cpu().sum(-1), 0)
|
| 51 |
+
|
| 52 |
+
change_indices_clone = change_indices.clone()
|
| 53 |
+
# First pass computes the RLEs on GPU, in a flatten format
|
| 54 |
+
for i in range(mask.shape[0]):
|
| 55 |
+
# Get the change indices for this batch item
|
| 56 |
+
beg = 0 if i == 0 else boundaries[i - 1].item()
|
| 57 |
+
end = boundaries[i].item()
|
| 58 |
+
change_indices[beg + 1 : end] -= change_indices_clone[beg : end - 1]
|
| 59 |
+
|
| 60 |
+
# Now we can split the RLES of each batch item, and convert them to strings
|
| 61 |
+
# No more gpu at this point
|
| 62 |
+
change_indices = change_indices.tolist()
|
| 63 |
+
|
| 64 |
+
batch_rles = []
|
| 65 |
+
# Process each mask in the batch separately
|
| 66 |
+
for i in range(mask.shape[0]):
|
| 67 |
+
beg = 0 if i == 0 else boundaries[i - 1].item()
|
| 68 |
+
end = boundaries[i].item()
|
| 69 |
+
run_lengths = change_indices[beg:end]
|
| 70 |
+
|
| 71 |
+
uncompressed_rle = {"counts": run_lengths, "size": list(orig_mask.shape[1:])}
|
| 72 |
+
h, w = uncompressed_rle["size"]
|
| 73 |
+
rle = mask_util.frPyObjects(uncompressed_rle, h, w)
|
| 74 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
| 75 |
+
if return_areas:
|
| 76 |
+
rle["area"] = mask_areas[i]
|
| 77 |
+
batch_rles.append(rle)
|
| 78 |
+
|
| 79 |
+
return batch_rles
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def robust_rle_encode(masks):
|
| 83 |
+
"""Encodes a collection of masks in RLE format. Uses the gpu version fist, falls back to the cpu version if it fails"""
|
| 84 |
+
|
| 85 |
+
assert masks.ndim == 3, "Mask must be of shape (N, H, W)"
|
| 86 |
+
assert masks.dtype == torch.bool, "Mask must have dtype=torch.bool"
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
return rle_encode(masks)
|
| 90 |
+
except RuntimeError as _:
|
| 91 |
+
masks = masks.cpu().numpy()
|
| 92 |
+
rles = [
|
| 93 |
+
mask_util.encode(
|
| 94 |
+
np.array(mask[:, :, np.newaxis], dtype=np.uint8, order="F")
|
| 95 |
+
)[0]
|
| 96 |
+
for mask in masks
|
| 97 |
+
]
|
| 98 |
+
for rle in rles:
|
| 99 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
| 100 |
+
return rles
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def ann_to_rle(segm, im_info):
|
| 104 |
+
"""Convert annotation which can be polygons, uncompressed RLE to RLE.
|
| 105 |
+
Args:
|
| 106 |
+
ann (dict) : annotation object
|
| 107 |
+
Returns:
|
| 108 |
+
ann (rle)
|
| 109 |
+
"""
|
| 110 |
+
h, w = im_info["height"], im_info["width"]
|
| 111 |
+
if isinstance(segm, list):
|
| 112 |
+
# polygon -- a single object might consist of multiple parts
|
| 113 |
+
# we merge all parts into one mask rle code
|
| 114 |
+
rles = mask_util.frPyObjects(segm, h, w)
|
| 115 |
+
rle = mask_util.merge(rles)
|
| 116 |
+
elif isinstance(segm["counts"], list):
|
| 117 |
+
# uncompressed RLE
|
| 118 |
+
rle = mask_util.frPyObjects(segm, h, w)
|
| 119 |
+
else:
|
| 120 |
+
# rle
|
| 121 |
+
rle = segm
|
| 122 |
+
return rle
|
sam3/agent/helpers/roi_align.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torchvision.ops import roi_align
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# NOTE: torchvision's RoIAlign has a different default aligned=False
|
| 8 |
+
class ROIAlign(nn.Module):
|
| 9 |
+
def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True):
|
| 10 |
+
"""
|
| 11 |
+
Args:
|
| 12 |
+
output_size (tuple): h, w
|
| 13 |
+
spatial_scale (float): scale the input boxes by this number
|
| 14 |
+
sampling_ratio (int): number of inputs samples to take for each output
|
| 15 |
+
sample. 0 to take samples densely.
|
| 16 |
+
aligned (bool): if False, use the legacy implementation in
|
| 17 |
+
Detectron. If True, align the results more perfectly.
|
| 18 |
+
|
| 19 |
+
Note:
|
| 20 |
+
The meaning of aligned=True:
|
| 21 |
+
|
| 22 |
+
Given a continuous coordinate c, its two neighboring pixel indices (in our
|
| 23 |
+
pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example,
|
| 24 |
+
c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled
|
| 25 |
+
from the underlying signal at continuous coordinates 0.5 and 1.5). But the original
|
| 26 |
+
roi_align (aligned=False) does not subtract the 0.5 when computing neighboring
|
| 27 |
+
pixel indices and therefore it uses pixels with a slightly incorrect alignment
|
| 28 |
+
(relative to our pixel model) when performing bilinear interpolation.
|
| 29 |
+
|
| 30 |
+
With `aligned=True`,
|
| 31 |
+
we first appropriately scale the ROI and then shift it by -0.5
|
| 32 |
+
prior to calling roi_align. This produces the correct neighbors; see
|
| 33 |
+
detectron2/tests/test_roi_align.py for verification.
|
| 34 |
+
|
| 35 |
+
The difference does not make a difference to the model's performance if
|
| 36 |
+
ROIAlign is used together with conv layers.
|
| 37 |
+
"""
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.output_size = output_size
|
| 40 |
+
self.spatial_scale = spatial_scale
|
| 41 |
+
self.sampling_ratio = sampling_ratio
|
| 42 |
+
self.aligned = aligned
|
| 43 |
+
|
| 44 |
+
from torchvision import __version__
|
| 45 |
+
|
| 46 |
+
version = tuple(int(x) for x in __version__.split(".")[:2])
|
| 47 |
+
# https://github.com/pytorch/vision/pull/2438
|
| 48 |
+
assert version >= (0, 7), "Require torchvision >= 0.7"
|
| 49 |
+
|
| 50 |
+
def forward(self, input, rois):
|
| 51 |
+
"""
|
| 52 |
+
Args:
|
| 53 |
+
input: NCHW images
|
| 54 |
+
rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
|
| 55 |
+
"""
|
| 56 |
+
assert rois.dim() == 2 and rois.size(1) == 5
|
| 57 |
+
if input.is_quantized:
|
| 58 |
+
input = input.dequantize()
|
| 59 |
+
return roi_align(
|
| 60 |
+
input,
|
| 61 |
+
rois.to(dtype=input.dtype),
|
| 62 |
+
self.output_size,
|
| 63 |
+
self.spatial_scale,
|
| 64 |
+
self.sampling_ratio,
|
| 65 |
+
self.aligned,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def __repr__(self):
|
| 69 |
+
tmpstr = self.__class__.__name__ + "("
|
| 70 |
+
tmpstr += "output_size=" + str(self.output_size)
|
| 71 |
+
tmpstr += ", spatial_scale=" + str(self.spatial_scale)
|
| 72 |
+
tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
|
| 73 |
+
tmpstr += ", aligned=" + str(self.aligned)
|
| 74 |
+
tmpstr += ")"
|
| 75 |
+
return tmpstr
|
sam3/agent/helpers/rotated_boxes.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import List, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
# from detectron2.layers.rotated_boxes import pairwise_iou_rotated
|
| 11 |
+
|
| 12 |
+
from .boxes import Boxes
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def pairwise_iou_rotated(boxes1, boxes2):
|
| 16 |
+
"""
|
| 17 |
+
Return intersection-over-union (Jaccard index) of boxes.
|
| 18 |
+
|
| 19 |
+
Both sets of boxes are expected to be in
|
| 20 |
+
(x_center, y_center, width, height, angle) format.
|
| 21 |
+
|
| 22 |
+
Arguments:
|
| 23 |
+
boxes1 (Tensor[N, 5])
|
| 24 |
+
boxes2 (Tensor[M, 5])
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
| 28 |
+
IoU values for every element in boxes1 and boxes2
|
| 29 |
+
"""
|
| 30 |
+
return torch.ops.detectron2.box_iou_rotated(boxes1, boxes2)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class RotatedBoxes(Boxes):
|
| 34 |
+
"""
|
| 35 |
+
This structure stores a list of rotated boxes as a Nx5 torch.Tensor.
|
| 36 |
+
It supports some common methods about boxes
|
| 37 |
+
(`area`, `clip`, `nonempty`, etc),
|
| 38 |
+
and also behaves like a Tensor
|
| 39 |
+
(support indexing, `to(device)`, `.device`, and iteration over all boxes)
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, tensor: torch.Tensor):
|
| 43 |
+
"""
|
| 44 |
+
Args:
|
| 45 |
+
tensor (Tensor[float]): a Nx5 matrix. Each row is
|
| 46 |
+
(x_center, y_center, width, height, angle),
|
| 47 |
+
in which angle is represented in degrees.
|
| 48 |
+
While there's no strict range restriction for it,
|
| 49 |
+
the recommended principal range is between [-180, 180) degrees.
|
| 50 |
+
|
| 51 |
+
Assume we have a horizontal box B = (x_center, y_center, width, height),
|
| 52 |
+
where width is along the x-axis and height is along the y-axis.
|
| 53 |
+
The rotated box B_rot (x_center, y_center, width, height, angle)
|
| 54 |
+
can be seen as:
|
| 55 |
+
|
| 56 |
+
1. When angle == 0:
|
| 57 |
+
B_rot == B
|
| 58 |
+
2. When angle > 0:
|
| 59 |
+
B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CCW;
|
| 60 |
+
3. When angle < 0:
|
| 61 |
+
B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CW.
|
| 62 |
+
|
| 63 |
+
Mathematically, since the right-handed coordinate system for image space
|
| 64 |
+
is (y, x), where y is top->down and x is left->right, the 4 vertices of the
|
| 65 |
+
rotated rectangle :math:`(yr_i, xr_i)` (i = 1, 2, 3, 4) can be obtained from
|
| 66 |
+
the vertices of the horizontal rectangle :math:`(y_i, x_i)` (i = 1, 2, 3, 4)
|
| 67 |
+
in the following way (:math:`\\theta = angle*\\pi/180` is the angle in radians,
|
| 68 |
+
:math:`(y_c, x_c)` is the center of the rectangle):
|
| 69 |
+
|
| 70 |
+
.. math::
|
| 71 |
+
|
| 72 |
+
yr_i = \\cos(\\theta) (y_i - y_c) - \\sin(\\theta) (x_i - x_c) + y_c,
|
| 73 |
+
|
| 74 |
+
xr_i = \\sin(\\theta) (y_i - y_c) + \\cos(\\theta) (x_i - x_c) + x_c,
|
| 75 |
+
|
| 76 |
+
which is the standard rigid-body rotation transformation.
|
| 77 |
+
|
| 78 |
+
Intuitively, the angle is
|
| 79 |
+
(1) the rotation angle from y-axis in image space
|
| 80 |
+
to the height vector (top->down in the box's local coordinate system)
|
| 81 |
+
of the box in CCW, and
|
| 82 |
+
(2) the rotation angle from x-axis in image space
|
| 83 |
+
to the width vector (left->right in the box's local coordinate system)
|
| 84 |
+
of the box in CCW.
|
| 85 |
+
|
| 86 |
+
More intuitively, consider the following horizontal box ABCD represented
|
| 87 |
+
in (x1, y1, x2, y2): (3, 2, 7, 4),
|
| 88 |
+
covering the [3, 7] x [2, 4] region of the continuous coordinate system
|
| 89 |
+
which looks like this:
|
| 90 |
+
|
| 91 |
+
.. code:: none
|
| 92 |
+
|
| 93 |
+
O--------> x
|
| 94 |
+
|
|
| 95 |
+
| A---B
|
| 96 |
+
| | |
|
| 97 |
+
| D---C
|
| 98 |
+
|
|
| 99 |
+
v y
|
| 100 |
+
|
| 101 |
+
Note that each capital letter represents one 0-dimensional geometric point
|
| 102 |
+
instead of a 'square pixel' here.
|
| 103 |
+
|
| 104 |
+
In the example above, using (x, y) to represent a point we have:
|
| 105 |
+
|
| 106 |
+
.. math::
|
| 107 |
+
|
| 108 |
+
O = (0, 0), A = (3, 2), B = (7, 2), C = (7, 4), D = (3, 4)
|
| 109 |
+
|
| 110 |
+
We name vector AB = vector DC as the width vector in box's local coordinate system, and
|
| 111 |
+
vector AD = vector BC as the height vector in box's local coordinate system. Initially,
|
| 112 |
+
when angle = 0 degree, they're aligned with the positive directions of x-axis and y-axis
|
| 113 |
+
in the image space, respectively.
|
| 114 |
+
|
| 115 |
+
For better illustration, we denote the center of the box as E,
|
| 116 |
+
|
| 117 |
+
.. code:: none
|
| 118 |
+
|
| 119 |
+
O--------> x
|
| 120 |
+
|
|
| 121 |
+
| A---B
|
| 122 |
+
| | E |
|
| 123 |
+
| D---C
|
| 124 |
+
|
|
| 125 |
+
v y
|
| 126 |
+
|
| 127 |
+
where the center E = ((3+7)/2, (2+4)/2) = (5, 3).
|
| 128 |
+
|
| 129 |
+
Also,
|
| 130 |
+
|
| 131 |
+
.. math::
|
| 132 |
+
|
| 133 |
+
width = |AB| = |CD| = 7 - 3 = 4,
|
| 134 |
+
height = |AD| = |BC| = 4 - 2 = 2.
|
| 135 |
+
|
| 136 |
+
Therefore, the corresponding representation for the same shape in rotated box in
|
| 137 |
+
(x_center, y_center, width, height, angle) format is:
|
| 138 |
+
|
| 139 |
+
(5, 3, 4, 2, 0),
|
| 140 |
+
|
| 141 |
+
Now, let's consider (5, 3, 4, 2, 90), which is rotated by 90 degrees
|
| 142 |
+
CCW (counter-clockwise) by definition. It looks like this:
|
| 143 |
+
|
| 144 |
+
.. code:: none
|
| 145 |
+
|
| 146 |
+
O--------> x
|
| 147 |
+
| B-C
|
| 148 |
+
| | |
|
| 149 |
+
| |E|
|
| 150 |
+
| | |
|
| 151 |
+
| A-D
|
| 152 |
+
v y
|
| 153 |
+
|
| 154 |
+
The center E is still located at the same point (5, 3), while the vertices
|
| 155 |
+
ABCD are rotated by 90 degrees CCW with regard to E:
|
| 156 |
+
A = (4, 5), B = (4, 1), C = (6, 1), D = (6, 5)
|
| 157 |
+
|
| 158 |
+
Here, 90 degrees can be seen as the CCW angle to rotate from y-axis to
|
| 159 |
+
vector AD or vector BC (the top->down height vector in box's local coordinate system),
|
| 160 |
+
or the CCW angle to rotate from x-axis to vector AB or vector DC (the left->right
|
| 161 |
+
width vector in box's local coordinate system).
|
| 162 |
+
|
| 163 |
+
.. math::
|
| 164 |
+
|
| 165 |
+
width = |AB| = |CD| = 5 - 1 = 4,
|
| 166 |
+
height = |AD| = |BC| = 6 - 4 = 2.
|
| 167 |
+
|
| 168 |
+
Next, how about (5, 3, 4, 2, -90), which is rotated by 90 degrees CW (clockwise)
|
| 169 |
+
by definition? It looks like this:
|
| 170 |
+
|
| 171 |
+
.. code:: none
|
| 172 |
+
|
| 173 |
+
O--------> x
|
| 174 |
+
| D-A
|
| 175 |
+
| | |
|
| 176 |
+
| |E|
|
| 177 |
+
| | |
|
| 178 |
+
| C-B
|
| 179 |
+
v y
|
| 180 |
+
|
| 181 |
+
The center E is still located at the same point (5, 3), while the vertices
|
| 182 |
+
ABCD are rotated by 90 degrees CW with regard to E:
|
| 183 |
+
A = (6, 1), B = (6, 5), C = (4, 5), D = (4, 1)
|
| 184 |
+
|
| 185 |
+
.. math::
|
| 186 |
+
|
| 187 |
+
width = |AB| = |CD| = 5 - 1 = 4,
|
| 188 |
+
height = |AD| = |BC| = 6 - 4 = 2.
|
| 189 |
+
|
| 190 |
+
This covers exactly the same region as (5, 3, 4, 2, 90) does, and their IoU
|
| 191 |
+
will be 1. However, these two will generate different RoI Pooling results and
|
| 192 |
+
should not be treated as an identical box.
|
| 193 |
+
|
| 194 |
+
On the other hand, it's easy to see that (X, Y, W, H, A) is identical to
|
| 195 |
+
(X, Y, W, H, A+360N), for any integer N. For example (5, 3, 4, 2, 270) would be
|
| 196 |
+
identical to (5, 3, 4, 2, -90), because rotating the shape 270 degrees CCW is
|
| 197 |
+
equivalent to rotating the same shape 90 degrees CW.
|
| 198 |
+
|
| 199 |
+
We could rotate further to get (5, 3, 4, 2, 180), or (5, 3, 4, 2, -180):
|
| 200 |
+
|
| 201 |
+
.. code:: none
|
| 202 |
+
|
| 203 |
+
O--------> x
|
| 204 |
+
|
|
| 205 |
+
| C---D
|
| 206 |
+
| | E |
|
| 207 |
+
| B---A
|
| 208 |
+
|
|
| 209 |
+
v y
|
| 210 |
+
|
| 211 |
+
.. math::
|
| 212 |
+
|
| 213 |
+
A = (7, 4), B = (3, 4), C = (3, 2), D = (7, 2),
|
| 214 |
+
|
| 215 |
+
width = |AB| = |CD| = 7 - 3 = 4,
|
| 216 |
+
height = |AD| = |BC| = 4 - 2 = 2.
|
| 217 |
+
|
| 218 |
+
Finally, this is a very inaccurate (heavily quantized) illustration of
|
| 219 |
+
how (5, 3, 4, 2, 60) looks like in case anyone wonders:
|
| 220 |
+
|
| 221 |
+
.. code:: none
|
| 222 |
+
|
| 223 |
+
O--------> x
|
| 224 |
+
| B\
|
| 225 |
+
| / C
|
| 226 |
+
| /E /
|
| 227 |
+
| A /
|
| 228 |
+
| `D
|
| 229 |
+
v y
|
| 230 |
+
|
| 231 |
+
It's still a rectangle with center of (5, 3), width of 4 and height of 2,
|
| 232 |
+
but its angle (and thus orientation) is somewhere between
|
| 233 |
+
(5, 3, 4, 2, 0) and (5, 3, 4, 2, 90).
|
| 234 |
+
"""
|
| 235 |
+
device = (
|
| 236 |
+
tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
|
| 237 |
+
)
|
| 238 |
+
tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
|
| 239 |
+
if tensor.numel() == 0:
|
| 240 |
+
# Use reshape, so we don't end up creating a new tensor that does not depend on
|
| 241 |
+
# the inputs (and consequently confuses jit)
|
| 242 |
+
tensor = tensor.reshape((0, 5)).to(dtype=torch.float32, device=device)
|
| 243 |
+
assert tensor.dim() == 2 and tensor.size(-1) == 5, tensor.size()
|
| 244 |
+
|
| 245 |
+
self.tensor = tensor
|
| 246 |
+
|
| 247 |
+
def clone(self) -> "RotatedBoxes":
|
| 248 |
+
"""
|
| 249 |
+
Clone the RotatedBoxes.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
RotatedBoxes
|
| 253 |
+
"""
|
| 254 |
+
return RotatedBoxes(self.tensor.clone())
|
| 255 |
+
|
| 256 |
+
def to(self, device: torch.device, non_blocking: bool = False):
|
| 257 |
+
# Boxes are assumed float32 and does not support to(dtype)
|
| 258 |
+
return RotatedBoxes(self.tensor.to(device=device, non_blocking=non_blocking))
|
| 259 |
+
|
| 260 |
+
def area(self) -> torch.Tensor:
|
| 261 |
+
"""
|
| 262 |
+
Computes the area of all the boxes.
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
torch.Tensor: a vector with areas of each box.
|
| 266 |
+
"""
|
| 267 |
+
box = self.tensor
|
| 268 |
+
area = box[:, 2] * box[:, 3]
|
| 269 |
+
return area
|
| 270 |
+
|
| 271 |
+
# Avoid in-place operations so that we can torchscript; NOTE: this creates a new tensor
|
| 272 |
+
def normalize_angles(self) -> None:
|
| 273 |
+
"""
|
| 274 |
+
Restrict angles to the range of [-180, 180) degrees
|
| 275 |
+
"""
|
| 276 |
+
angle_tensor = (self.tensor[:, 4] + 180.0) % 360.0 - 180.0
|
| 277 |
+
self.tensor = torch.cat((self.tensor[:, :4], angle_tensor[:, None]), dim=1)
|
| 278 |
+
|
| 279 |
+
def clip(
|
| 280 |
+
self, box_size: Tuple[int, int], clip_angle_threshold: float = 1.0
|
| 281 |
+
) -> None:
|
| 282 |
+
"""
|
| 283 |
+
Clip (in place) the boxes by limiting x coordinates to the range [0, width]
|
| 284 |
+
and y coordinates to the range [0, height].
|
| 285 |
+
|
| 286 |
+
For RRPN:
|
| 287 |
+
Only clip boxes that are almost horizontal with a tolerance of
|
| 288 |
+
clip_angle_threshold to maintain backward compatibility.
|
| 289 |
+
|
| 290 |
+
Rotated boxes beyond this threshold are not clipped for two reasons:
|
| 291 |
+
|
| 292 |
+
1. There are potentially multiple ways to clip a rotated box to make it
|
| 293 |
+
fit within the image.
|
| 294 |
+
2. It's tricky to make the entire rectangular box fit within the image
|
| 295 |
+
and still be able to not leave out pixels of interest.
|
| 296 |
+
|
| 297 |
+
Therefore we rely on ops like RoIAlignRotated to safely handle this.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
box_size (height, width): The clipping box's size.
|
| 301 |
+
clip_angle_threshold:
|
| 302 |
+
Iff. abs(normalized(angle)) <= clip_angle_threshold (in degrees),
|
| 303 |
+
we do the clipping as horizontal boxes.
|
| 304 |
+
"""
|
| 305 |
+
h, w = box_size
|
| 306 |
+
|
| 307 |
+
# normalize angles to be within (-180, 180] degrees
|
| 308 |
+
self.normalize_angles()
|
| 309 |
+
|
| 310 |
+
idx = torch.where(torch.abs(self.tensor[:, 4]) <= clip_angle_threshold)[0]
|
| 311 |
+
|
| 312 |
+
# convert to (x1, y1, x2, y2)
|
| 313 |
+
x1 = self.tensor[idx, 0] - self.tensor[idx, 2] / 2.0
|
| 314 |
+
y1 = self.tensor[idx, 1] - self.tensor[idx, 3] / 2.0
|
| 315 |
+
x2 = self.tensor[idx, 0] + self.tensor[idx, 2] / 2.0
|
| 316 |
+
y2 = self.tensor[idx, 1] + self.tensor[idx, 3] / 2.0
|
| 317 |
+
|
| 318 |
+
# clip
|
| 319 |
+
x1.clamp_(min=0, max=w)
|
| 320 |
+
y1.clamp_(min=0, max=h)
|
| 321 |
+
x2.clamp_(min=0, max=w)
|
| 322 |
+
y2.clamp_(min=0, max=h)
|
| 323 |
+
|
| 324 |
+
# convert back to (xc, yc, w, h)
|
| 325 |
+
self.tensor[idx, 0] = (x1 + x2) / 2.0
|
| 326 |
+
self.tensor[idx, 1] = (y1 + y2) / 2.0
|
| 327 |
+
# make sure widths and heights do not increase due to numerical errors
|
| 328 |
+
self.tensor[idx, 2] = torch.min(self.tensor[idx, 2], x2 - x1)
|
| 329 |
+
self.tensor[idx, 3] = torch.min(self.tensor[idx, 3], y2 - y1)
|
| 330 |
+
|
| 331 |
+
def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
|
| 332 |
+
"""
|
| 333 |
+
Find boxes that are non-empty.
|
| 334 |
+
A box is considered empty, if either of its side is no larger than threshold.
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
Tensor: a binary vector which represents
|
| 338 |
+
whether each box is empty (False) or non-empty (True).
|
| 339 |
+
"""
|
| 340 |
+
box = self.tensor
|
| 341 |
+
widths = box[:, 2]
|
| 342 |
+
heights = box[:, 3]
|
| 343 |
+
keep = (widths > threshold) & (heights > threshold)
|
| 344 |
+
return keep
|
| 345 |
+
|
| 346 |
+
def __getitem__(self, item) -> "RotatedBoxes":
|
| 347 |
+
"""
|
| 348 |
+
Returns:
|
| 349 |
+
RotatedBoxes: Create a new :class:`RotatedBoxes` by indexing.
|
| 350 |
+
|
| 351 |
+
The following usage are allowed:
|
| 352 |
+
|
| 353 |
+
1. `new_boxes = boxes[3]`: return a `RotatedBoxes` which contains only one box.
|
| 354 |
+
2. `new_boxes = boxes[2:10]`: return a slice of boxes.
|
| 355 |
+
3. `new_boxes = boxes[vector]`, where vector is a torch.ByteTensor
|
| 356 |
+
with `length = len(boxes)`. Nonzero elements in the vector will be selected.
|
| 357 |
+
|
| 358 |
+
Note that the returned RotatedBoxes might share storage with this RotatedBoxes,
|
| 359 |
+
subject to Pytorch's indexing semantics.
|
| 360 |
+
"""
|
| 361 |
+
if isinstance(item, int):
|
| 362 |
+
return RotatedBoxes(self.tensor[item].view(1, -1))
|
| 363 |
+
b = self.tensor[item]
|
| 364 |
+
assert (
|
| 365 |
+
b.dim() == 2
|
| 366 |
+
), "Indexing on RotatedBoxes with {} failed to return a matrix!".format(item)
|
| 367 |
+
return RotatedBoxes(b)
|
| 368 |
+
|
| 369 |
+
def __len__(self) -> int:
|
| 370 |
+
return self.tensor.shape[0]
|
| 371 |
+
|
| 372 |
+
def __repr__(self) -> str:
|
| 373 |
+
return "RotatedBoxes(" + str(self.tensor) + ")"
|
| 374 |
+
|
| 375 |
+
def inside_box(
|
| 376 |
+
self, box_size: Tuple[int, int], boundary_threshold: int = 0
|
| 377 |
+
) -> torch.Tensor:
|
| 378 |
+
"""
|
| 379 |
+
Args:
|
| 380 |
+
box_size (height, width): Size of the reference box covering
|
| 381 |
+
[0, width] x [0, height]
|
| 382 |
+
boundary_threshold (int): Boxes that extend beyond the reference box
|
| 383 |
+
boundary by more than boundary_threshold are considered "outside".
|
| 384 |
+
|
| 385 |
+
For RRPN, it might not be necessary to call this function since it's common
|
| 386 |
+
for rotated box to extend to outside of the image boundaries
|
| 387 |
+
(the clip function only clips the near-horizontal boxes)
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
a binary vector, indicating whether each box is inside the reference box.
|
| 391 |
+
"""
|
| 392 |
+
height, width = box_size
|
| 393 |
+
|
| 394 |
+
cnt_x = self.tensor[..., 0]
|
| 395 |
+
cnt_y = self.tensor[..., 1]
|
| 396 |
+
half_w = self.tensor[..., 2] / 2.0
|
| 397 |
+
half_h = self.tensor[..., 3] / 2.0
|
| 398 |
+
a = self.tensor[..., 4]
|
| 399 |
+
c = torch.abs(torch.cos(a * math.pi / 180.0))
|
| 400 |
+
s = torch.abs(torch.sin(a * math.pi / 180.0))
|
| 401 |
+
# This basically computes the horizontal bounding rectangle of the rotated box
|
| 402 |
+
max_rect_dx = c * half_w + s * half_h
|
| 403 |
+
max_rect_dy = c * half_h + s * half_w
|
| 404 |
+
|
| 405 |
+
inds_inside = (
|
| 406 |
+
(cnt_x - max_rect_dx >= -boundary_threshold)
|
| 407 |
+
& (cnt_y - max_rect_dy >= -boundary_threshold)
|
| 408 |
+
& (cnt_x + max_rect_dx < width + boundary_threshold)
|
| 409 |
+
& (cnt_y + max_rect_dy < height + boundary_threshold)
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
return inds_inside
|
| 413 |
+
|
| 414 |
+
def get_centers(self) -> torch.Tensor:
|
| 415 |
+
"""
|
| 416 |
+
Returns:
|
| 417 |
+
The box centers in a Nx2 array of (x, y).
|
| 418 |
+
"""
|
| 419 |
+
return self.tensor[:, :2]
|
| 420 |
+
|
| 421 |
+
def scale(self, scale_x: float, scale_y: float) -> None:
|
| 422 |
+
"""
|
| 423 |
+
Scale the rotated box with horizontal and vertical scaling factors
|
| 424 |
+
Note: when scale_factor_x != scale_factor_y,
|
| 425 |
+
the rotated box does not preserve the rectangular shape when the angle
|
| 426 |
+
is not a multiple of 90 degrees under resize transformation.
|
| 427 |
+
Instead, the shape is a parallelogram (that has skew)
|
| 428 |
+
Here we make an approximation by fitting a rotated rectangle to the parallelogram.
|
| 429 |
+
"""
|
| 430 |
+
self.tensor[:, 0] *= scale_x
|
| 431 |
+
self.tensor[:, 1] *= scale_y
|
| 432 |
+
theta = self.tensor[:, 4] * math.pi / 180.0
|
| 433 |
+
c = torch.cos(theta)
|
| 434 |
+
s = torch.sin(theta)
|
| 435 |
+
|
| 436 |
+
# In image space, y is top->down and x is left->right
|
| 437 |
+
# Consider the local coordintate system for the rotated box,
|
| 438 |
+
# where the box center is located at (0, 0), and the four vertices ABCD are
|
| 439 |
+
# A(-w / 2, -h / 2), B(w / 2, -h / 2), C(w / 2, h / 2), D(-w / 2, h / 2)
|
| 440 |
+
# the midpoint of the left edge AD of the rotated box E is:
|
| 441 |
+
# E = (A+D)/2 = (-w / 2, 0)
|
| 442 |
+
# the midpoint of the top edge AB of the rotated box F is:
|
| 443 |
+
# F(0, -h / 2)
|
| 444 |
+
# To get the old coordinates in the global system, apply the rotation transformation
|
| 445 |
+
# (Note: the right-handed coordinate system for image space is yOx):
|
| 446 |
+
# (old_x, old_y) = (s * y + c * x, c * y - s * x)
|
| 447 |
+
# E(old) = (s * 0 + c * (-w/2), c * 0 - s * (-w/2)) = (-c * w / 2, s * w / 2)
|
| 448 |
+
# F(old) = (s * (-h / 2) + c * 0, c * (-h / 2) - s * 0) = (-s * h / 2, -c * h / 2)
|
| 449 |
+
# After applying the scaling factor (sfx, sfy):
|
| 450 |
+
# E(new) = (-sfx * c * w / 2, sfy * s * w / 2)
|
| 451 |
+
# F(new) = (-sfx * s * h / 2, -sfy * c * h / 2)
|
| 452 |
+
# The new width after scaling tranformation becomes:
|
| 453 |
+
|
| 454 |
+
# w(new) = |E(new) - O| * 2
|
| 455 |
+
# = sqrt[(sfx * c * w / 2)^2 + (sfy * s * w / 2)^2] * 2
|
| 456 |
+
# = sqrt[(sfx * c)^2 + (sfy * s)^2] * w
|
| 457 |
+
# i.e., scale_factor_w = sqrt[(sfx * c)^2 + (sfy * s)^2]
|
| 458 |
+
#
|
| 459 |
+
# For example,
|
| 460 |
+
# when angle = 0 or 180, |c| = 1, s = 0, scale_factor_w == scale_factor_x;
|
| 461 |
+
# when |angle| = 90, c = 0, |s| = 1, scale_factor_w == scale_factor_y
|
| 462 |
+
self.tensor[:, 2] *= torch.sqrt((scale_x * c) ** 2 + (scale_y * s) ** 2)
|
| 463 |
+
|
| 464 |
+
# h(new) = |F(new) - O| * 2
|
| 465 |
+
# = sqrt[(sfx * s * h / 2)^2 + (sfy * c * h / 2)^2] * 2
|
| 466 |
+
# = sqrt[(sfx * s)^2 + (sfy * c)^2] * h
|
| 467 |
+
# i.e., scale_factor_h = sqrt[(sfx * s)^2 + (sfy * c)^2]
|
| 468 |
+
#
|
| 469 |
+
# For example,
|
| 470 |
+
# when angle = 0 or 180, |c| = 1, s = 0, scale_factor_h == scale_factor_y;
|
| 471 |
+
# when |angle| = 90, c = 0, |s| = 1, scale_factor_h == scale_factor_x
|
| 472 |
+
self.tensor[:, 3] *= torch.sqrt((scale_x * s) ** 2 + (scale_y * c) ** 2)
|
| 473 |
+
|
| 474 |
+
# The angle is the rotation angle from y-axis in image space to the height
|
| 475 |
+
# vector (top->down in the box's local coordinate system) of the box in CCW.
|
| 476 |
+
#
|
| 477 |
+
# angle(new) = angle_yOx(O - F(new))
|
| 478 |
+
# = angle_yOx( (sfx * s * h / 2, sfy * c * h / 2) )
|
| 479 |
+
# = atan2(sfx * s * h / 2, sfy * c * h / 2)
|
| 480 |
+
# = atan2(sfx * s, sfy * c)
|
| 481 |
+
#
|
| 482 |
+
# For example,
|
| 483 |
+
# when sfx == sfy, angle(new) == atan2(s, c) == angle(old)
|
| 484 |
+
self.tensor[:, 4] = torch.atan2(scale_x * s, scale_y * c) * 180 / math.pi
|
| 485 |
+
|
| 486 |
+
@classmethod
|
| 487 |
+
def cat(cls, boxes_list: List["RotatedBoxes"]) -> "RotatedBoxes":
|
| 488 |
+
"""
|
| 489 |
+
Concatenates a list of RotatedBoxes into a single RotatedBoxes
|
| 490 |
+
|
| 491 |
+
Arguments:
|
| 492 |
+
boxes_list (list[RotatedBoxes])
|
| 493 |
+
|
| 494 |
+
Returns:
|
| 495 |
+
RotatedBoxes: the concatenated RotatedBoxes
|
| 496 |
+
"""
|
| 497 |
+
assert isinstance(boxes_list, (list, tuple))
|
| 498 |
+
if len(boxes_list) == 0:
|
| 499 |
+
return cls(torch.empty(0))
|
| 500 |
+
assert all([isinstance(box, RotatedBoxes) for box in boxes_list])
|
| 501 |
+
|
| 502 |
+
# use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input
|
| 503 |
+
cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0))
|
| 504 |
+
return cat_boxes
|
| 505 |
+
|
| 506 |
+
@property
|
| 507 |
+
def device(self) -> torch.device:
|
| 508 |
+
return self.tensor.device
|
| 509 |
+
|
| 510 |
+
@torch.jit.unused
|
| 511 |
+
def __iter__(self):
|
| 512 |
+
"""
|
| 513 |
+
Yield a box as a Tensor of shape (5,) at a time.
|
| 514 |
+
"""
|
| 515 |
+
yield from self.tensor
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def pairwise_iou(boxes1: RotatedBoxes, boxes2: RotatedBoxes) -> None:
|
| 519 |
+
"""
|
| 520 |
+
Given two lists of rotated boxes of size N and M,
|
| 521 |
+
compute the IoU (intersection over union)
|
| 522 |
+
between **all** N x M pairs of boxes.
|
| 523 |
+
The box order must be (x_center, y_center, width, height, angle).
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
boxes1, boxes2 (RotatedBoxes):
|
| 527 |
+
two `RotatedBoxes`. Contains N & M rotated boxes, respectively.
|
| 528 |
+
|
| 529 |
+
Returns:
|
| 530 |
+
Tensor: IoU, sized [N,M].
|
| 531 |
+
"""
|
| 532 |
+
|
| 533 |
+
return pairwise_iou_rotated(boxes1.tensor, boxes2.tensor)
|
sam3/agent/helpers/som_utils.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import colorsys
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
import matplotlib as mpl
|
| 9 |
+
import matplotlib.colors as mplc
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pycocotools.mask as mask_utils
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def rgb_to_hex(rgb_color):
|
| 15 |
+
"""
|
| 16 |
+
Convert a rgb color to hex color.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
rgb_color (tuple/list of ints): RGB color in tuple or list format.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
str: Hex color.
|
| 23 |
+
|
| 24 |
+
Example:
|
| 25 |
+
```
|
| 26 |
+
>>> rgb_to_hex((255, 0, 244))
|
| 27 |
+
'#ff00ff'
|
| 28 |
+
```
|
| 29 |
+
"""
|
| 30 |
+
return "#" + "".join([hex(c)[2:].zfill(2) for c in rgb_color])
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# DEFAULT_COLOR_HEX_TO_NAME = {
|
| 34 |
+
# rgb_to_hex((255, 0, 0)): "red",
|
| 35 |
+
# rgb_to_hex((0, 255, 0)): "lime",
|
| 36 |
+
# rgb_to_hex((0, 0, 255)): "blue",
|
| 37 |
+
# rgb_to_hex((255, 255, 0)): "yellow",
|
| 38 |
+
# rgb_to_hex((255, 0, 255)): "fuchsia",
|
| 39 |
+
# rgb_to_hex((0, 255, 255)): "aqua",
|
| 40 |
+
# rgb_to_hex((255, 165, 0)): "orange",
|
| 41 |
+
# rgb_to_hex((128, 0, 128)): "purple",
|
| 42 |
+
# rgb_to_hex((255, 215, 0)): "gold",
|
| 43 |
+
# }
|
| 44 |
+
|
| 45 |
+
# Assuming rgb_to_hex is a function that converts an (R, G, B) tuple to a hex string.
|
| 46 |
+
# For example: def rgb_to_hex(rgb): return '#%02x%02x%02x' % rgb
|
| 47 |
+
|
| 48 |
+
DEFAULT_COLOR_HEX_TO_NAME = {
|
| 49 |
+
# The top 20 approved colors
|
| 50 |
+
rgb_to_hex((255, 255, 0)): "yellow",
|
| 51 |
+
rgb_to_hex((0, 255, 0)): "lime",
|
| 52 |
+
rgb_to_hex((0, 255, 255)): "cyan",
|
| 53 |
+
rgb_to_hex((255, 0, 255)): "magenta",
|
| 54 |
+
rgb_to_hex((255, 0, 0)): "red",
|
| 55 |
+
rgb_to_hex((255, 127, 0)): "orange",
|
| 56 |
+
rgb_to_hex((127, 255, 0)): "chartreuse",
|
| 57 |
+
rgb_to_hex((0, 255, 127)): "spring green",
|
| 58 |
+
rgb_to_hex((255, 0, 127)): "rose",
|
| 59 |
+
rgb_to_hex((127, 0, 255)): "violet",
|
| 60 |
+
rgb_to_hex((192, 255, 0)): "electric lime",
|
| 61 |
+
rgb_to_hex((255, 192, 0)): "vivid orange",
|
| 62 |
+
rgb_to_hex((0, 255, 192)): "turquoise",
|
| 63 |
+
rgb_to_hex((192, 0, 255)): "bright violet",
|
| 64 |
+
rgb_to_hex((255, 0, 192)): "bright pink",
|
| 65 |
+
rgb_to_hex((255, 64, 0)): "fiery orange",
|
| 66 |
+
rgb_to_hex((64, 255, 0)): "bright chartreuse",
|
| 67 |
+
rgb_to_hex((0, 255, 64)): "malachite",
|
| 68 |
+
rgb_to_hex((64, 0, 255)): "deep violet",
|
| 69 |
+
rgb_to_hex((255, 0, 64)): "hot pink",
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
DEFAULT_COLOR_PALETTE = list(DEFAULT_COLOR_HEX_TO_NAME.keys())
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _validate_color_hex(color_hex: str):
|
| 77 |
+
color_hex = color_hex.lstrip("#")
|
| 78 |
+
if not all(c in "0123456789abcdefABCDEF" for c in color_hex):
|
| 79 |
+
raise ValueError("Invalid characters in color hash")
|
| 80 |
+
if len(color_hex) not in (3, 6):
|
| 81 |
+
raise ValueError("Invalid length of color hash")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# copied from https://github.com/roboflow/supervision/blob/c8f557af0c61b5c03392bad2cc36c8835598b1e1/supervision/draw/color.py
|
| 85 |
+
@dataclass
|
| 86 |
+
class Color:
|
| 87 |
+
"""
|
| 88 |
+
Represents a color in RGB format.
|
| 89 |
+
|
| 90 |
+
Attributes:
|
| 91 |
+
r (int): Red channel.
|
| 92 |
+
g (int): Green channel.
|
| 93 |
+
b (int): Blue channel.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
r: int
|
| 97 |
+
g: int
|
| 98 |
+
b: int
|
| 99 |
+
|
| 100 |
+
@classmethod
|
| 101 |
+
def from_hex(cls, color_hex: str):
|
| 102 |
+
"""
|
| 103 |
+
Create a Color instance from a hex string.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
color_hex (str): Hex string of the color.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Color: Instance representing the color.
|
| 110 |
+
|
| 111 |
+
Example:
|
| 112 |
+
```
|
| 113 |
+
>>> Color.from_hex('#ff00ff')
|
| 114 |
+
Color(r=255, g=0, b=255)
|
| 115 |
+
```
|
| 116 |
+
"""
|
| 117 |
+
_validate_color_hex(color_hex)
|
| 118 |
+
color_hex = color_hex.lstrip("#")
|
| 119 |
+
if len(color_hex) == 3:
|
| 120 |
+
color_hex = "".join(c * 2 for c in color_hex)
|
| 121 |
+
r, g, b = (int(color_hex[i : i + 2], 16) for i in range(0, 6, 2))
|
| 122 |
+
return cls(r, g, b)
|
| 123 |
+
|
| 124 |
+
@classmethod
|
| 125 |
+
def to_hex(cls, color):
|
| 126 |
+
"""
|
| 127 |
+
Convert a Color instance to a hex string.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
color (Color): Color instance of color.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Color: a hex string.
|
| 134 |
+
"""
|
| 135 |
+
return rgb_to_hex((color.r, color.g, color.b))
|
| 136 |
+
|
| 137 |
+
def as_rgb(self) -> Tuple[int, int, int]:
|
| 138 |
+
"""
|
| 139 |
+
Returns the color as an RGB tuple.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Tuple[int, int, int]: RGB tuple.
|
| 143 |
+
|
| 144 |
+
Example:
|
| 145 |
+
```
|
| 146 |
+
>>> color.as_rgb()
|
| 147 |
+
(255, 0, 255)
|
| 148 |
+
```
|
| 149 |
+
"""
|
| 150 |
+
return self.r, self.g, self.b
|
| 151 |
+
|
| 152 |
+
def as_bgr(self) -> Tuple[int, int, int]:
|
| 153 |
+
"""
|
| 154 |
+
Returns the color as a BGR tuple.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Tuple[int, int, int]: BGR tuple.
|
| 158 |
+
|
| 159 |
+
Example:
|
| 160 |
+
```
|
| 161 |
+
>>> color.as_bgr()
|
| 162 |
+
(255, 0, 255)
|
| 163 |
+
```
|
| 164 |
+
"""
|
| 165 |
+
return self.b, self.g, self.r
|
| 166 |
+
|
| 167 |
+
@classmethod
|
| 168 |
+
def white(cls):
|
| 169 |
+
return Color.from_hex(color_hex="#ffffff")
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def black(cls):
|
| 173 |
+
return Color.from_hex(color_hex="#000000")
|
| 174 |
+
|
| 175 |
+
@classmethod
|
| 176 |
+
def red(cls):
|
| 177 |
+
return Color.from_hex(color_hex="#ff0000")
|
| 178 |
+
|
| 179 |
+
@classmethod
|
| 180 |
+
def green(cls):
|
| 181 |
+
return Color.from_hex(color_hex="#00ff00")
|
| 182 |
+
|
| 183 |
+
@classmethod
|
| 184 |
+
def blue(cls):
|
| 185 |
+
return Color.from_hex(color_hex="#0000ff")
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@dataclass
|
| 189 |
+
class ColorPalette:
|
| 190 |
+
colors: List[Color]
|
| 191 |
+
|
| 192 |
+
@classmethod
|
| 193 |
+
def default(cls):
|
| 194 |
+
"""
|
| 195 |
+
Returns a default color palette.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
ColorPalette: A ColorPalette instance with default colors.
|
| 199 |
+
|
| 200 |
+
Example:
|
| 201 |
+
```
|
| 202 |
+
>>> ColorPalette.default()
|
| 203 |
+
ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
|
| 204 |
+
```
|
| 205 |
+
"""
|
| 206 |
+
return ColorPalette.from_hex(color_hex_list=DEFAULT_COLOR_PALETTE)
|
| 207 |
+
|
| 208 |
+
@classmethod
|
| 209 |
+
def from_hex(cls, color_hex_list: List[str]):
|
| 210 |
+
"""
|
| 211 |
+
Create a ColorPalette instance from a list of hex strings.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
color_hex_list (List[str]): List of color hex strings.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
ColorPalette: A ColorPalette instance.
|
| 218 |
+
|
| 219 |
+
Example:
|
| 220 |
+
```
|
| 221 |
+
>>> ColorPalette.from_hex(['#ff0000', '#00ff00', '#0000ff'])
|
| 222 |
+
ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
|
| 223 |
+
```
|
| 224 |
+
"""
|
| 225 |
+
colors = [Color.from_hex(color_hex) for color_hex in color_hex_list]
|
| 226 |
+
return cls(colors)
|
| 227 |
+
|
| 228 |
+
def by_idx(self, idx: int) -> Color:
|
| 229 |
+
"""
|
| 230 |
+
Return the color at a given index in the palette.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
idx (int): Index of the color in the palette.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
Color: Color at the given index.
|
| 237 |
+
|
| 238 |
+
Example:
|
| 239 |
+
```
|
| 240 |
+
>>> color_palette.by_idx(1)
|
| 241 |
+
Color(r=0, g=255, b=0)
|
| 242 |
+
```
|
| 243 |
+
"""
|
| 244 |
+
if idx < 0:
|
| 245 |
+
raise ValueError("idx argument should not be negative")
|
| 246 |
+
idx = idx % len(self.colors)
|
| 247 |
+
return self.colors[idx]
|
| 248 |
+
|
| 249 |
+
def find_farthest_color(self, img_array):
|
| 250 |
+
"""
|
| 251 |
+
Return the color that is the farthest from the given color.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
img_array (np array): any *x3 np array, 3 is the RGB color channel.
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
Color: Farthest color.
|
| 258 |
+
|
| 259 |
+
"""
|
| 260 |
+
# Reshape the image array for broadcasting
|
| 261 |
+
img_array = img_array.reshape((-1, 3))
|
| 262 |
+
|
| 263 |
+
# Convert colors dictionary to a NumPy array
|
| 264 |
+
color_values = np.array([[c.r, c.g, c.b] for c in self.colors])
|
| 265 |
+
|
| 266 |
+
# Calculate the Euclidean distance between the colors and each pixel in the image
|
| 267 |
+
# Broadcasting happens here: img_array shape is (num_pixels, 3), color_values shape is (num_colors, 3)
|
| 268 |
+
distances = np.sqrt(
|
| 269 |
+
np.sum((img_array[:, np.newaxis, :] - color_values) ** 2, axis=2)
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Average the distances for each color
|
| 273 |
+
mean_distances = np.mean(distances, axis=0)
|
| 274 |
+
|
| 275 |
+
# return the farthest color
|
| 276 |
+
farthest_idx = np.argmax(mean_distances)
|
| 277 |
+
farthest_color = self.colors[farthest_idx]
|
| 278 |
+
farthest_color_hex = Color.to_hex(farthest_color)
|
| 279 |
+
if farthest_color_hex in DEFAULT_COLOR_HEX_TO_NAME:
|
| 280 |
+
farthest_color_name = DEFAULT_COLOR_HEX_TO_NAME[farthest_color_hex]
|
| 281 |
+
else:
|
| 282 |
+
farthest_color_name = "unknown"
|
| 283 |
+
|
| 284 |
+
return farthest_color, farthest_color_name
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def draw_box(ax, box_coord, alpha=0.8, edge_color="g", line_style="-", linewidth=2.0):
|
| 288 |
+
x0, y0, width, height = box_coord
|
| 289 |
+
ax.add_patch(
|
| 290 |
+
mpl.patches.Rectangle(
|
| 291 |
+
(x0, y0),
|
| 292 |
+
width,
|
| 293 |
+
height,
|
| 294 |
+
fill=False,
|
| 295 |
+
edgecolor=edge_color,
|
| 296 |
+
linewidth=linewidth,
|
| 297 |
+
alpha=alpha,
|
| 298 |
+
linestyle=line_style,
|
| 299 |
+
)
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def draw_text(
|
| 304 |
+
ax,
|
| 305 |
+
text,
|
| 306 |
+
position,
|
| 307 |
+
font_size=None,
|
| 308 |
+
color="g",
|
| 309 |
+
horizontal_alignment="left",
|
| 310 |
+
rotation=0,
|
| 311 |
+
):
|
| 312 |
+
if not font_size:
|
| 313 |
+
font_size = mpl.rcParams["font.size"]
|
| 314 |
+
|
| 315 |
+
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
|
| 316 |
+
color[np.argmax(color)] = max(0.8, np.max(color))
|
| 317 |
+
|
| 318 |
+
x, y = position
|
| 319 |
+
ax.text(
|
| 320 |
+
x,
|
| 321 |
+
y,
|
| 322 |
+
text,
|
| 323 |
+
size=font_size,
|
| 324 |
+
family="sans-serif",
|
| 325 |
+
bbox={"facecolor": "none", "alpha": 0.5, "pad": 0.7, "edgecolor": "none"},
|
| 326 |
+
verticalalignment="top",
|
| 327 |
+
horizontalalignment=horizontal_alignment,
|
| 328 |
+
color=color,
|
| 329 |
+
rotation=rotation,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def draw_mask(
|
| 334 |
+
ax, rle, color, show_holes=True, alpha=0.15, upsample_factor=1.0, rle_upsampled=None
|
| 335 |
+
):
|
| 336 |
+
if isinstance(rle, dict):
|
| 337 |
+
mask = mask_utils.decode(rle)
|
| 338 |
+
elif isinstance(rle, np.ndarray):
|
| 339 |
+
mask = rle
|
| 340 |
+
else:
|
| 341 |
+
raise ValueError(f"Unsupported type for rle: {type(rle)}")
|
| 342 |
+
|
| 343 |
+
mask_upsampled = None
|
| 344 |
+
if upsample_factor > 1.0 and show_holes:
|
| 345 |
+
assert rle_upsampled is not None
|
| 346 |
+
if isinstance(rle_upsampled, dict):
|
| 347 |
+
mask_upsampled = mask_utils.decode(rle_upsampled)
|
| 348 |
+
elif isinstance(rle_upsampled, np.ndarray):
|
| 349 |
+
mask_upsampled = rle_upsampled
|
| 350 |
+
else:
|
| 351 |
+
raise ValueError(f"Unsupported type for rle: {type(rle)}")
|
| 352 |
+
|
| 353 |
+
if show_holes:
|
| 354 |
+
if mask_upsampled is None:
|
| 355 |
+
mask_upsampled = mask
|
| 356 |
+
h, w = mask_upsampled.shape
|
| 357 |
+
mask_img = np.zeros((h, w, 4))
|
| 358 |
+
mask_img[:, :, :-1] = color[np.newaxis, np.newaxis, :]
|
| 359 |
+
mask_img[:, :, -1] = mask_upsampled * alpha
|
| 360 |
+
ax.imshow(mask_img)
|
| 361 |
+
|
| 362 |
+
*_, contours, _ = cv2.findContours(
|
| 363 |
+
mask.astype(np.uint8).copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
| 364 |
+
)
|
| 365 |
+
upsampled_contours = [(cont + 0.5) * upsample_factor - 0.5 for cont in contours]
|
| 366 |
+
facecolor = (0, 0, 0, 0) if show_holes else color
|
| 367 |
+
if alpha > 0.8:
|
| 368 |
+
edge_color = _change_color_brightness(color, brightness_factor=-0.7)
|
| 369 |
+
else:
|
| 370 |
+
edge_color = color
|
| 371 |
+
for cont in upsampled_contours:
|
| 372 |
+
polygon = mpl.patches.Polygon(
|
| 373 |
+
[el[0] for el in cont],
|
| 374 |
+
edgecolor=edge_color,
|
| 375 |
+
linewidth=2.0,
|
| 376 |
+
facecolor=facecolor,
|
| 377 |
+
)
|
| 378 |
+
ax.add_patch(polygon)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def _change_color_brightness(color, brightness_factor):
|
| 382 |
+
"""
|
| 383 |
+
Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
|
| 384 |
+
less or more saturation than the original color.
|
| 385 |
+
|
| 386 |
+
Args:
|
| 387 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
| 388 |
+
formats that are accepted.
|
| 389 |
+
brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
|
| 390 |
+
0 will correspond to no change, a factor in [-1.0, 0) range will result in
|
| 391 |
+
a darker color and a factor in (0, 1.0] range will result in a lighter color.
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
modified_color (tuple[double]): a tuple containing the RGB values of the
|
| 395 |
+
modified color. Each value in the tuple is in the [0.0, 1.0] range.
|
| 396 |
+
"""
|
| 397 |
+
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
|
| 398 |
+
color = mplc.to_rgb(color)
|
| 399 |
+
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
|
| 400 |
+
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
|
| 401 |
+
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
|
| 402 |
+
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
|
| 403 |
+
modified_color = colorsys.hls_to_rgb(
|
| 404 |
+
polygon_color[0], modified_lightness, polygon_color[2]
|
| 405 |
+
)
|
| 406 |
+
return modified_color
|
sam3/agent/helpers/visualizer.py
ADDED
|
@@ -0,0 +1,1662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import colorsys
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
import random
|
| 7 |
+
from enum import Enum, unique
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import matplotlib as mpl
|
| 11 |
+
import matplotlib.colors as mplc
|
| 12 |
+
import matplotlib.figure as mplfigure
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pycocotools.mask as mask_util
|
| 15 |
+
import torch
|
| 16 |
+
from iopath.common.file_io import PathManager
|
| 17 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
| 18 |
+
from PIL import Image
|
| 19 |
+
|
| 20 |
+
from .boxes import Boxes, BoxMode
|
| 21 |
+
|
| 22 |
+
from .color_map import random_color
|
| 23 |
+
from .keypoints import Keypoints
|
| 24 |
+
from .masks import BitMasks, PolygonMasks
|
| 25 |
+
from .rotated_boxes import RotatedBoxes
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__all__ = ["ColorMode", "VisImage", "Visualizer"]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
_SMALL_OBJECT_AREA_THRESH = 1000
|
| 34 |
+
_LARGE_MASK_AREA_THRESH = 120000
|
| 35 |
+
_OFF_WHITE = (1.0, 1.0, 240.0 / 255)
|
| 36 |
+
_BLACK = (0, 0, 0)
|
| 37 |
+
_RED = (1.0, 0, 0)
|
| 38 |
+
|
| 39 |
+
_KEYPOINT_THRESHOLD = 0.05
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@unique
|
| 43 |
+
class ColorMode(Enum):
|
| 44 |
+
"""
|
| 45 |
+
Enum of different color modes to use for instance visualizations.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
IMAGE = 0
|
| 49 |
+
"""
|
| 50 |
+
Picks a random color for every instance and overlay segmentations with low opacity.
|
| 51 |
+
"""
|
| 52 |
+
SEGMENTATION = 1
|
| 53 |
+
"""
|
| 54 |
+
Let instances of the same category have similar colors
|
| 55 |
+
(from metadata.thing_colors), and overlay them with
|
| 56 |
+
high opacity. This provides more attention on the quality of segmentation.
|
| 57 |
+
"""
|
| 58 |
+
IMAGE_BW = 2
|
| 59 |
+
"""
|
| 60 |
+
Same as IMAGE, but convert all areas without masks to gray-scale.
|
| 61 |
+
Only available for drawing per-instance mask predictions.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class GenericMask:
|
| 66 |
+
"""
|
| 67 |
+
Attribute:
|
| 68 |
+
polygons (list[ndarray]): list[ndarray]: polygons for this mask.
|
| 69 |
+
Each ndarray has format [x, y, x, y, ...]
|
| 70 |
+
mask (ndarray): a binary mask
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, mask_or_polygons, height, width):
|
| 74 |
+
self._mask = self._polygons = self._has_holes = None
|
| 75 |
+
self.height = height
|
| 76 |
+
self.width = width
|
| 77 |
+
|
| 78 |
+
m = mask_or_polygons
|
| 79 |
+
if isinstance(m, dict):
|
| 80 |
+
# RLEs
|
| 81 |
+
assert "counts" in m and "size" in m
|
| 82 |
+
if isinstance(m["counts"], list): # uncompressed RLEs
|
| 83 |
+
h, w = m["size"]
|
| 84 |
+
assert h == height and w == width
|
| 85 |
+
m = mask_util.frPyObjects(m, h, w)
|
| 86 |
+
self._mask = mask_util.decode(m)[:, :]
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
if isinstance(m, list): # list[ndarray]
|
| 90 |
+
self._polygons = [np.asarray(x).reshape(-1) for x in m]
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
if isinstance(m, np.ndarray): # assumed to be a binary mask
|
| 94 |
+
assert m.shape[1] != 2, m.shape
|
| 95 |
+
assert m.shape == (
|
| 96 |
+
height,
|
| 97 |
+
width,
|
| 98 |
+
), f"mask shape: {m.shape}, target dims: {height}, {width}"
|
| 99 |
+
self._mask = m.astype("uint8")
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
raise ValueError(
|
| 103 |
+
"GenericMask cannot handle object {} of type '{}'".format(m, type(m))
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def mask(self):
|
| 108 |
+
if self._mask is None:
|
| 109 |
+
self._mask = self.polygons_to_mask(self._polygons)
|
| 110 |
+
return self._mask
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def polygons(self):
|
| 114 |
+
if self._polygons is None:
|
| 115 |
+
self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
|
| 116 |
+
return self._polygons
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def has_holes(self):
|
| 120 |
+
if self._has_holes is None:
|
| 121 |
+
if self._mask is not None:
|
| 122 |
+
self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
|
| 123 |
+
else:
|
| 124 |
+
self._has_holes = (
|
| 125 |
+
False # if original format is polygon, does not have holes
|
| 126 |
+
)
|
| 127 |
+
return self._has_holes
|
| 128 |
+
|
| 129 |
+
def mask_to_polygons(self, mask):
|
| 130 |
+
# cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
|
| 131 |
+
# hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
|
| 132 |
+
# Internal contours (holes) are placed in hierarchy-2.
|
| 133 |
+
# cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
|
| 134 |
+
mask = np.ascontiguousarray(
|
| 135 |
+
mask
|
| 136 |
+
) # some versions of cv2 does not support incontiguous arr
|
| 137 |
+
res = cv2.findContours(
|
| 138 |
+
mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
|
| 139 |
+
)
|
| 140 |
+
hierarchy = res[-1]
|
| 141 |
+
if hierarchy is None: # empty mask
|
| 142 |
+
return [], False
|
| 143 |
+
has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
|
| 144 |
+
res = res[-2]
|
| 145 |
+
res = [x.flatten() for x in res]
|
| 146 |
+
# These coordinates from OpenCV are integers in range [0, W-1 or H-1].
|
| 147 |
+
# We add 0.5 to turn them into real-value coordinate space. A better solution
|
| 148 |
+
# would be to first +0.5 and then dilate the returned polygon by 0.5.
|
| 149 |
+
res = [x + 0.5 for x in res if len(x) >= 6]
|
| 150 |
+
return res, has_holes
|
| 151 |
+
|
| 152 |
+
def polygons_to_mask(self, polygons):
|
| 153 |
+
rle = mask_util.frPyObjects(polygons, self.height, self.width)
|
| 154 |
+
rle = mask_util.merge(rle)
|
| 155 |
+
return mask_util.decode(rle)[:, :]
|
| 156 |
+
|
| 157 |
+
def area(self):
|
| 158 |
+
return self.mask.sum()
|
| 159 |
+
|
| 160 |
+
def bbox(self):
|
| 161 |
+
p = mask_util.frPyObjects(self.polygons, self.height, self.width)
|
| 162 |
+
p = mask_util.merge(p)
|
| 163 |
+
bbox = mask_util.toBbox(p)
|
| 164 |
+
bbox[2] += bbox[0]
|
| 165 |
+
bbox[3] += bbox[1]
|
| 166 |
+
return bbox
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class _PanopticPrediction:
|
| 170 |
+
"""
|
| 171 |
+
Unify different panoptic annotation/prediction formats
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, panoptic_seg, segments_info, metadata=None):
|
| 175 |
+
if segments_info is None:
|
| 176 |
+
assert metadata is not None
|
| 177 |
+
# If "segments_info" is None, we assume "panoptic_img" is a
|
| 178 |
+
# H*W int32 image storing the panoptic_id in the format of
|
| 179 |
+
# category_id * label_divisor + instance_id. We reserve -1 for
|
| 180 |
+
# VOID label.
|
| 181 |
+
label_divisor = metadata.label_divisor
|
| 182 |
+
segments_info = []
|
| 183 |
+
for panoptic_label in np.unique(panoptic_seg.numpy()):
|
| 184 |
+
if panoptic_label == -1:
|
| 185 |
+
# VOID region.
|
| 186 |
+
continue
|
| 187 |
+
pred_class = panoptic_label // label_divisor
|
| 188 |
+
isthing = (
|
| 189 |
+
pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
|
| 190 |
+
)
|
| 191 |
+
segments_info.append(
|
| 192 |
+
{
|
| 193 |
+
"id": int(panoptic_label),
|
| 194 |
+
"category_id": int(pred_class),
|
| 195 |
+
"isthing": bool(isthing),
|
| 196 |
+
}
|
| 197 |
+
)
|
| 198 |
+
del metadata
|
| 199 |
+
|
| 200 |
+
self._seg = panoptic_seg
|
| 201 |
+
|
| 202 |
+
self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info
|
| 203 |
+
segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
|
| 204 |
+
areas = areas.numpy()
|
| 205 |
+
sorted_idxs = np.argsort(-areas)
|
| 206 |
+
self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]
|
| 207 |
+
self._seg_ids = self._seg_ids.tolist()
|
| 208 |
+
for sid, area in zip(self._seg_ids, self._seg_areas):
|
| 209 |
+
if sid in self._sinfo:
|
| 210 |
+
self._sinfo[sid]["area"] = float(area)
|
| 211 |
+
|
| 212 |
+
def non_empty_mask(self):
|
| 213 |
+
"""
|
| 214 |
+
Returns:
|
| 215 |
+
(H, W) array, a mask for all pixels that have a prediction
|
| 216 |
+
"""
|
| 217 |
+
empty_ids = []
|
| 218 |
+
for id in self._seg_ids:
|
| 219 |
+
if id not in self._sinfo:
|
| 220 |
+
empty_ids.append(id)
|
| 221 |
+
if len(empty_ids) == 0:
|
| 222 |
+
return np.zeros(self._seg.shape, dtype=np.uint8)
|
| 223 |
+
assert (
|
| 224 |
+
len(empty_ids) == 1
|
| 225 |
+
), ">1 ids corresponds to no labels. This is currently not supported"
|
| 226 |
+
return (self._seg != empty_ids[0]).numpy().astype(np.bool)
|
| 227 |
+
|
| 228 |
+
def semantic_masks(self):
|
| 229 |
+
for sid in self._seg_ids:
|
| 230 |
+
sinfo = self._sinfo.get(sid)
|
| 231 |
+
if sinfo is None or sinfo["isthing"]:
|
| 232 |
+
# Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
|
| 233 |
+
continue
|
| 234 |
+
yield (self._seg == sid).numpy().astype(np.bool), sinfo
|
| 235 |
+
|
| 236 |
+
def instance_masks(self):
|
| 237 |
+
for sid in self._seg_ids:
|
| 238 |
+
sinfo = self._sinfo.get(sid)
|
| 239 |
+
if sinfo is None or not sinfo["isthing"]:
|
| 240 |
+
continue
|
| 241 |
+
mask = (self._seg == sid).numpy().astype(np.bool)
|
| 242 |
+
if mask.sum() > 0:
|
| 243 |
+
yield mask, sinfo
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def _create_text_labels(classes, scores, class_names, is_crowd=None):
|
| 247 |
+
"""
|
| 248 |
+
Args:
|
| 249 |
+
classes (list[int] or None):
|
| 250 |
+
scores (list[float] or None):
|
| 251 |
+
class_names (list[str] or None):
|
| 252 |
+
is_crowd (list[bool] or None):
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
list[str] or None
|
| 256 |
+
"""
|
| 257 |
+
labels = None
|
| 258 |
+
if classes is not None:
|
| 259 |
+
if class_names is not None and len(class_names) > 0:
|
| 260 |
+
labels = [class_names[i] for i in classes]
|
| 261 |
+
else:
|
| 262 |
+
labels = [str(i) for i in classes]
|
| 263 |
+
if scores is not None:
|
| 264 |
+
if labels is None:
|
| 265 |
+
labels = ["{:.0f}%".format(s * 100) for s in scores]
|
| 266 |
+
else:
|
| 267 |
+
labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
|
| 268 |
+
if labels is not None and is_crowd is not None:
|
| 269 |
+
labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)]
|
| 270 |
+
return labels
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class VisImage:
|
| 274 |
+
def __init__(self, img, scale=1.0):
|
| 275 |
+
"""
|
| 276 |
+
Args:
|
| 277 |
+
img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
|
| 278 |
+
scale (float): scale the input image
|
| 279 |
+
"""
|
| 280 |
+
self.img = img
|
| 281 |
+
self.scale = scale
|
| 282 |
+
self.width, self.height = img.shape[1], img.shape[0]
|
| 283 |
+
self._setup_figure(img)
|
| 284 |
+
|
| 285 |
+
def _setup_figure(self, img):
|
| 286 |
+
"""
|
| 287 |
+
Args:
|
| 288 |
+
Same as in :meth:`__init__()`.
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
|
| 292 |
+
ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
|
| 293 |
+
"""
|
| 294 |
+
fig = mplfigure.Figure(frameon=False)
|
| 295 |
+
self.dpi = fig.get_dpi()
|
| 296 |
+
# add a small 1e-2 to avoid precision lost due to matplotlib's truncation
|
| 297 |
+
# (https://github.com/matplotlib/matplotlib/issues/15363)
|
| 298 |
+
fig.set_size_inches(
|
| 299 |
+
(self.width * self.scale + 1e-2) / self.dpi,
|
| 300 |
+
(self.height * self.scale + 1e-2) / self.dpi,
|
| 301 |
+
)
|
| 302 |
+
self.canvas = FigureCanvasAgg(fig)
|
| 303 |
+
# self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
|
| 304 |
+
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
|
| 305 |
+
ax.axis("off")
|
| 306 |
+
self.fig = fig
|
| 307 |
+
self.ax = ax
|
| 308 |
+
self.reset_image(img)
|
| 309 |
+
|
| 310 |
+
def reset_image(self, img):
|
| 311 |
+
"""
|
| 312 |
+
Args:
|
| 313 |
+
img: same as in __init__
|
| 314 |
+
"""
|
| 315 |
+
img = img.astype("uint8")
|
| 316 |
+
self.ax.imshow(
|
| 317 |
+
img, extent=(0, self.width, self.height, 0), interpolation="nearest"
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
def save(self, filepath):
|
| 321 |
+
"""
|
| 322 |
+
Args:
|
| 323 |
+
filepath (str): a string that contains the absolute path, including the file name, where
|
| 324 |
+
the visualized image will be saved.
|
| 325 |
+
"""
|
| 326 |
+
self.fig.savefig(filepath)
|
| 327 |
+
|
| 328 |
+
def get_image(self):
|
| 329 |
+
"""
|
| 330 |
+
Returns:
|
| 331 |
+
ndarray:
|
| 332 |
+
the visualized image of shape (H, W, 3) (RGB) in uint8 type.
|
| 333 |
+
The shape is scaled w.r.t the input image using the given `scale` argument.
|
| 334 |
+
"""
|
| 335 |
+
canvas = self.canvas
|
| 336 |
+
s, (width, height) = canvas.print_to_buffer()
|
| 337 |
+
# buf = io.BytesIO() # works for cairo backend
|
| 338 |
+
# canvas.print_rgba(buf)
|
| 339 |
+
# width, height = self.width, self.height
|
| 340 |
+
# s = buf.getvalue()
|
| 341 |
+
|
| 342 |
+
buffer = np.frombuffer(s, dtype="uint8")
|
| 343 |
+
|
| 344 |
+
img_rgba = buffer.reshape(height, width, 4)
|
| 345 |
+
rgb, alpha = np.split(img_rgba, [3], axis=2)
|
| 346 |
+
return rgb.astype("uint8")
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class Visualizer:
|
| 350 |
+
"""
|
| 351 |
+
Visualizer that draws data about detection/segmentation on images.
|
| 352 |
+
|
| 353 |
+
It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
|
| 354 |
+
that draw primitive objects to images, as well as high-level wrappers like
|
| 355 |
+
`draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
|
| 356 |
+
that draw composite data in some pre-defined style.
|
| 357 |
+
|
| 358 |
+
Note that the exact visualization style for the high-level wrappers are subject to change.
|
| 359 |
+
Style such as color, opacity, label contents, visibility of labels, or even the visibility
|
| 360 |
+
of objects themselves (e.g. when the object is too small) may change according
|
| 361 |
+
to different heuristics, as long as the results still look visually reasonable.
|
| 362 |
+
|
| 363 |
+
To obtain a consistent style, you can implement custom drawing functions with the
|
| 364 |
+
abovementioned primitive methods instead. If you need more customized visualization
|
| 365 |
+
styles, you can process the data yourself following their format documented in
|
| 366 |
+
tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
|
| 367 |
+
intend to satisfy everyone's preference on drawing styles.
|
| 368 |
+
|
| 369 |
+
This visualizer focuses on high rendering quality rather than performance. It is not
|
| 370 |
+
designed to be used for real-time applications.
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
def __init__(
|
| 374 |
+
self,
|
| 375 |
+
img_rgb,
|
| 376 |
+
metadata=None,
|
| 377 |
+
scale=1.0,
|
| 378 |
+
instance_mode=ColorMode.IMAGE,
|
| 379 |
+
font_size_multiplier=1.3,
|
| 380 |
+
boarder_width_multiplier=1.5,
|
| 381 |
+
):
|
| 382 |
+
"""
|
| 383 |
+
Args:
|
| 384 |
+
img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
|
| 385 |
+
the height and width of the image respectively. C is the number of
|
| 386 |
+
color channels. The image is required to be in RGB format since that
|
| 387 |
+
is a requirement of the Matplotlib library. The image is also expected
|
| 388 |
+
to be in the range [0, 255].
|
| 389 |
+
metadata (Metadata): dataset metadata (e.g. class names and colors)
|
| 390 |
+
instance_mode (ColorMode): defines one of the pre-defined style for drawing
|
| 391 |
+
instances on an image.
|
| 392 |
+
"""
|
| 393 |
+
self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
|
| 394 |
+
self.boarder_width_multiplier = boarder_width_multiplier
|
| 395 |
+
# if metadata is None:
|
| 396 |
+
# metadata = MetadataCatalog.get("__nonexist__")
|
| 397 |
+
# self.metadata = metadata
|
| 398 |
+
self.output = VisImage(self.img, scale=scale)
|
| 399 |
+
self.cpu_device = torch.device("cpu")
|
| 400 |
+
|
| 401 |
+
# too small texts are useless, therefore clamp to 9
|
| 402 |
+
self._default_font_size = (
|
| 403 |
+
max(np.sqrt(self.output.height * self.output.width) // 60, 15 // scale)
|
| 404 |
+
* font_size_multiplier
|
| 405 |
+
)
|
| 406 |
+
# self._default_font_size = 18
|
| 407 |
+
self._instance_mode = instance_mode
|
| 408 |
+
self.keypoint_threshold = _KEYPOINT_THRESHOLD
|
| 409 |
+
|
| 410 |
+
import matplotlib.colors as mcolors
|
| 411 |
+
|
| 412 |
+
css4_colors = mcolors.CSS4_COLORS
|
| 413 |
+
self.color_proposals = [
|
| 414 |
+
list(mcolors.hex2color(color)) for color in css4_colors.values()
|
| 415 |
+
]
|
| 416 |
+
|
| 417 |
+
def draw_instance_predictions(self, predictions):
|
| 418 |
+
"""
|
| 419 |
+
Draw instance-level prediction results on an image.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
predictions (Instances): the output of an instance detection/segmentation
|
| 423 |
+
model. Following fields will be used to draw:
|
| 424 |
+
"pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
output (VisImage): image object with visualizations.
|
| 428 |
+
"""
|
| 429 |
+
boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
|
| 430 |
+
scores = predictions.scores if predictions.has("scores") else None
|
| 431 |
+
classes = (
|
| 432 |
+
predictions.pred_classes.tolist()
|
| 433 |
+
if predictions.has("pred_classes")
|
| 434 |
+
else None
|
| 435 |
+
)
|
| 436 |
+
labels = _create_text_labels(
|
| 437 |
+
classes, scores, self.metadata.get("thing_classes", None)
|
| 438 |
+
)
|
| 439 |
+
keypoints = (
|
| 440 |
+
predictions.pred_keypoints if predictions.has("pred_keypoints") else None
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
keep = (scores > 0.5).cpu()
|
| 444 |
+
boxes = boxes[keep]
|
| 445 |
+
scores = scores[keep]
|
| 446 |
+
classes = np.array(classes)
|
| 447 |
+
classes = classes[np.array(keep)]
|
| 448 |
+
labels = np.array(labels)
|
| 449 |
+
labels = labels[np.array(keep)]
|
| 450 |
+
|
| 451 |
+
if predictions.has("pred_masks"):
|
| 452 |
+
masks = np.asarray(predictions.pred_masks)
|
| 453 |
+
masks = masks[np.array(keep)]
|
| 454 |
+
masks = [
|
| 455 |
+
GenericMask(x, self.output.height, self.output.width) for x in masks
|
| 456 |
+
]
|
| 457 |
+
else:
|
| 458 |
+
masks = None
|
| 459 |
+
|
| 460 |
+
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(
|
| 461 |
+
"thing_colors"
|
| 462 |
+
):
|
| 463 |
+
# if self.metadata.get("thing_colors"):
|
| 464 |
+
colors = [
|
| 465 |
+
self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
|
| 466 |
+
for c in classes
|
| 467 |
+
]
|
| 468 |
+
alpha = 0.4
|
| 469 |
+
else:
|
| 470 |
+
colors = None
|
| 471 |
+
alpha = 0.4
|
| 472 |
+
|
| 473 |
+
if self._instance_mode == ColorMode.IMAGE_BW:
|
| 474 |
+
self.output.reset_image(
|
| 475 |
+
self._create_grayscale_image(
|
| 476 |
+
(predictions.pred_masks.any(dim=0) > 0).numpy()
|
| 477 |
+
if predictions.has("pred_masks")
|
| 478 |
+
else None
|
| 479 |
+
)
|
| 480 |
+
)
|
| 481 |
+
alpha = 0.3
|
| 482 |
+
|
| 483 |
+
self.overlay_instances(
|
| 484 |
+
masks=masks,
|
| 485 |
+
boxes=boxes,
|
| 486 |
+
labels=labels,
|
| 487 |
+
keypoints=keypoints,
|
| 488 |
+
assigned_colors=colors,
|
| 489 |
+
alpha=alpha,
|
| 490 |
+
)
|
| 491 |
+
return self.output
|
| 492 |
+
|
| 493 |
+
def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.7):
|
| 494 |
+
"""
|
| 495 |
+
Draw semantic segmentation predictions/labels.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
|
| 499 |
+
Each value is the integer label of the pixel.
|
| 500 |
+
area_threshold (int): segments with less than `area_threshold` are not drawn.
|
| 501 |
+
alpha (float): the larger it is, the more opaque the segmentations are.
|
| 502 |
+
|
| 503 |
+
Returns:
|
| 504 |
+
output (VisImage): image object with visualizations.
|
| 505 |
+
"""
|
| 506 |
+
if isinstance(sem_seg, torch.Tensor):
|
| 507 |
+
sem_seg = sem_seg.numpy()
|
| 508 |
+
labels, areas = np.unique(sem_seg, return_counts=True)
|
| 509 |
+
sorted_idxs = np.argsort(-areas).tolist()
|
| 510 |
+
labels = labels[sorted_idxs]
|
| 511 |
+
for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
|
| 512 |
+
try:
|
| 513 |
+
mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
|
| 514 |
+
except (AttributeError, IndexError):
|
| 515 |
+
mask_color = None
|
| 516 |
+
|
| 517 |
+
binary_mask = (sem_seg == label).astype(np.uint8)
|
| 518 |
+
text = self.metadata.stuff_classes[label]
|
| 519 |
+
self.draw_binary_mask(
|
| 520 |
+
binary_mask,
|
| 521 |
+
color=mask_color,
|
| 522 |
+
edge_color=_OFF_WHITE,
|
| 523 |
+
text=text,
|
| 524 |
+
alpha=alpha,
|
| 525 |
+
area_threshold=area_threshold,
|
| 526 |
+
)
|
| 527 |
+
return self.output
|
| 528 |
+
|
| 529 |
+
def draw_panoptic_seg(
|
| 530 |
+
self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7
|
| 531 |
+
):
|
| 532 |
+
"""
|
| 533 |
+
Draw panoptic prediction annotations or results.
|
| 534 |
+
|
| 535 |
+
Args:
|
| 536 |
+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
|
| 537 |
+
segment.
|
| 538 |
+
segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
|
| 539 |
+
If it is a ``list[dict]``, each dict contains keys "id", "category_id".
|
| 540 |
+
If None, category id of each pixel is computed by
|
| 541 |
+
``pixel // metadata.label_divisor``.
|
| 542 |
+
area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
|
| 543 |
+
|
| 544 |
+
Returns:
|
| 545 |
+
output (VisImage): image object with visualizations.
|
| 546 |
+
"""
|
| 547 |
+
pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
|
| 548 |
+
|
| 549 |
+
if self._instance_mode == ColorMode.IMAGE_BW:
|
| 550 |
+
self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
|
| 551 |
+
|
| 552 |
+
# draw mask for all semantic segments first i.e. "stuff"
|
| 553 |
+
for mask, sinfo in pred.semantic_masks():
|
| 554 |
+
category_idx = sinfo["category_id"]
|
| 555 |
+
try:
|
| 556 |
+
mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
|
| 557 |
+
except AttributeError:
|
| 558 |
+
mask_color = None
|
| 559 |
+
|
| 560 |
+
text = (
|
| 561 |
+
self.metadata.stuff_classes[category_idx]
|
| 562 |
+
.replace("-other", "")
|
| 563 |
+
.replace("-merged", "")
|
| 564 |
+
)
|
| 565 |
+
self.draw_binary_mask(
|
| 566 |
+
mask,
|
| 567 |
+
color=mask_color,
|
| 568 |
+
edge_color=_OFF_WHITE,
|
| 569 |
+
text=text,
|
| 570 |
+
alpha=alpha,
|
| 571 |
+
area_threshold=area_threshold,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
# draw mask for all instances second
|
| 575 |
+
all_instances = list(pred.instance_masks())
|
| 576 |
+
if len(all_instances) == 0:
|
| 577 |
+
return self.output
|
| 578 |
+
masks, sinfo = list(zip(*all_instances))
|
| 579 |
+
category_ids = [x["category_id"] for x in sinfo]
|
| 580 |
+
|
| 581 |
+
try:
|
| 582 |
+
scores = [x["score"] for x in sinfo]
|
| 583 |
+
except KeyError:
|
| 584 |
+
scores = None
|
| 585 |
+
class_names = [
|
| 586 |
+
name.replace("-other", "").replace("-merged", "")
|
| 587 |
+
for name in self.metadata.thing_classes
|
| 588 |
+
]
|
| 589 |
+
labels = _create_text_labels(
|
| 590 |
+
category_ids, scores, class_names, [x.get("iscrowd", 0) for x in sinfo]
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
try:
|
| 594 |
+
colors = [
|
| 595 |
+
self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
|
| 596 |
+
for c in category_ids
|
| 597 |
+
]
|
| 598 |
+
except AttributeError:
|
| 599 |
+
colors = None
|
| 600 |
+
self.overlay_instances(
|
| 601 |
+
masks=masks, labels=labels, assigned_colors=colors, alpha=alpha
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
return self.output
|
| 605 |
+
|
| 606 |
+
draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility
|
| 607 |
+
|
| 608 |
+
def draw_dataset_dict(self, dic):
|
| 609 |
+
"""
|
| 610 |
+
Draw annotations/segmentaions in Detectron2 Dataset format.
|
| 611 |
+
|
| 612 |
+
Args:
|
| 613 |
+
dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.
|
| 614 |
+
|
| 615 |
+
Returns:
|
| 616 |
+
output (VisImage): image object with visualizations.
|
| 617 |
+
"""
|
| 618 |
+
annos = dic.get("annotations", None)
|
| 619 |
+
if annos:
|
| 620 |
+
if "segmentation" in annos[0]:
|
| 621 |
+
masks = [x["segmentation"] for x in annos]
|
| 622 |
+
else:
|
| 623 |
+
masks = None
|
| 624 |
+
if "keypoints" in annos[0]:
|
| 625 |
+
keypts = [x["keypoints"] for x in annos]
|
| 626 |
+
keypts = np.array(keypts).reshape(len(annos), -1, 3)
|
| 627 |
+
else:
|
| 628 |
+
keypts = None
|
| 629 |
+
|
| 630 |
+
boxes = [
|
| 631 |
+
(
|
| 632 |
+
BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
|
| 633 |
+
if len(x["bbox"]) == 4
|
| 634 |
+
else x["bbox"]
|
| 635 |
+
)
|
| 636 |
+
for x in annos
|
| 637 |
+
]
|
| 638 |
+
|
| 639 |
+
colors = None
|
| 640 |
+
category_ids = [x["category_id"] for x in annos]
|
| 641 |
+
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(
|
| 642 |
+
"thing_colors"
|
| 643 |
+
):
|
| 644 |
+
colors = [
|
| 645 |
+
self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
|
| 646 |
+
for c in category_ids
|
| 647 |
+
]
|
| 648 |
+
names = self.metadata.get("thing_classes", None)
|
| 649 |
+
labels = _create_text_labels(
|
| 650 |
+
category_ids,
|
| 651 |
+
scores=None,
|
| 652 |
+
class_names=names,
|
| 653 |
+
is_crowd=[x.get("iscrowd", 0) for x in annos],
|
| 654 |
+
)
|
| 655 |
+
self.overlay_instances(
|
| 656 |
+
labels=labels,
|
| 657 |
+
boxes=boxes,
|
| 658 |
+
masks=masks,
|
| 659 |
+
keypoints=keypts,
|
| 660 |
+
assigned_colors=colors,
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
sem_seg = dic.get("sem_seg", None)
|
| 664 |
+
if sem_seg is None and "sem_seg_file_name" in dic:
|
| 665 |
+
with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
|
| 666 |
+
sem_seg = Image.open(f)
|
| 667 |
+
sem_seg = np.asarray(sem_seg, dtype="uint8")
|
| 668 |
+
if sem_seg is not None:
|
| 669 |
+
self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.4)
|
| 670 |
+
|
| 671 |
+
pan_seg = dic.get("pan_seg", None)
|
| 672 |
+
if pan_seg is None and "pan_seg_file_name" in dic:
|
| 673 |
+
with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
|
| 674 |
+
pan_seg = Image.open(f)
|
| 675 |
+
pan_seg = np.asarray(pan_seg)
|
| 676 |
+
from panopticapi.utils import rgb2id
|
| 677 |
+
|
| 678 |
+
pan_seg = rgb2id(pan_seg)
|
| 679 |
+
if pan_seg is not None:
|
| 680 |
+
segments_info = dic["segments_info"]
|
| 681 |
+
pan_seg = torch.tensor(pan_seg)
|
| 682 |
+
self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.7)
|
| 683 |
+
return self.output
|
| 684 |
+
|
| 685 |
+
def overlay_instances(
|
| 686 |
+
self,
|
| 687 |
+
*,
|
| 688 |
+
boxes=None,
|
| 689 |
+
labels=None,
|
| 690 |
+
masks=None,
|
| 691 |
+
keypoints=None,
|
| 692 |
+
assigned_colors=None,
|
| 693 |
+
binary_masks=None,
|
| 694 |
+
alpha=0.5,
|
| 695 |
+
label_mode="1",
|
| 696 |
+
):
|
| 697 |
+
"""
|
| 698 |
+
Args:
|
| 699 |
+
boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
|
| 700 |
+
or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
|
| 701 |
+
or a :class:`RotatedBoxes`,
|
| 702 |
+
or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
|
| 703 |
+
for the N objects in a single image,
|
| 704 |
+
labels (list[str]): the text to be displayed for each instance.
|
| 705 |
+
masks (masks-like object): Supported types are:
|
| 706 |
+
|
| 707 |
+
* :class:`detectron2.structures.PolygonMasks`,
|
| 708 |
+
:class:`detectron2.structures.BitMasks`.
|
| 709 |
+
* list[list[ndarray]]: contains the segmentation masks for all objects in one image.
|
| 710 |
+
The first level of the list corresponds to individual instances. The second
|
| 711 |
+
level to all the polygon that compose the instance, and the third level
|
| 712 |
+
to the polygon coordinates. The third level should have the format of
|
| 713 |
+
[x0, y0, x1, y1, ..., xn, yn] (n >= 3).
|
| 714 |
+
* list[ndarray]: each ndarray is a binary mask of shape (H, W).
|
| 715 |
+
* list[dict]: each dict is a COCO-style RLE.
|
| 716 |
+
keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
|
| 717 |
+
where the N is the number of instances and K is the number of keypoints.
|
| 718 |
+
The last dimension corresponds to (x, y, visibility or score).
|
| 719 |
+
assigned_colors (list[matplotlib.colors]): a list of colors, where each color
|
| 720 |
+
corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
|
| 721 |
+
for full list of formats that the colors are accepted in.
|
| 722 |
+
Returns:
|
| 723 |
+
output (VisImage): image object with visualizations.
|
| 724 |
+
"""
|
| 725 |
+
num_instances = 0
|
| 726 |
+
if boxes is not None:
|
| 727 |
+
boxes = self._convert_boxes(boxes)
|
| 728 |
+
num_instances = len(boxes)
|
| 729 |
+
if masks is not None:
|
| 730 |
+
masks = self._convert_masks(masks)
|
| 731 |
+
if num_instances:
|
| 732 |
+
assert len(masks) == num_instances
|
| 733 |
+
else:
|
| 734 |
+
num_instances = len(masks)
|
| 735 |
+
if keypoints is not None:
|
| 736 |
+
if num_instances:
|
| 737 |
+
assert len(keypoints) == num_instances
|
| 738 |
+
else:
|
| 739 |
+
num_instances = len(keypoints)
|
| 740 |
+
keypoints = self._convert_keypoints(keypoints)
|
| 741 |
+
if labels is not None:
|
| 742 |
+
assert len(labels) == num_instances
|
| 743 |
+
if assigned_colors is None:
|
| 744 |
+
assigned_colors = [
|
| 745 |
+
random_color(rgb=True, maximum=1) for _ in range(num_instances)
|
| 746 |
+
]
|
| 747 |
+
if num_instances == 0:
|
| 748 |
+
return labels, [], []
|
| 749 |
+
if boxes is not None and boxes.shape[1] == 5:
|
| 750 |
+
return self.overlay_rotated_instances(
|
| 751 |
+
boxes=boxes, labels=labels, assigned_colors=assigned_colors
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
# Display in largest to smallest order to reduce occlusion.
|
| 755 |
+
areas = None
|
| 756 |
+
if boxes is not None:
|
| 757 |
+
areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
|
| 758 |
+
elif masks is not None:
|
| 759 |
+
areas = np.asarray([x.area() for x in masks])
|
| 760 |
+
|
| 761 |
+
# if areas is not None:
|
| 762 |
+
# # sorted_idxs = np.argsort(areas).tolist()
|
| 763 |
+
# sorted_idxs = np.argsort(-areas).tolist()
|
| 764 |
+
# # Re-order overlapped instances in descending order.
|
| 765 |
+
# boxes = boxes[sorted_idxs] if boxes is not None else None
|
| 766 |
+
# labels = [labels[k] for k in sorted_idxs] if labels is not None else None
|
| 767 |
+
# masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
|
| 768 |
+
# binary_masks = (
|
| 769 |
+
# [binary_masks[idx] for idx in sorted_idxs]
|
| 770 |
+
# if binary_masks is not None
|
| 771 |
+
# else None
|
| 772 |
+
# )
|
| 773 |
+
# assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
|
| 774 |
+
# keypoints = keypoints[sorted_idxs] if keypoints is not None else None
|
| 775 |
+
|
| 776 |
+
marks = []
|
| 777 |
+
marks_position = []
|
| 778 |
+
added_positions = set()
|
| 779 |
+
for i in range(num_instances):
|
| 780 |
+
color = assigned_colors[i]
|
| 781 |
+
if boxes is not None:
|
| 782 |
+
self.draw_box(boxes[i], alpha=1, edge_color=color)
|
| 783 |
+
if binary_masks is None:
|
| 784 |
+
# draw number for non-mask instances
|
| 785 |
+
mark = self._draw_number_in_box(
|
| 786 |
+
boxes[i], i + 1, color=color, label_mode=label_mode
|
| 787 |
+
)
|
| 788 |
+
marks.append(mark)
|
| 789 |
+
|
| 790 |
+
if binary_masks is not None:
|
| 791 |
+
mark, mask_position = self._draw_number_in_mask(
|
| 792 |
+
binary_mask=binary_masks[i].astype("uint8"),
|
| 793 |
+
text=i + 1,
|
| 794 |
+
color=color,
|
| 795 |
+
added_positions=added_positions,
|
| 796 |
+
label_mode=label_mode,
|
| 797 |
+
)
|
| 798 |
+
marks.append(mark)
|
| 799 |
+
marks_position.append(mask_position)
|
| 800 |
+
|
| 801 |
+
self.draw_binary_mask(
|
| 802 |
+
binary_masks[i],
|
| 803 |
+
color=color,
|
| 804 |
+
edge_color=_OFF_WHITE,
|
| 805 |
+
alpha=alpha,
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
if masks is not None:
|
| 809 |
+
for segment in masks[i].polygons:
|
| 810 |
+
self.draw_polygon(
|
| 811 |
+
segment.reshape(-1, 2), color, alpha=0
|
| 812 |
+
) # alpha=0 so holes in masks are not colored
|
| 813 |
+
|
| 814 |
+
# draw keypoints
|
| 815 |
+
if keypoints is not None:
|
| 816 |
+
for keypoints_per_instance in keypoints:
|
| 817 |
+
self.draw_and_connect_keypoints(keypoints_per_instance)
|
| 818 |
+
|
| 819 |
+
# return labels, marks, sorted_idxs, marks_position
|
| 820 |
+
return labels, marks, marks_position
|
| 821 |
+
|
| 822 |
+
def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
|
| 823 |
+
"""
|
| 824 |
+
Args:
|
| 825 |
+
boxes (ndarray): an Nx5 numpy array of
|
| 826 |
+
(x_center, y_center, width, height, angle_degrees) format
|
| 827 |
+
for the N objects in a single image.
|
| 828 |
+
labels (list[str]): the text to be displayed for each instance.
|
| 829 |
+
assigned_colors (list[matplotlib.colors]): a list of colors, where each color
|
| 830 |
+
corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
|
| 831 |
+
for full list of formats that the colors are accepted in.
|
| 832 |
+
|
| 833 |
+
Returns:
|
| 834 |
+
output (VisImage): image object with visualizations.
|
| 835 |
+
"""
|
| 836 |
+
num_instances = len(boxes)
|
| 837 |
+
|
| 838 |
+
if assigned_colors is None:
|
| 839 |
+
assigned_colors = [
|
| 840 |
+
random_color(rgb=True, maximum=1) for _ in range(num_instances)
|
| 841 |
+
]
|
| 842 |
+
if num_instances == 0:
|
| 843 |
+
return self.output
|
| 844 |
+
|
| 845 |
+
# Display in largest to smallest order to reduce occlusion.
|
| 846 |
+
if boxes is not None:
|
| 847 |
+
areas = boxes[:, 2] * boxes[:, 3]
|
| 848 |
+
|
| 849 |
+
sorted_idxs = np.argsort(-areas).tolist()
|
| 850 |
+
# Re-order overlapped instances in descending order.
|
| 851 |
+
boxes = boxes[sorted_idxs]
|
| 852 |
+
labels = [labels[k] for k in sorted_idxs] if labels is not None else None
|
| 853 |
+
colors = [assigned_colors[idx] for idx in sorted_idxs]
|
| 854 |
+
|
| 855 |
+
for i in range(num_instances):
|
| 856 |
+
self.draw_rotated_box_with_label(
|
| 857 |
+
boxes[i],
|
| 858 |
+
edge_color=colors[i],
|
| 859 |
+
label=labels[i] if labels is not None else None,
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
return self.output
|
| 863 |
+
|
| 864 |
+
def draw_and_connect_keypoints(self, keypoints):
|
| 865 |
+
"""
|
| 866 |
+
Draws keypoints of an instance and follows the rules for keypoint connections
|
| 867 |
+
to draw lines between appropriate keypoints. This follows color heuristics for
|
| 868 |
+
line color.
|
| 869 |
+
|
| 870 |
+
Args:
|
| 871 |
+
keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
|
| 872 |
+
and the last dimension corresponds to (x, y, probability).
|
| 873 |
+
|
| 874 |
+
Returns:
|
| 875 |
+
output (VisImage): image object with visualizations.
|
| 876 |
+
"""
|
| 877 |
+
visible = {}
|
| 878 |
+
keypoint_names = self.metadata.get("keypoint_names")
|
| 879 |
+
for idx, keypoint in enumerate(keypoints):
|
| 880 |
+
# draw keypoint
|
| 881 |
+
x, y, prob = keypoint
|
| 882 |
+
if prob > self.keypoint_threshold:
|
| 883 |
+
self.draw_circle((x, y), color=_RED)
|
| 884 |
+
if keypoint_names:
|
| 885 |
+
keypoint_name = keypoint_names[idx]
|
| 886 |
+
visible[keypoint_name] = (x, y)
|
| 887 |
+
|
| 888 |
+
if self.metadata.get("keypoint_connection_rules"):
|
| 889 |
+
for kp0, kp1, color in self.metadata.keypoint_connection_rules:
|
| 890 |
+
if kp0 in visible and kp1 in visible:
|
| 891 |
+
x0, y0 = visible[kp0]
|
| 892 |
+
x1, y1 = visible[kp1]
|
| 893 |
+
color = tuple(x / 255.0 for x in color)
|
| 894 |
+
self.draw_line([x0, x1], [y0, y1], color=color)
|
| 895 |
+
|
| 896 |
+
# draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
|
| 897 |
+
# Note that this strategy is specific to person keypoints.
|
| 898 |
+
# For other keypoints, it should just do nothing
|
| 899 |
+
try:
|
| 900 |
+
ls_x, ls_y = visible["left_shoulder"]
|
| 901 |
+
rs_x, rs_y = visible["right_shoulder"]
|
| 902 |
+
mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
|
| 903 |
+
except KeyError:
|
| 904 |
+
pass
|
| 905 |
+
else:
|
| 906 |
+
# draw line from nose to mid-shoulder
|
| 907 |
+
nose_x, nose_y = visible.get("nose", (None, None))
|
| 908 |
+
if nose_x is not None:
|
| 909 |
+
self.draw_line(
|
| 910 |
+
[nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
try:
|
| 914 |
+
# draw line from mid-shoulder to mid-hip
|
| 915 |
+
lh_x, lh_y = visible["left_hip"]
|
| 916 |
+
rh_x, rh_y = visible["right_hip"]
|
| 917 |
+
except KeyError:
|
| 918 |
+
pass
|
| 919 |
+
else:
|
| 920 |
+
mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
|
| 921 |
+
self.draw_line(
|
| 922 |
+
[mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED
|
| 923 |
+
)
|
| 924 |
+
return self.output
|
| 925 |
+
|
| 926 |
+
def mask_dims_from_binary(self, binary_mask):
|
| 927 |
+
ind_y, ind_x = np.where(binary_mask == 1)
|
| 928 |
+
min_ind_x = np.min(ind_x)
|
| 929 |
+
max_ind_x = np.max(ind_x)
|
| 930 |
+
min_ind_y = np.min(ind_y)
|
| 931 |
+
max_ind_y = np.max(ind_y)
|
| 932 |
+
return (max_ind_x - min_ind_x), (max_ind_y - min_ind_y)
|
| 933 |
+
|
| 934 |
+
def reposition_label(self, position, cur, binary_mask, move_count):
|
| 935 |
+
img_width, img_height = self.output.width, self.output.height
|
| 936 |
+
mask_width, mask_height = self.mask_dims_from_binary(binary_mask)
|
| 937 |
+
|
| 938 |
+
# set resposition thresholds
|
| 939 |
+
mask_width_limit, mask_height_limit = (
|
| 940 |
+
25,
|
| 941 |
+
25,
|
| 942 |
+
) # limit for width and height size for object covering
|
| 943 |
+
location_diff_threshold = 15 # limit for the distance between two labels
|
| 944 |
+
x_boundry_limit, y_boundry_limit = (
|
| 945 |
+
20,
|
| 946 |
+
20,
|
| 947 |
+
) # limit for the distancing the label from edges
|
| 948 |
+
|
| 949 |
+
offset_x = 15 # move in x direction
|
| 950 |
+
offset_y = 15 # move in y direction
|
| 951 |
+
|
| 952 |
+
x1, y1 = position
|
| 953 |
+
|
| 954 |
+
if (
|
| 955 |
+
mask_width < mask_width_limit
|
| 956 |
+
and mask_height < mask_height_limit
|
| 957 |
+
and move_count == 0
|
| 958 |
+
):
|
| 959 |
+
move_x = offset_x if offset_x + x1 < img_width else -offset_x
|
| 960 |
+
move_y = offset_y if offset_y + y1 < img_height else -offset_y
|
| 961 |
+
return (True, move_x, move_y)
|
| 962 |
+
|
| 963 |
+
for x2, y2 in cur:
|
| 964 |
+
if abs(x1 - x2) + abs(y1 - y2) < location_diff_threshold:
|
| 965 |
+
move_x = offset_x if x1 >= x2 else -offset_x
|
| 966 |
+
move_y = offset_y if y1 >= y2 else -offset_y
|
| 967 |
+
move_x = (
|
| 968 |
+
0
|
| 969 |
+
if x1 + move_x > img_width - x_boundry_limit
|
| 970 |
+
or x1 + move_x < x_boundry_limit
|
| 971 |
+
else move_x
|
| 972 |
+
)
|
| 973 |
+
move_y = (
|
| 974 |
+
0
|
| 975 |
+
if y1 + move_y > img_height - y_boundry_limit
|
| 976 |
+
or y1 + move_y < y_boundry_limit
|
| 977 |
+
else move_y
|
| 978 |
+
)
|
| 979 |
+
return (
|
| 980 |
+
True,
|
| 981 |
+
move_x,
|
| 982 |
+
move_y,
|
| 983 |
+
)
|
| 984 |
+
return (False, 0, 0)
|
| 985 |
+
|
| 986 |
+
def locate_label_position(self, original_position, added_positions, binary_mask):
|
| 987 |
+
if added_positions is None or binary_mask is None:
|
| 988 |
+
return original_position
|
| 989 |
+
|
| 990 |
+
x, y = original_position
|
| 991 |
+
|
| 992 |
+
move_count = 0
|
| 993 |
+
reposition, x_move, y_move = self.reposition_label(
|
| 994 |
+
(x, y), added_positions, binary_mask, move_count
|
| 995 |
+
)
|
| 996 |
+
while reposition and move_count < 10:
|
| 997 |
+
x += x_move
|
| 998 |
+
y += y_move
|
| 999 |
+
move_count += 1
|
| 1000 |
+
reposition, x_move, y_move = self.reposition_label(
|
| 1001 |
+
(x, y), added_positions, binary_mask, move_count
|
| 1002 |
+
)
|
| 1003 |
+
added_positions.add((x, y))
|
| 1004 |
+
return x, y
|
| 1005 |
+
|
| 1006 |
+
"""
|
| 1007 |
+
Primitive drawing functions:
|
| 1008 |
+
"""
|
| 1009 |
+
|
| 1010 |
+
def draw_text(
|
| 1011 |
+
self,
|
| 1012 |
+
text,
|
| 1013 |
+
position,
|
| 1014 |
+
added_positions=None,
|
| 1015 |
+
binary_mask=None,
|
| 1016 |
+
*,
|
| 1017 |
+
font_size=None,
|
| 1018 |
+
color="g",
|
| 1019 |
+
horizontal_alignment="center",
|
| 1020 |
+
rotation=0,
|
| 1021 |
+
):
|
| 1022 |
+
"""
|
| 1023 |
+
Args:
|
| 1024 |
+
text (str): class label
|
| 1025 |
+
position (tuple): a tuple of the x and y coordinates to place text on image.
|
| 1026 |
+
font_size (int, optional): font of the text. If not provided, a font size
|
| 1027 |
+
proportional to the image width is calculated and used.
|
| 1028 |
+
color: color of the text. Refer to `matplotlib.colors` for full list
|
| 1029 |
+
of formats that are accepted.
|
| 1030 |
+
horizontal_alignment (str): see `matplotlib.text.Text`
|
| 1031 |
+
rotation: rotation angle in degrees CCW
|
| 1032 |
+
|
| 1033 |
+
Returns:
|
| 1034 |
+
output (VisImage): image object with text drawn.
|
| 1035 |
+
"""
|
| 1036 |
+
if not font_size:
|
| 1037 |
+
font_size = self._default_font_size
|
| 1038 |
+
|
| 1039 |
+
# since the text background is dark, we don't want the text to be dark
|
| 1040 |
+
color = np.maximum(list(mplc.to_rgb(color)), 0.15)
|
| 1041 |
+
color[np.argmax(color)] = max(0.8, np.max(color))
|
| 1042 |
+
|
| 1043 |
+
def contrasting_color(rgb):
|
| 1044 |
+
"""Returns 'white' or 'black' depending on which color contrasts more with the given RGB value."""
|
| 1045 |
+
|
| 1046 |
+
# Decompose the RGB tuple
|
| 1047 |
+
R, G, B = rgb
|
| 1048 |
+
|
| 1049 |
+
# Calculate the Y value
|
| 1050 |
+
Y = 0.299 * R + 0.587 * G + 0.114 * B
|
| 1051 |
+
|
| 1052 |
+
# If Y value is greater than 128, it's closer to white so return black. Otherwise, return white.
|
| 1053 |
+
return "black" if Y > 128 else "white"
|
| 1054 |
+
|
| 1055 |
+
bbox_background = contrasting_color(color * 255)
|
| 1056 |
+
|
| 1057 |
+
x, y = self.locate_label_position(
|
| 1058 |
+
original_position=position,
|
| 1059 |
+
added_positions=added_positions,
|
| 1060 |
+
binary_mask=binary_mask,
|
| 1061 |
+
)
|
| 1062 |
+
|
| 1063 |
+
self.output.ax.text(
|
| 1064 |
+
x,
|
| 1065 |
+
y,
|
| 1066 |
+
text,
|
| 1067 |
+
size=font_size * self.output.scale,
|
| 1068 |
+
family="sans-serif",
|
| 1069 |
+
bbox={
|
| 1070 |
+
"facecolor": bbox_background,
|
| 1071 |
+
"alpha": 0.8,
|
| 1072 |
+
"pad": 0.7,
|
| 1073 |
+
"edgecolor": "none",
|
| 1074 |
+
},
|
| 1075 |
+
verticalalignment="top",
|
| 1076 |
+
horizontalalignment=horizontal_alignment,
|
| 1077 |
+
color=color,
|
| 1078 |
+
zorder=10,
|
| 1079 |
+
rotation=rotation,
|
| 1080 |
+
)
|
| 1081 |
+
return self.output
|
| 1082 |
+
|
| 1083 |
+
def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
|
| 1084 |
+
"""
|
| 1085 |
+
Args:
|
| 1086 |
+
box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
|
| 1087 |
+
are the coordinates of the image's top left corner. x1 and y1 are the
|
| 1088 |
+
coordinates of the image's bottom right corner.
|
| 1089 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
| 1090 |
+
edge_color: color of the outline of the box. Refer to `matplotlib.colors`
|
| 1091 |
+
for full list of formats that are accepted.
|
| 1092 |
+
line_style (string): the string to use to create the outline of the boxes.
|
| 1093 |
+
|
| 1094 |
+
Returns:
|
| 1095 |
+
output (VisImage): image object with box drawn.
|
| 1096 |
+
"""
|
| 1097 |
+
x0, y0, x1, y1 = box_coord
|
| 1098 |
+
width = x1 - x0
|
| 1099 |
+
height = y1 - y0
|
| 1100 |
+
|
| 1101 |
+
linewidth = max(self._default_font_size / 12, 1) * self.boarder_width_multiplier
|
| 1102 |
+
|
| 1103 |
+
self.output.ax.add_patch(
|
| 1104 |
+
mpl.patches.Rectangle(
|
| 1105 |
+
(x0, y0),
|
| 1106 |
+
width,
|
| 1107 |
+
height,
|
| 1108 |
+
fill=False,
|
| 1109 |
+
edgecolor=edge_color,
|
| 1110 |
+
linewidth=linewidth * self.output.scale,
|
| 1111 |
+
alpha=alpha,
|
| 1112 |
+
linestyle=line_style,
|
| 1113 |
+
)
|
| 1114 |
+
)
|
| 1115 |
+
return self.output
|
| 1116 |
+
|
| 1117 |
+
def draw_rotated_box_with_label(
|
| 1118 |
+
self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
|
| 1119 |
+
):
|
| 1120 |
+
"""
|
| 1121 |
+
Draw a rotated box with label on its top-left corner.
|
| 1122 |
+
|
| 1123 |
+
Args:
|
| 1124 |
+
rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
|
| 1125 |
+
where cnt_x and cnt_y are the center coordinates of the box.
|
| 1126 |
+
w and h are the width and height of the box. angle represents how
|
| 1127 |
+
many degrees the box is rotated CCW with regard to the 0-degree box.
|
| 1128 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
| 1129 |
+
edge_color: color of the outline of the box. Refer to `matplotlib.colors`
|
| 1130 |
+
for full list of formats that are accepted.
|
| 1131 |
+
line_style (string): the string to use to create the outline of the boxes.
|
| 1132 |
+
label (string): label for rotated box. It will not be rendered when set to None.
|
| 1133 |
+
|
| 1134 |
+
Returns:
|
| 1135 |
+
output (VisImage): image object with box drawn.
|
| 1136 |
+
"""
|
| 1137 |
+
cnt_x, cnt_y, w, h, angle = rotated_box
|
| 1138 |
+
area = w * h
|
| 1139 |
+
# use thinner lines when the box is small
|
| 1140 |
+
linewidth = self._default_font_size / (
|
| 1141 |
+
6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
|
| 1142 |
+
)
|
| 1143 |
+
|
| 1144 |
+
theta = angle * math.pi / 180.0
|
| 1145 |
+
c = math.cos(theta)
|
| 1146 |
+
s = math.sin(theta)
|
| 1147 |
+
rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
|
| 1148 |
+
# x: left->right ; y: top->down
|
| 1149 |
+
rotated_rect = [
|
| 1150 |
+
(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect
|
| 1151 |
+
]
|
| 1152 |
+
for k in range(4):
|
| 1153 |
+
j = (k + 1) % 4
|
| 1154 |
+
self.draw_line(
|
| 1155 |
+
[rotated_rect[k][0], rotated_rect[j][0]],
|
| 1156 |
+
[rotated_rect[k][1], rotated_rect[j][1]],
|
| 1157 |
+
color=edge_color,
|
| 1158 |
+
linestyle="--" if k == 1 else line_style,
|
| 1159 |
+
linewidth=linewidth,
|
| 1160 |
+
)
|
| 1161 |
+
|
| 1162 |
+
if label is not None:
|
| 1163 |
+
text_pos = rotated_rect[1] # topleft corner
|
| 1164 |
+
|
| 1165 |
+
height_ratio = h / np.sqrt(self.output.height * self.output.width)
|
| 1166 |
+
label_color = self._change_color_brightness(
|
| 1167 |
+
edge_color, brightness_factor=0.7
|
| 1168 |
+
)
|
| 1169 |
+
font_size = (
|
| 1170 |
+
np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
|
| 1171 |
+
* 0.5
|
| 1172 |
+
* self._default_font_size
|
| 1173 |
+
)
|
| 1174 |
+
self.draw_text(
|
| 1175 |
+
label, text_pos, color=label_color, font_size=font_size, rotation=angle
|
| 1176 |
+
)
|
| 1177 |
+
|
| 1178 |
+
return self.output
|
| 1179 |
+
|
| 1180 |
+
def draw_circle(self, circle_coord, color, radius=3):
|
| 1181 |
+
"""
|
| 1182 |
+
Args:
|
| 1183 |
+
circle_coord (list(int) or tuple(int)): contains the x and y coordinates
|
| 1184 |
+
of the center of the circle.
|
| 1185 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
| 1186 |
+
formats that are accepted.
|
| 1187 |
+
radius (int): radius of the circle.
|
| 1188 |
+
|
| 1189 |
+
Returns:
|
| 1190 |
+
output (VisImage): image object with box drawn.
|
| 1191 |
+
"""
|
| 1192 |
+
x, y = circle_coord
|
| 1193 |
+
self.output.ax.add_patch(
|
| 1194 |
+
mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)
|
| 1195 |
+
)
|
| 1196 |
+
return self.output
|
| 1197 |
+
|
| 1198 |
+
def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
|
| 1199 |
+
"""
|
| 1200 |
+
Args:
|
| 1201 |
+
x_data (list[int]): a list containing x values of all the points being drawn.
|
| 1202 |
+
Length of list should match the length of y_data.
|
| 1203 |
+
y_data (list[int]): a list containing y values of all the points being drawn.
|
| 1204 |
+
Length of list should match the length of x_data.
|
| 1205 |
+
color: color of the line. Refer to `matplotlib.colors` for a full list of
|
| 1206 |
+
formats that are accepted.
|
| 1207 |
+
linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
|
| 1208 |
+
for a full list of formats that are accepted.
|
| 1209 |
+
linewidth (float or None): width of the line. When it's None,
|
| 1210 |
+
a default value will be computed and used.
|
| 1211 |
+
|
| 1212 |
+
Returns:
|
| 1213 |
+
output (VisImage): image object with line drawn.
|
| 1214 |
+
"""
|
| 1215 |
+
if linewidth is None:
|
| 1216 |
+
linewidth = self._default_font_size / 3
|
| 1217 |
+
linewidth = max(linewidth, 1)
|
| 1218 |
+
self.output.ax.add_line(
|
| 1219 |
+
mpl.lines.Line2D(
|
| 1220 |
+
x_data,
|
| 1221 |
+
y_data,
|
| 1222 |
+
linewidth=linewidth * self.output.scale,
|
| 1223 |
+
color=color,
|
| 1224 |
+
linestyle=linestyle,
|
| 1225 |
+
)
|
| 1226 |
+
)
|
| 1227 |
+
return self.output
|
| 1228 |
+
|
| 1229 |
+
def draw_binary_mask(
|
| 1230 |
+
self,
|
| 1231 |
+
binary_mask,
|
| 1232 |
+
color=None,
|
| 1233 |
+
*,
|
| 1234 |
+
edge_color=None,
|
| 1235 |
+
text=None,
|
| 1236 |
+
alpha=0.7,
|
| 1237 |
+
area_threshold=10,
|
| 1238 |
+
):
|
| 1239 |
+
"""
|
| 1240 |
+
Args:
|
| 1241 |
+
binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
|
| 1242 |
+
W is the image width. Each value in the array is either a 0 or 1 value of uint8
|
| 1243 |
+
type.
|
| 1244 |
+
color: color of the mask. Refer to `matplotlib.colors` for a full list of
|
| 1245 |
+
formats that are accepted. If None, will pick a random color.
|
| 1246 |
+
edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
|
| 1247 |
+
full list of formats that are accepted.
|
| 1248 |
+
text (str): if None, will be drawn on the object
|
| 1249 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
| 1250 |
+
area_threshold (float): a connected component smaller than this area will not be shown.
|
| 1251 |
+
|
| 1252 |
+
Returns:
|
| 1253 |
+
output (VisImage): image object with mask drawn.
|
| 1254 |
+
"""
|
| 1255 |
+
if color is None:
|
| 1256 |
+
color = random_color(rgb=True, maximum=1)
|
| 1257 |
+
color = mplc.to_rgb(color)
|
| 1258 |
+
|
| 1259 |
+
has_valid_segment = False
|
| 1260 |
+
binary_mask = binary_mask.astype("uint8") # opencv needs uint8
|
| 1261 |
+
mask = GenericMask(binary_mask, self.output.height, self.output.width)
|
| 1262 |
+
shape2d = (binary_mask.shape[0], binary_mask.shape[1])
|
| 1263 |
+
|
| 1264 |
+
if not mask.has_holes:
|
| 1265 |
+
# draw polygons for regular masks
|
| 1266 |
+
for segment in mask.polygons:
|
| 1267 |
+
area = mask_util.area(
|
| 1268 |
+
mask_util.frPyObjects([segment], shape2d[0], shape2d[1])
|
| 1269 |
+
)
|
| 1270 |
+
if area < (area_threshold or 0):
|
| 1271 |
+
continue
|
| 1272 |
+
has_valid_segment = True
|
| 1273 |
+
segment = segment.reshape(-1, 2)
|
| 1274 |
+
self.draw_polygon(
|
| 1275 |
+
segment, color=color, edge_color=edge_color, alpha=alpha
|
| 1276 |
+
)
|
| 1277 |
+
else:
|
| 1278 |
+
# https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
|
| 1279 |
+
rgba = np.zeros(shape2d + (4,), dtype="float32")
|
| 1280 |
+
rgba[:, :, :3] = color
|
| 1281 |
+
rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
|
| 1282 |
+
has_valid_segment = True
|
| 1283 |
+
self.output.ax.imshow(
|
| 1284 |
+
rgba, extent=(0, self.output.width, self.output.height, 0)
|
| 1285 |
+
)
|
| 1286 |
+
|
| 1287 |
+
if text is not None and has_valid_segment:
|
| 1288 |
+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
| 1289 |
+
self._draw_text_in_mask(binary_mask, text, lighter_color)
|
| 1290 |
+
return self.output
|
| 1291 |
+
|
| 1292 |
+
def draw_binary_mask_with_number(
|
| 1293 |
+
self,
|
| 1294 |
+
binary_mask,
|
| 1295 |
+
color=None,
|
| 1296 |
+
*,
|
| 1297 |
+
edge_color=None,
|
| 1298 |
+
text=None,
|
| 1299 |
+
label_mode="1",
|
| 1300 |
+
alpha=0.1,
|
| 1301 |
+
anno_mode=["Mask"],
|
| 1302 |
+
area_threshold=10,
|
| 1303 |
+
):
|
| 1304 |
+
"""
|
| 1305 |
+
Args:
|
| 1306 |
+
binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
|
| 1307 |
+
W is the image width. Each value in the array is either a 0 or 1 value of uint8
|
| 1308 |
+
type.
|
| 1309 |
+
color: color of the mask. Refer to `matplotlib.colors` for a full list of
|
| 1310 |
+
formats that are accepted. If None, will pick a random color.
|
| 1311 |
+
edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
|
| 1312 |
+
full list of formats that are accepted.
|
| 1313 |
+
text (str): if None, will be drawn on the object
|
| 1314 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
| 1315 |
+
area_threshold (float): a connected component smaller than this area will not be shown.
|
| 1316 |
+
|
| 1317 |
+
Returns:
|
| 1318 |
+
output (VisImage): image object with mask drawn.
|
| 1319 |
+
"""
|
| 1320 |
+
if color is None:
|
| 1321 |
+
randint = random.randint(0, len(self.color_proposals) - 1)
|
| 1322 |
+
color = self.color_proposals[randint]
|
| 1323 |
+
color = mplc.to_rgb(color)
|
| 1324 |
+
|
| 1325 |
+
has_valid_segment = True
|
| 1326 |
+
binary_mask = binary_mask.astype("uint8") # opencv needs uint8
|
| 1327 |
+
mask = GenericMask(binary_mask, self.output.height, self.output.width)
|
| 1328 |
+
shape2d = (binary_mask.shape[0], binary_mask.shape[1])
|
| 1329 |
+
bbox = mask.bbox()
|
| 1330 |
+
|
| 1331 |
+
if "Mask" in anno_mode:
|
| 1332 |
+
if not mask.has_holes:
|
| 1333 |
+
# draw polygons for regular masks
|
| 1334 |
+
for segment in mask.polygons:
|
| 1335 |
+
area = mask_util.area(
|
| 1336 |
+
mask_util.frPyObjects([segment], shape2d[0], shape2d[1])
|
| 1337 |
+
)
|
| 1338 |
+
if area < (area_threshold or 0):
|
| 1339 |
+
continue
|
| 1340 |
+
has_valid_segment = True
|
| 1341 |
+
segment = segment.reshape(-1, 2)
|
| 1342 |
+
self.draw_polygon(
|
| 1343 |
+
segment, color=color, edge_color=edge_color, alpha=alpha
|
| 1344 |
+
)
|
| 1345 |
+
else:
|
| 1346 |
+
# https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
|
| 1347 |
+
rgba = np.zeros(shape2d + (4,), dtype="float32")
|
| 1348 |
+
rgba[:, :, :3] = color
|
| 1349 |
+
rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
|
| 1350 |
+
has_valid_segment = True
|
| 1351 |
+
self.output.ax.imshow(
|
| 1352 |
+
rgba, extent=(0, self.output.width, self.output.height, 0)
|
| 1353 |
+
)
|
| 1354 |
+
|
| 1355 |
+
if "Box" in anno_mode:
|
| 1356 |
+
self.draw_box(bbox, edge_color=color, alpha=0.75)
|
| 1357 |
+
|
| 1358 |
+
if "Mark" in anno_mode:
|
| 1359 |
+
has_valid_segment = True
|
| 1360 |
+
else:
|
| 1361 |
+
has_valid_segment = False
|
| 1362 |
+
|
| 1363 |
+
if text is not None and has_valid_segment:
|
| 1364 |
+
# lighter_color = tuple([x*0.2 for x in color])
|
| 1365 |
+
lighter_color = [
|
| 1366 |
+
1,
|
| 1367 |
+
1,
|
| 1368 |
+
1,
|
| 1369 |
+
] # self._change_color_brightness(color, brightness_factor=0.7)
|
| 1370 |
+
self._draw_number_in_mask(
|
| 1371 |
+
binary_mask=binary_mask,
|
| 1372 |
+
text=text,
|
| 1373 |
+
color=lighter_color,
|
| 1374 |
+
label_mode=label_mode,
|
| 1375 |
+
)
|
| 1376 |
+
return self.output
|
| 1377 |
+
|
| 1378 |
+
def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5):
|
| 1379 |
+
"""
|
| 1380 |
+
Args:
|
| 1381 |
+
soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].
|
| 1382 |
+
color: color of the mask. Refer to `matplotlib.colors` for a full list of
|
| 1383 |
+
formats that are accepted. If None, will pick a random color.
|
| 1384 |
+
text (str): if None, will be drawn on the object
|
| 1385 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
| 1386 |
+
|
| 1387 |
+
Returns:
|
| 1388 |
+
output (VisImage): image object with mask drawn.
|
| 1389 |
+
"""
|
| 1390 |
+
if color is None:
|
| 1391 |
+
color = random_color(rgb=True, maximum=1)
|
| 1392 |
+
color = mplc.to_rgb(color)
|
| 1393 |
+
|
| 1394 |
+
shape2d = (soft_mask.shape[0], soft_mask.shape[1])
|
| 1395 |
+
rgba = np.zeros(shape2d + (4,), dtype="float32")
|
| 1396 |
+
rgba[:, :, :3] = color
|
| 1397 |
+
rgba[:, :, 3] = soft_mask * alpha
|
| 1398 |
+
self.output.ax.imshow(
|
| 1399 |
+
rgba, extent=(0, self.output.width, self.output.height, 0)
|
| 1400 |
+
)
|
| 1401 |
+
|
| 1402 |
+
if text is not None:
|
| 1403 |
+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
| 1404 |
+
binary_mask = (soft_mask > 0.5).astype("uint8")
|
| 1405 |
+
self._draw_text_in_mask(binary_mask, text, lighter_color)
|
| 1406 |
+
return self.output
|
| 1407 |
+
|
| 1408 |
+
def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
|
| 1409 |
+
"""
|
| 1410 |
+
Args:
|
| 1411 |
+
segment: numpy array of shape Nx2, containing all the points in the polygon.
|
| 1412 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
| 1413 |
+
formats that are accepted.
|
| 1414 |
+
edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
|
| 1415 |
+
full list of formats that are accepted. If not provided, a darker shade
|
| 1416 |
+
of the polygon color will be used instead.
|
| 1417 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
| 1418 |
+
|
| 1419 |
+
Returns:
|
| 1420 |
+
output (VisImage): image object with polygon drawn.
|
| 1421 |
+
"""
|
| 1422 |
+
if edge_color is None:
|
| 1423 |
+
# make edge color darker than the polygon color
|
| 1424 |
+
if alpha > 0.8:
|
| 1425 |
+
edge_color = self._change_color_brightness(
|
| 1426 |
+
color, brightness_factor=-0.7
|
| 1427 |
+
)
|
| 1428 |
+
else:
|
| 1429 |
+
edge_color = color
|
| 1430 |
+
edge_color = mplc.to_rgb(edge_color) + (1,)
|
| 1431 |
+
|
| 1432 |
+
polygon = mpl.patches.Polygon(
|
| 1433 |
+
segment,
|
| 1434 |
+
fill=True,
|
| 1435 |
+
facecolor=mplc.to_rgb(color) + (alpha,),
|
| 1436 |
+
edgecolor=edge_color,
|
| 1437 |
+
linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
|
| 1438 |
+
)
|
| 1439 |
+
self.output.ax.add_patch(polygon)
|
| 1440 |
+
return self.output
|
| 1441 |
+
|
| 1442 |
+
"""
|
| 1443 |
+
Internal methods:
|
| 1444 |
+
"""
|
| 1445 |
+
|
| 1446 |
+
def _jitter(self, color):
|
| 1447 |
+
"""
|
| 1448 |
+
Randomly modifies given color to produce a slightly different color than the color given.
|
| 1449 |
+
|
| 1450 |
+
Args:
|
| 1451 |
+
color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
|
| 1452 |
+
picked. The values in the list are in the [0.0, 1.0] range.
|
| 1453 |
+
|
| 1454 |
+
Returns:
|
| 1455 |
+
jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
|
| 1456 |
+
color after being jittered. The values in the list are in the [0.0, 1.0] range.
|
| 1457 |
+
"""
|
| 1458 |
+
color = mplc.to_rgb(color)
|
| 1459 |
+
# np.random.seed(0)
|
| 1460 |
+
vec = np.random.rand(3)
|
| 1461 |
+
# better to do it in another color space
|
| 1462 |
+
vec = vec / np.linalg.norm(vec) * 0.5
|
| 1463 |
+
res = np.clip(vec + color, 0, 1)
|
| 1464 |
+
return tuple(res)
|
| 1465 |
+
|
| 1466 |
+
def _create_grayscale_image(self, mask=None):
|
| 1467 |
+
"""
|
| 1468 |
+
Create a grayscale version of the original image.
|
| 1469 |
+
The colors in masked area, if given, will be kept.
|
| 1470 |
+
"""
|
| 1471 |
+
img_bw = self.img.astype("f4").mean(axis=2)
|
| 1472 |
+
img_bw = np.stack([img_bw] * 3, axis=2)
|
| 1473 |
+
if mask is not None:
|
| 1474 |
+
img_bw[mask] = self.img[mask]
|
| 1475 |
+
return img_bw
|
| 1476 |
+
|
| 1477 |
+
def _change_color_brightness(self, color, brightness_factor):
|
| 1478 |
+
"""
|
| 1479 |
+
Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
|
| 1480 |
+
less or more saturation than the original color.
|
| 1481 |
+
|
| 1482 |
+
Args:
|
| 1483 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
| 1484 |
+
formats that are accepted.
|
| 1485 |
+
brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
|
| 1486 |
+
0 will correspond to no change, a factor in [-1.0, 0) range will result in
|
| 1487 |
+
a darker color and a factor in (0, 1.0] range will result in a lighter color.
|
| 1488 |
+
|
| 1489 |
+
Returns:
|
| 1490 |
+
modified_color (tuple[double]): a tuple containing the RGB values of the
|
| 1491 |
+
modified color. Each value in the tuple is in the [0.0, 1.0] range.
|
| 1492 |
+
"""
|
| 1493 |
+
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
|
| 1494 |
+
color = mplc.to_rgb(color)
|
| 1495 |
+
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
|
| 1496 |
+
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
|
| 1497 |
+
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
|
| 1498 |
+
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
|
| 1499 |
+
modified_color = colorsys.hls_to_rgb(
|
| 1500 |
+
polygon_color[0], modified_lightness, polygon_color[2]
|
| 1501 |
+
)
|
| 1502 |
+
return modified_color
|
| 1503 |
+
|
| 1504 |
+
def _convert_boxes(self, boxes):
|
| 1505 |
+
"""
|
| 1506 |
+
Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
|
| 1507 |
+
"""
|
| 1508 |
+
if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
|
| 1509 |
+
return boxes.tensor.detach().numpy()
|
| 1510 |
+
else:
|
| 1511 |
+
return np.asarray(boxes)
|
| 1512 |
+
|
| 1513 |
+
def _convert_masks(self, masks_or_polygons):
|
| 1514 |
+
"""
|
| 1515 |
+
Convert different format of masks or polygons to a tuple of masks and polygons.
|
| 1516 |
+
|
| 1517 |
+
Returns:
|
| 1518 |
+
list[GenericMask]:
|
| 1519 |
+
"""
|
| 1520 |
+
|
| 1521 |
+
m = masks_or_polygons
|
| 1522 |
+
if isinstance(m, PolygonMasks):
|
| 1523 |
+
m = m.polygons
|
| 1524 |
+
if isinstance(m, BitMasks):
|
| 1525 |
+
m = m.tensor.numpy()
|
| 1526 |
+
if isinstance(m, torch.Tensor):
|
| 1527 |
+
m = m.numpy()
|
| 1528 |
+
ret = []
|
| 1529 |
+
for x in m:
|
| 1530 |
+
if isinstance(x, GenericMask):
|
| 1531 |
+
ret.append(x)
|
| 1532 |
+
else:
|
| 1533 |
+
ret.append(GenericMask(x, self.output.height, self.output.width))
|
| 1534 |
+
return ret
|
| 1535 |
+
|
| 1536 |
+
def _draw_number_in_box(self, box, text, color, label_mode="1"):
|
| 1537 |
+
"""
|
| 1538 |
+
Find proper places to draw text given a box.
|
| 1539 |
+
"""
|
| 1540 |
+
x0, y0, x1, y1 = box
|
| 1541 |
+
text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
|
| 1542 |
+
horiz_align = "left"
|
| 1543 |
+
# for small objects, draw text at the side to avoid occlusion
|
| 1544 |
+
instance_area = (y1 - y0) * (x1 - x0)
|
| 1545 |
+
if (
|
| 1546 |
+
instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
|
| 1547 |
+
or y1 - y0 < 40 * self.output.scale
|
| 1548 |
+
):
|
| 1549 |
+
if y1 >= self.output.height - 5:
|
| 1550 |
+
text_pos = (x1, y0)
|
| 1551 |
+
else:
|
| 1552 |
+
text_pos = (x0, y1)
|
| 1553 |
+
|
| 1554 |
+
height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
|
| 1555 |
+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
| 1556 |
+
font_size = (
|
| 1557 |
+
np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
|
| 1558 |
+
* 0.65
|
| 1559 |
+
* self._default_font_size
|
| 1560 |
+
)
|
| 1561 |
+
if label_mode == "a":
|
| 1562 |
+
text = self.number_to_string(int(text))
|
| 1563 |
+
else:
|
| 1564 |
+
text = text
|
| 1565 |
+
self.draw_text(
|
| 1566 |
+
text,
|
| 1567 |
+
text_pos,
|
| 1568 |
+
color=lighter_color,
|
| 1569 |
+
horizontal_alignment=horiz_align,
|
| 1570 |
+
font_size=font_size,
|
| 1571 |
+
)
|
| 1572 |
+
|
| 1573 |
+
return str(text)
|
| 1574 |
+
|
| 1575 |
+
@staticmethod
|
| 1576 |
+
def number_to_string(n):
|
| 1577 |
+
chars = []
|
| 1578 |
+
while n:
|
| 1579 |
+
n, remainder = divmod(n - 1, 26)
|
| 1580 |
+
chars.append(chr(97 + remainder))
|
| 1581 |
+
return "".join(reversed(chars))
|
| 1582 |
+
|
| 1583 |
+
def _draw_number_in_mask(
|
| 1584 |
+
self, binary_mask, text, color, added_positions=None, label_mode="1"
|
| 1585 |
+
):
|
| 1586 |
+
"""
|
| 1587 |
+
Find proper places to draw text given a binary mask.
|
| 1588 |
+
"""
|
| 1589 |
+
binary_mask = np.pad(binary_mask, ((1, 1), (1, 1)), "constant")
|
| 1590 |
+
mask_dt = cv2.distanceTransform(binary_mask, cv2.DIST_L2, 0)
|
| 1591 |
+
mask_dt = mask_dt[1:-1, 1:-1]
|
| 1592 |
+
max_dist = np.max(mask_dt)
|
| 1593 |
+
coords_y, coords_x = np.where(mask_dt == max_dist) # coords is [y, x]
|
| 1594 |
+
|
| 1595 |
+
if label_mode == "a":
|
| 1596 |
+
text = self.number_to_string(int(text))
|
| 1597 |
+
else:
|
| 1598 |
+
text = text
|
| 1599 |
+
|
| 1600 |
+
text_position = (
|
| 1601 |
+
coords_x[len(coords_x) // 2] + 2,
|
| 1602 |
+
coords_y[len(coords_y) // 2] - 6,
|
| 1603 |
+
)
|
| 1604 |
+
self.draw_text(
|
| 1605 |
+
text,
|
| 1606 |
+
text_position,
|
| 1607 |
+
added_positions=added_positions,
|
| 1608 |
+
binary_mask=binary_mask,
|
| 1609 |
+
color=color,
|
| 1610 |
+
)
|
| 1611 |
+
|
| 1612 |
+
return str(text), text_position
|
| 1613 |
+
|
| 1614 |
+
# _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
|
| 1615 |
+
# if stats[1:, -1].size == 0:
|
| 1616 |
+
# return
|
| 1617 |
+
# largest_component_id = np.argmax(stats[1:, -1]) + 1
|
| 1618 |
+
|
| 1619 |
+
# # draw text on the largest component, as well as other very large components.
|
| 1620 |
+
# for cid in range(1, _num_cc):
|
| 1621 |
+
# if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
|
| 1622 |
+
# # median is more stable than centroid
|
| 1623 |
+
# # center = centroids[largest_component_id]
|
| 1624 |
+
# center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
|
| 1625 |
+
# # bottom=np.max((cc_labels == cid).nonzero(), axis=1)[::-1]
|
| 1626 |
+
# # center[1]=bottom[1]+2
|
| 1627 |
+
# self.draw_text(text, center, color=color)
|
| 1628 |
+
|
| 1629 |
+
def _draw_text_in_mask(self, binary_mask, text, color):
|
| 1630 |
+
"""
|
| 1631 |
+
Find proper places to draw text given a binary mask.
|
| 1632 |
+
"""
|
| 1633 |
+
_num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(
|
| 1634 |
+
binary_mask, 8
|
| 1635 |
+
)
|
| 1636 |
+
if stats[1:, -1].size == 0:
|
| 1637 |
+
return
|
| 1638 |
+
largest_component_id = np.argmax(stats[1:, -1]) + 1
|
| 1639 |
+
|
| 1640 |
+
# draw text on the largest component, as well as other very large components.
|
| 1641 |
+
for cid in range(1, _num_cc):
|
| 1642 |
+
if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
|
| 1643 |
+
# median is more stable than centroid
|
| 1644 |
+
# center = centroids[largest_component_id]
|
| 1645 |
+
center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
|
| 1646 |
+
bottom = np.max((cc_labels == cid).nonzero(), axis=1)[::-1]
|
| 1647 |
+
center[1] = bottom[1] + 2
|
| 1648 |
+
self.draw_text(text, center, color=color)
|
| 1649 |
+
|
| 1650 |
+
def _convert_keypoints(self, keypoints):
|
| 1651 |
+
if isinstance(keypoints, Keypoints):
|
| 1652 |
+
keypoints = keypoints.tensor
|
| 1653 |
+
keypoints = np.asarray(keypoints)
|
| 1654 |
+
return keypoints
|
| 1655 |
+
|
| 1656 |
+
def get_output(self):
|
| 1657 |
+
"""
|
| 1658 |
+
Returns:
|
| 1659 |
+
output (VisImage): the image output containing the visualizations added
|
| 1660 |
+
to the image.
|
| 1661 |
+
"""
|
| 1662 |
+
return self.output
|
sam3/agent/helpers/zoom_in.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import io
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pycocotools.mask as mask_utils
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
from .som_utils import ColorPalette, draw_box, draw_mask, draw_text
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def render_zoom_in(
|
| 15 |
+
object_data,
|
| 16 |
+
image_file,
|
| 17 |
+
show_box: bool = True,
|
| 18 |
+
show_text: bool = False,
|
| 19 |
+
show_holes: bool = True,
|
| 20 |
+
mask_alpha: float = 0.15,
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
Render a two-panel visualization with a cropped original view (left/upper) and a zoomed-in
|
| 24 |
+
mask overlay (right/lower), then return it as a PIL.Image along with the chosen mask color (hex).
|
| 25 |
+
|
| 26 |
+
Parameters
|
| 27 |
+
----------
|
| 28 |
+
object_data : dict
|
| 29 |
+
Dict containing "labels" and COCO RLE "segmentation".
|
| 30 |
+
Expected:
|
| 31 |
+
object_data["labels"][0]["noun_phrase"] : str
|
| 32 |
+
object_data["segmentation"] : COCO RLE (with "size": [H, W])
|
| 33 |
+
image_file : PIL.Image.Image
|
| 34 |
+
Source image (PIL).
|
| 35 |
+
show_box : bool
|
| 36 |
+
Whether to draw the bbox on the cropped original panel.
|
| 37 |
+
show_text : bool
|
| 38 |
+
Whether to draw the noun phrase label near the bbox.
|
| 39 |
+
show_holes : bool
|
| 40 |
+
Whether to render mask holes (passed through to draw_mask).
|
| 41 |
+
mask_alpha : float
|
| 42 |
+
Alpha for the mask overlay.
|
| 43 |
+
|
| 44 |
+
Returns
|
| 45 |
+
-------
|
| 46 |
+
pil_img : PIL.Image.Image
|
| 47 |
+
The composed visualization image.
|
| 48 |
+
color_hex : str
|
| 49 |
+
Hex string of the chosen mask color.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
# ---- local constants (avoid module-level globals) ----
|
| 53 |
+
_AREA_LARGE = 0.25
|
| 54 |
+
_AREA_MEDIUM = 0.05
|
| 55 |
+
|
| 56 |
+
# ---- local helpers (avoid name collisions in a larger class) ----
|
| 57 |
+
def _get_shift(x, w, w_new, w_img):
|
| 58 |
+
assert 0 <= w_new <= w_img
|
| 59 |
+
shift = (w_new - w) / 2
|
| 60 |
+
if x - shift + w_new > w_img:
|
| 61 |
+
shift = x + w_new - w_img
|
| 62 |
+
return min(x, shift)
|
| 63 |
+
|
| 64 |
+
def _get_zoom_in_box(mask_box_xywh, img_h, img_w, mask_area):
|
| 65 |
+
box_w, box_h = mask_box_xywh[2], mask_box_xywh[3]
|
| 66 |
+
w_new = min(box_w + max(0.2 * box_w, 16), img_w)
|
| 67 |
+
h_new = min(box_h + max(0.2 * box_h, 16), img_h)
|
| 68 |
+
|
| 69 |
+
mask_relative_area = mask_area / (w_new * h_new)
|
| 70 |
+
|
| 71 |
+
# zoom-in (larger box if mask is relatively big)
|
| 72 |
+
w_new_large, h_new_large = w_new, h_new
|
| 73 |
+
if mask_relative_area > _AREA_LARGE:
|
| 74 |
+
ratio_large = math.sqrt(mask_relative_area / _AREA_LARGE)
|
| 75 |
+
w_new_large = min(w_new * ratio_large, img_w)
|
| 76 |
+
h_new_large = min(h_new * ratio_large, img_h)
|
| 77 |
+
|
| 78 |
+
w_shift_large = _get_shift(
|
| 79 |
+
mask_box_xywh[0], mask_box_xywh[2], w_new_large, img_w
|
| 80 |
+
)
|
| 81 |
+
h_shift_large = _get_shift(
|
| 82 |
+
mask_box_xywh[1], mask_box_xywh[3], h_new_large, img_h
|
| 83 |
+
)
|
| 84 |
+
zoom_in_box = [
|
| 85 |
+
mask_box_xywh[0] - w_shift_large,
|
| 86 |
+
mask_box_xywh[1] - h_shift_large,
|
| 87 |
+
w_new_large,
|
| 88 |
+
h_new_large,
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
# crop box for the original/cropped image
|
| 92 |
+
w_new_medium, h_new_medium = w_new, h_new
|
| 93 |
+
if mask_relative_area > _AREA_MEDIUM:
|
| 94 |
+
ratio_med = math.sqrt(mask_relative_area / _AREA_MEDIUM)
|
| 95 |
+
w_new_medium = min(w_new * ratio_med, img_w)
|
| 96 |
+
h_new_medium = min(h_new * ratio_med, img_h)
|
| 97 |
+
|
| 98 |
+
w_shift_medium = _get_shift(
|
| 99 |
+
mask_box_xywh[0], mask_box_xywh[2], w_new_medium, img_w
|
| 100 |
+
)
|
| 101 |
+
h_shift_medium = _get_shift(
|
| 102 |
+
mask_box_xywh[1], mask_box_xywh[3], h_new_medium, img_h
|
| 103 |
+
)
|
| 104 |
+
img_crop_box = [
|
| 105 |
+
mask_box_xywh[0] - w_shift_medium,
|
| 106 |
+
mask_box_xywh[1] - h_shift_medium,
|
| 107 |
+
w_new_medium,
|
| 108 |
+
h_new_medium,
|
| 109 |
+
]
|
| 110 |
+
return zoom_in_box, img_crop_box
|
| 111 |
+
|
| 112 |
+
# ---- main body ----
|
| 113 |
+
# Input parsing
|
| 114 |
+
object_label = object_data["labels"][0]["noun_phrase"]
|
| 115 |
+
img = image_file.convert("RGB")
|
| 116 |
+
bbox_xywh = mask_utils.toBbox(object_data["segmentation"]) # [x, y, w, h]
|
| 117 |
+
|
| 118 |
+
# Choose a stable, visually distant color based on crop
|
| 119 |
+
bbox_xyxy = [
|
| 120 |
+
bbox_xywh[0],
|
| 121 |
+
bbox_xywh[1],
|
| 122 |
+
bbox_xywh[0] + bbox_xywh[2],
|
| 123 |
+
bbox_xywh[1] + bbox_xywh[3],
|
| 124 |
+
]
|
| 125 |
+
crop_img = img.crop(bbox_xyxy)
|
| 126 |
+
color_palette = ColorPalette.default()
|
| 127 |
+
color_obj, _ = color_palette.find_farthest_color(np.array(crop_img))
|
| 128 |
+
color = np.array([color_obj.r / 255, color_obj.g / 255, color_obj.b / 255])
|
| 129 |
+
color_hex = f"#{color_obj.r:02x}{color_obj.g:02x}{color_obj.b:02x}"
|
| 130 |
+
|
| 131 |
+
# Compute zoom-in / crop boxes
|
| 132 |
+
img_h, img_w = object_data["segmentation"]["size"]
|
| 133 |
+
mask_area = mask_utils.area(object_data["segmentation"])
|
| 134 |
+
zoom_in_box, img_crop_box = _get_zoom_in_box(bbox_xywh, img_h, img_w, mask_area)
|
| 135 |
+
|
| 136 |
+
# Layout choice
|
| 137 |
+
w, h = img_crop_box[2], img_crop_box[3]
|
| 138 |
+
if w < h:
|
| 139 |
+
fig, (ax1, ax2) = plt.subplots(1, 2)
|
| 140 |
+
else:
|
| 141 |
+
fig, (ax1, ax2) = plt.subplots(2, 1)
|
| 142 |
+
|
| 143 |
+
# Panel 1: cropped original with optional box/text
|
| 144 |
+
img_crop_box_xyxy = [
|
| 145 |
+
img_crop_box[0],
|
| 146 |
+
img_crop_box[1],
|
| 147 |
+
img_crop_box[0] + img_crop_box[2],
|
| 148 |
+
img_crop_box[1] + img_crop_box[3],
|
| 149 |
+
]
|
| 150 |
+
img1 = img.crop(img_crop_box_xyxy)
|
| 151 |
+
bbox_xywh_rel = [
|
| 152 |
+
bbox_xywh[0] - img_crop_box[0],
|
| 153 |
+
bbox_xywh[1] - img_crop_box[1],
|
| 154 |
+
bbox_xywh[2],
|
| 155 |
+
bbox_xywh[3],
|
| 156 |
+
]
|
| 157 |
+
ax1.imshow(img1)
|
| 158 |
+
ax1.axis("off")
|
| 159 |
+
if show_box:
|
| 160 |
+
draw_box(ax1, bbox_xywh_rel, edge_color=color)
|
| 161 |
+
if show_text:
|
| 162 |
+
x0, y0 = bbox_xywh_rel[0] + 2, bbox_xywh_rel[1] + 2
|
| 163 |
+
draw_text(ax1, object_label, [x0, y0], color=color)
|
| 164 |
+
|
| 165 |
+
# Panel 2: zoomed-in mask overlay
|
| 166 |
+
binary_mask = mask_utils.decode(object_data["segmentation"])
|
| 167 |
+
alpha = Image.fromarray((binary_mask * 255).astype("uint8"))
|
| 168 |
+
img_rgba = img.convert("RGBA")
|
| 169 |
+
img_rgba.putalpha(alpha)
|
| 170 |
+
zoom_in_box_xyxy = [
|
| 171 |
+
zoom_in_box[0],
|
| 172 |
+
zoom_in_box[1],
|
| 173 |
+
zoom_in_box[0] + zoom_in_box[2],
|
| 174 |
+
zoom_in_box[1] + zoom_in_box[3],
|
| 175 |
+
]
|
| 176 |
+
img_with_alpha_zoomin = img_rgba.crop(zoom_in_box_xyxy)
|
| 177 |
+
alpha_zoomin = img_with_alpha_zoomin.split()[3]
|
| 178 |
+
binary_mask_zoomin = np.array(alpha_zoomin).astype(bool)
|
| 179 |
+
|
| 180 |
+
ax2.imshow(img_with_alpha_zoomin.convert("RGB"))
|
| 181 |
+
ax2.axis("off")
|
| 182 |
+
draw_mask(
|
| 183 |
+
ax2, binary_mask_zoomin, color=color, show_holes=show_holes, alpha=mask_alpha
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
plt.tight_layout()
|
| 187 |
+
|
| 188 |
+
# Buffer -> PIL.Image
|
| 189 |
+
buf = io.BytesIO()
|
| 190 |
+
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=100)
|
| 191 |
+
plt.close(fig)
|
| 192 |
+
buf.seek(0)
|
| 193 |
+
pil_img = Image.open(buf)
|
| 194 |
+
|
| 195 |
+
return pil_img, color_hex
|
sam3/agent/inference.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from sam3.agent.agent_core import agent_inference
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run_single_image_inference(
|
| 10 |
+
image_path,
|
| 11 |
+
text_prompt,
|
| 12 |
+
llm_config,
|
| 13 |
+
send_generate_request,
|
| 14 |
+
call_sam_service,
|
| 15 |
+
output_dir="agent_output",
|
| 16 |
+
debug=False,
|
| 17 |
+
):
|
| 18 |
+
"""Run inference on a single image with provided prompt"""
|
| 19 |
+
|
| 20 |
+
llm_name = llm_config["name"]
|
| 21 |
+
|
| 22 |
+
if not os.path.exists(image_path):
|
| 23 |
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
| 24 |
+
|
| 25 |
+
# Create output directory
|
| 26 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 27 |
+
|
| 28 |
+
# Generate output file names
|
| 29 |
+
image_basename = os.path.splitext(os.path.basename(image_path))[0]
|
| 30 |
+
prompt_for_filename = text_prompt.replace("/", "_").replace(" ", "_")
|
| 31 |
+
|
| 32 |
+
base_filename = f"{image_basename}_{prompt_for_filename}_agent_{llm_name}"
|
| 33 |
+
output_json_path = os.path.join(output_dir, f"{base_filename}_pred.json")
|
| 34 |
+
output_image_path = os.path.join(output_dir, f"{base_filename}_pred.png")
|
| 35 |
+
agent_history_path = os.path.join(output_dir, f"{base_filename}_history.json")
|
| 36 |
+
|
| 37 |
+
# Check if output already exists and skip
|
| 38 |
+
if os.path.exists(output_json_path):
|
| 39 |
+
print(f"Output JSON {output_json_path} already exists. Skipping.")
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
print(f"{'-'*30} Starting SAM 3 Agent Session... {'-'*30} ")
|
| 43 |
+
agent_history, final_output_dict, rendered_final_output = agent_inference(
|
| 44 |
+
image_path,
|
| 45 |
+
text_prompt,
|
| 46 |
+
send_generate_request=send_generate_request,
|
| 47 |
+
call_sam_service=call_sam_service,
|
| 48 |
+
output_dir=output_dir,
|
| 49 |
+
debug=debug,
|
| 50 |
+
)
|
| 51 |
+
print(f"{'-'*30} End of SAM 3 Agent Session... {'-'*30} ")
|
| 52 |
+
|
| 53 |
+
final_output_dict["text_prompt"] = text_prompt
|
| 54 |
+
final_output_dict["image_path"] = image_path
|
| 55 |
+
|
| 56 |
+
# Save outputs
|
| 57 |
+
json.dump(final_output_dict, open(output_json_path, "w"), indent=4)
|
| 58 |
+
json.dump(agent_history, open(agent_history_path, "w"), indent=4)
|
| 59 |
+
rendered_final_output.save(output_image_path)
|
| 60 |
+
|
| 61 |
+
print(f"\n✅ Successfully processed single image!")
|
| 62 |
+
print(f"Output JSON: {output_json_path}")
|
| 63 |
+
print(f"Output Image: {output_image_path}")
|
| 64 |
+
print(f"Agent History: {agent_history_path}")
|
| 65 |
+
return output_image_path
|
sam3/agent/system_prompts/system_prompt.txt
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are a helpful visual-concept grounding assistant capable of leveraging tool calls to ground concepts the user refers to, and providing structured JSON outputs and tool calls.
|
| 2 |
+
The user may provide you with a referring expression that matches some part(s) of the image, or a question whose answer points to some part(s) of the image.
|
| 3 |
+
You should observe and analyze the image along with the initial user input query very carefully, note all details in the image, think about what the user is actually referring to, how to leverage existing tools below to ground the target(s), and then call exactly one tool per turn.
|
| 4 |
+
At each turn, all available mask(s) will be renumbered and re-rendered on the most recent image provided to you. The numbering and coloring can be different from previous turns. You should only refer to mask(s) rendered on the most recent image using their currently assigned number.
|
| 5 |
+
If a tool call does not produce the intended output, do not give up; be creative and try calling the segment_phrase tool again with different parameters, or try a different tool. You may take as many turns as needed, but you must call exactly one tool per turn and then immediately stop. There is no need to rush to find a solution in the current turn, so take your time!
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
How you should understand the initial user input query and the raw input image:
|
| 9 |
+
|
| 10 |
+
1. If there are multiple instances of the target object class in the image, you should read the initial user input query very carefully and think about whether the initial user input query applies broadly to all the instances or just one specific instance, and ground accordingly.
|
| 11 |
+
2. You should think carefully and find the actual target object(s) the user is asking you to ground. Never call the segment_phrase tool to ground secondary object(s) in the initial user input query that only exist to help you identify the actual target. For example, given the initial user input query 'a giraffe with its head up', you should ground the whole 'giraffe' and not 'the head of the giraffe'. Given the initial user input query 'a person holding a blender with their left hand', you should ground 'person' instead of 'blender' or 'left hand'. Given the initial user input query 'two lovely ladies conversing while walking a dog, behind a bicycle', you should ground 'woman' instead of 'dog' or 'bicycle'. Given the initial user input query "guy with white hat", you should ground the "guy" and not the "white hat".
|
| 12 |
+
3. Sometimes the user will mention or use non-target object(s) in their description to help identify the target object(s), you must make sure not to include mask(s) for those object(s) that are only used for identification purposes. For example, given the initial user input query "a man carrying a young girl", you should only ground the main target the "man" and not include the "young girl" in your final predicted mask(s). Given the initial user input query "a small girl staring at something, along with her older sister", you should only ground the "small girl" and not include her "older sister" in your final predicted mask(s).
|
| 13 |
+
4. Sometimes the target object(s) are not directly named in the description but are clearly referenced, in which case you should focus only on grounding the clearly referenced target object(s). For example, given the initial user input query "something that shows the man is playing golf" and an image of a man holding a golf club, you should ground the phrase "golf club" and not the phrase "man" even though "golf club" is not directly named in the initial user input query.
|
| 14 |
+
5. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query.
|
| 15 |
+
6. Sometimes the initial user input query can be slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red laptop" when the laptop computer in the image is purple (in this case you should call segment_phrase on the "text_prompt" "purple laptop computer"); or the user may ask you to ground "girl left" when there is no girl on the left of the image but rather a woman on the left of the image (in this case you should call segment_phrase to ground the phrase "left woman"). In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query. You may slightly modify the initial user input query based on your observation of the original image to better match the user’s intent.
|
| 16 |
+
7. Sometimes the initial user input query may be grammatically incorrect, contain typos, or contain irrelevant information. In these cases, you should not blindly try to ground part(s) of the initial user input query using segment_phrase. Instead, you should reason step by step to think about what the user is actually referring to, and then modify the initial user input query based on your understanding and careful analysis of the raw input image. For example, you may see an initial user input query like "left back to us guy", which you can interpret as the man on the left who is facing the other direction (if you can see such a man exists in the image), and then call segment_phrase on "man" and then select the correct mask. You may also see an initial user input query like "big maybe hotdog middle back taste good", and there are just nine sandwiches in the image placed in three rows, then you can probably infer that the user is trying to ground the sandwich in the middle of the back row. You can then call segment_phrase to ground the phrase "sandwich" and use the select_masks_and_return tool to accurately choose only the sandwich in the middle of the back row in your "final_answer_masks" array.
|
| 17 |
+
8. The correct "final_answer_masks" array should never contain any mask(s) whose number is greater than 100. For example, you may never select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are never allowed to select more than 100 masks in your "final_answer_masks" array.
|
| 18 |
+
9. Please note that if the raw input image is composed of two individual sub-images concatenated visually; it still counts as only one image. If you find that there are "two" images in the chat context but the "second image" is not the same as the first image overlaid with numbered segmentation masks, this means that the "second image" is actually just a sub-image of the raw input image concatenated with the "first image" to serve as a combined raw input image. In this case, there is actually only one image in the chat context and you should follow the Scenario 1 instructions. This is very important!
|
| 19 |
+
|
| 20 |
+
You should always follow the response format defined below and complete the Steps for Each Turn as specified below. Never break the specified format for any reason.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
Available tools:
|
| 24 |
+
|
| 25 |
+
segment_phrase: Use the experimental Segment Anything 3 model to ground all instances of a simple noun phrase by generating segmentation mask(s) that cover those instances on the raw input image. At the same time, all previously generated mask(s) will be deleted and cannot be referred to in future messages.
|
| 26 |
+
Use cases: "Given a simple, direct, and singular noun phrase (not a referring expression that requires additional understanding/reasoning), segment_phrase will try to locate all object instance(s) on the raw input image that match the simple noun phrase you provided. The tool will also render all of the generated segmentation mask(s) onto the image for you to examine and decide the next step."
|
| 27 |
+
Parameters for segment_phrase: {"type": "object", "properties": {"text_prompt": {"type": "string", "description": "A short and simple noun phrase, e.g., rope, bird beak, speed monitor, brown handbag, person torso"}}, "required": ["text_prompt"]}
|
| 28 |
+
Return type: A new image with differently colored segmentation mask(s) rendered on it, and a text message indicating the number of mask(s) generated by the experimental Segment Anything 3 model for this "text_prompt" only.
|
| 29 |
+
Important rules for using the segment_phrase tool:
|
| 30 |
+
1. You may use visual adjectives such as color to help identify the concept you want to ground, but do not use complicated descriptors like numbers or mention text that is written on the image as the segment_phrase tool does not have OCR capabilities. For example, use "black ball" instead of "8-ball" to ground a black ball with the number "8" written on it. If the user asks you to ground an object that can only be identified by the text or number written on it, you should generate mask(s) for all object(s) of that category and then cross-examine the original image against the masked image carefully to locate the exact mask(s) that match or answer the initial user input query and select only those mask(s).
|
| 31 |
+
2. Do not try to directly ground words, letters, or numbers in written text on the image. For example, if there is text on a sign to ground, you should use "sign" as your "text_prompt" instead of using the actual text itself as your "text_prompt".
|
| 32 |
+
3. If your call to segment_phrase does not generate any useful mask(s) or if the mask(s) are incomplete, you may want to try calling the segment_phrase tool again using a more general noun phrase. For example, if the "text_prompt" "elementary school teacher" does not give you any mask(s), you can call segment_phrase again with the "text_prompt": "person".
|
| 33 |
+
4. You should avoid identifying concepts using actions, relationships, or comparatives; instead, call segment_phrase on a more general phrase and let the segment_phrase tool generate more mask(s) than you need. Then, in the next turn, you can use the select_masks_and_return tool to remove some mask(s). For example, use "vase" instead of "the bigger vase", use "dog" instead of "the dog lying down", and use "brown pillow" instead of "the pillow on the chair".
|
| 34 |
+
5. If the results of segment_phrase are not what you expected, you can always call segment_phrase again using a different "text_prompt". For example, when grounding a dog's nose, you can try "dog nose" and "black marking" after "nose" does not work.
|
| 35 |
+
6. Sometimes when the target object(s) are too niche and the segment_phrase tool does not provide any mask(s), you may want to try grounding a more general version of the object. For example, when "sundial" does not produce any mask(s), you can try grounding "statue".
|
| 36 |
+
7. Be concise and get the right keywords; don't make your "text_prompt" long.
|
| 37 |
+
8. Do not ever use the exact same "text_prompt" more than once. This is very important!
|
| 38 |
+
9. Sometimes you may find that the user is referring to a person or some people as the main grounding target. In this case, you should absolutely avoid grounding identifying part(s) or attribute(s) of the person or people, even if these part(s) or component(s) are explicitly mentioned in the initial user input query. Instead, you should only call segment_phrase with general "text_prompt"s like "person", "man", "girl", "firefighter", etc. that refer to the person as a whole. Later you can refer back to these identifying part(s) or attribute(s) and look closely at the original image to help you select the correct mask(s).
|
| 39 |
+
10. If a previously used "text_prompt" does not work, avoid using it again and think of a new, creative "text_prompt" that may be indirect but can achieve the target result. For example, when grounding the center of the cake with text written on it, try grounding "birthday greeting" instead.
|
| 40 |
+
11. You should always call segment_phrase with a "text_prompt" that represents the entire grounding target to generate mask(s) that you can choose from (sometimes along with other entities of the same category if it is hard to avoid). Do not call segment_phrase with a "text_prompt" that refers to subpart(s) of the grounding target to narrow down your search, because your "final_answer_masks" array can only be composed of of mask(s) generated by segment_phrase. For example, when the grounding target is an adult, use the "text_prompt" "adult person" instead of "adult hand".
|
| 41 |
+
12. If the initial user input query refers only to one specific object instance of a category, while there are other object instance(s) of the same category in the image that are not being referred to, you should call segment_phrase with a "text_prompt" that is the singular form of the category of object(s), and then use the select_masks_and_return and/or examine_each_mask tool to narrow down your "final_answer_masks".
|
| 42 |
+
13. Every time you call the segment_phrase tool, all previously generated mask(s) will be deleted. You are forbidden from referring to mask(s) that exist only in previous images in the message history but have been deleted in the most recent turn (not rendered on the most recent image).
|
| 43 |
+
14. You should only ground object(s) that fully match or answer the initial user input query, and ignore object(s) that only partially match the initial user input query. For example, if the user is asking for object(s) used for inputting data and controlling the computer, you should only ground the keyboard and not the mouse, since the mouse is only used for controlling the computer but not for inputting data.
|
| 44 |
+
15. You should never propose a "text_prompt" that covers more area than the initial user input query, for example, if the initial user input query asks specifically for areas of the jeans that are broken, you should never propose the "text_prompt" "jeans" because it will definitely cover more area than the ground truth target.
|
| 45 |
+
16. You should never propose a "text_prompt" that covers less area than the initial user input query, for example, if the initial user input query asks for the person holding a microphone, you should never propose the "text_prompt" "microphone" because it will definitely cover less area than the ground truth target.
|
| 46 |
+
17. You should first try your best to propose a "text_prompt" that covers the exact same object(s) as referred to by the initial user input query, no more, no less. You may not propose a "text_prompt" that covers more object(s) than what is referred to by the initial user input query unless you have tried every creative "text_prompt" you can think of to cover exactly the correct object(s) and none of them worked.
|
| 47 |
+
18. Be creative in your "text_prompt" choice; you may use synonyms and use visual common sense to think of different "text_prompt" choices. You have unlimited turns to call each tool, so take your time!
|
| 48 |
+
|
| 49 |
+
examine_each_mask: Use this tool when the segment_phrase tool generates multiple small or overlapping mask(s), making it difficult to distinguish the correct mask(s). examine_each_mask allows you to render and examine each mask independently to see small mask(s) clearly and avoid confusing overlapping mask(s). (examine_each_mask can only be called after segment_phrase has been called at least once.)
|
| 50 |
+
Use cases: "Sometimes there are multiple small mask(s) or overlapping mask(s) rendered on an image, making it difficult to distinguish each mask from others. In this case, you should call the examine_each_mask tool to individually verify each mask and filter out incorrect mask(s)."
|
| 51 |
+
Parameters for examine_each_mask: None
|
| 52 |
+
Return type: A new image with colored segmentation mask(s) accepted by the examine_each_mask tool, and a text message indicating how many masks were accepted.
|
| 53 |
+
Important rules for using the examine_each_mask tool:
|
| 54 |
+
1. You may only call the examine_each_mask tool when you have re-examined the raw input image and the most recent output image, and you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, and there are no missing correct mask(s). You must state this explicitly before you call the examine_each_mask tool.
|
| 55 |
+
2. Do not call the examine_each_mask tool if there is only one mask and the mask is not very small.
|
| 56 |
+
3. Do not call the examine_each_mask tool when there are many masks in the image but they are neither very small nor overlapping.
|
| 57 |
+
4. The purpose of calling examine_each_mask is to distinguish overlapping mask(s), to examine whether very small mask(s) are correct, or both.
|
| 58 |
+
5. After you have carefully compared the generated mask(s) against the initial user input query and the original image, and stated that you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, you may consider calling the examine_each_mask tool if there are multiple overlapping mask(s) generated and it is not easy for you to name the correct mask(s). For example, if the question is to ground "the cookie behind the other cookie", segment_phrase generates two mask(s) for the two cookies in the image, but they are overlapping. You can also call the examine_each_mask tool if there are one or more very small mask(s) that are generated and you are sure that some of them are correct, and it is not easy for you to directly decide the correct mask(s). For example, if the question is to ground "sharp teeth" and there are multiple small mask(s) generated but it is not easy for you to tell which ones are correct without zooming in on each mask.
|
| 59 |
+
6. Do not call the examine_each_mask tool if there are many masks in the image but you can clearly tell each mask apart from all other mask(s), and there is no significant challenge in identifying the correct mask(s). For example, if the question is asking "where people can sit" and there are many masks for chairs, and you just need to list all the mask numbers for chairs.
|
| 60 |
+
7. You may not call the examine_each_mask tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image.
|
| 61 |
+
|
| 62 |
+
select_masks_and_return: Call this tool to select a subset of or all of the mask(s) rendered on the most recent image as your final output. When calling select_masks_and_return, you cannot select any mask(s) generated by previous rounds other than the most recent round in your "final_answer_masks". You can only use mask(s) from the most recent image in your message history. (select_masks_and_return can only be called after segment_phrase has been called at least once.)
|
| 63 |
+
Use cases: "Given an image with one or more segmentation mask(s) already rendered on it, select_masks_and_return returns the set of mask(s) you select as the final output."
|
| 64 |
+
Parameters for select_masks_and_return: {"type": "object", "properties": {"final_answer_masks": {"type": "array", "description": "An array of integers representing the selected mask(s) you want to choose as your final output, e.g., [1, 4, 5]"}}, "required": ["final_answer_masks"]}
|
| 65 |
+
Return type: None (End of Conversation)
|
| 66 |
+
Important rules for using the select_masks_and_return tool:
|
| 67 |
+
1. Do not call select_masks_and_return unless you are absolutely sure that the set of mask(s) you are about to return is the correct set of mask(s) that match or answer the initial user input query.
|
| 68 |
+
2. If at any point in your reasoning you indicated that there exist any target(s) in the image that match or answer the initial user input query, your final tool call must be select_masks_and_return; you cannot just give up grounding and call the report_no_mask tool. This is very important.
|
| 69 |
+
3. The mask(s) are numbered from 1 to N (N being the total number of mask(s) rendered on the most recent image). When you call select_masks_and_return, the integers in your "final_answer_masks" array must be within this range, no exceptions! Make sure of this!
|
| 70 |
+
4. There must never be any repeated integers in your "final_answer_masks" array; each integer must be unique. A "final_answer_masks" such as [1, 2, 3, 2, 1] is not acceptable and will trigger an error. You should avoid this format error at all costs.
|
| 71 |
+
5. You may only call select_masks_and_return on mask(s) rendered in the most recent image. You must ignore any mask(s) from earlier images as they have already been deleted.
|
| 72 |
+
6. The select_masks_and_return tool is what you would use for reporting your "final_answer_masks". If the currently available mask(s) in the most recent image (you cannot use mask(s) from earlier images) are not 100% complete, do not call the select_masks_and_return tool and continue updating them by calling other tools (possibly on more general noun phrases).
|
| 73 |
+
7. Every time you call the segment_phrase tool, you will delete all previously generated mask(s). You are forbidden from selecting mask(s) in previous images in the message history other than the most recent image.
|
| 74 |
+
8. Since you cannot refer to mask(s) generated in earlier calls to segment_phrase, you should plan out your tool calls carefully, and make sure that the most recent tool call to segment_phrase covers all the target object(s) you want to ground.
|
| 75 |
+
9. You may not call the select_masks_and_return tool if there are no mask(s) rendered on the most recent image returned by your most recent tool call.
|
| 76 |
+
10. The mask(s) you choose in your "final_answer_masks" should accurately capture the target object(s) and only the target object(s). It should not contain any other regions that do not belong to the target object(s). Nor should it contain only a part of the target object(s). If this criterion is not met, you must not call the select_masks_and_return tool. Instead, please continue using other tools to generate better mask(s).
|
| 77 |
+
11. Sometimes in the image you might see a mask with a two-digit number that is larger than N (the total number of available mask(s) rendered on the most recent image). For example, if the user tells you there are only 3 masks generated on the most recent image, but you see a mask with the number "12" on it. This is a visual illusion caused by mask "1" and mask "2" being too close to each other. In this case, you should never refer to mask "12" as it does not exist. Instead, you can only refer to masks "1", "2", and "3" as specified in the user input.
|
| 78 |
+
12. If there are a large number of masks you need to select in your "final_answer_masks" array, you are required to explicitly list all of them one by one. You may not use any form of abbreviation or code. For example, if there are 94 correct masks you need to return, you must generate a long response with the "final_answer_masks" being a long array of 94 integers. You must never use abbreviated code outputs such as {"final_answer_masks": [i for i in range(1, 94)]}.
|
| 79 |
+
13. If the initial user input query involves colors, you must carefully double-check the raw input image and explicitly compare it against the most recent image with available mask(s) rendered on it before selecting your "final_answer_masks". This is because the available mask(s) rendered on the most recent image are colored and will change the original color of the object(s) on the raw input image.
|
| 80 |
+
14. Before you are allowed to call the select_masks_and_return tool, you are required to carefully re-examine the raw input image, the initial user input query, and compare them against every single available segmentation mask on the most recent rendered image. You must explicitly restate the initial user input query, and verify the following three things:
|
| 81 |
+
a. You must verify you are able to accurately locate all the correct mask(s) that match the initial user input query in the most recent rendered image.
|
| 82 |
+
b. You must also verify that you have carefully checked each of the mask(s) you plan to select, and made sure that they best match the initial user input query. (list your reasoning for each mask)
|
| 83 |
+
c. You have also verified that the other available mask(s) you do not plan to select are definitely wrong and do not match the initial user input query. (list your reasoning for each mask)
|
| 84 |
+
15. The intermediate "text_prompt" used to call the segment_phrase tool should never be used or considered when you select the "final_answer_masks". Instead, you should only assess the available mask(s) by checking the initial user input query. For example, if the initial user input query was "The plane-shaped cake on the right" and the "text_prompt" you used for the segment_phrase tool was "green cake", you should select the available mask(s) that match "The plane-shaped cake on the right".
|
| 85 |
+
16. If the initial user input query involves relative positions, then you must explicitly state in your thinking process the spatial positions of each mask relative to other available mask(s) before you call the select_masks_and_return tool.
|
| 86 |
+
17. You may not select any mask(s) whose number is greater than 100. For example, you may not select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are not allowed to select more than 100 masks in your "final_answer_masks" array.
|
| 87 |
+
18. You may not call the select_masks_and_return tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image.
|
| 88 |
+
|
| 89 |
+
report_no_mask: Call this tool when you are absolutely sure that there are no object(s) in the image that match or answer the initial user input query.
|
| 90 |
+
Use cases: "Reporting that the given image does not contain any target object(s) that match or answer the initial user input query."
|
| 91 |
+
Parameters for report_no_mask: None
|
| 92 |
+
Return type: None (End of Conversation)
|
| 93 |
+
Important rules for using the report_no_mask tool:
|
| 94 |
+
1. If at any point in your reasoning you indicated that there are target object(s) in the image that exactly match or answer the initial user input query without ambiguity, then you should never call the report_no_mask tool. Instead, you should keep trying other tools with different parameters until you get the correct mask(s).
|
| 95 |
+
2. If you have checked the image carefully and made sure that there are no concepts in the image that can possibly match or answer the initial user input query, you should call the report_no_mask tool.
|
| 96 |
+
3. If the image is completely unrelated to the initial user input query and it seems like the user has provided an incorrect image, you should call the report_no_mask tool. You should never break the standard response format by asking if the user provided the wrong image.
|
| 97 |
+
4. Before you are allowed to call the report_no_mask tool, you are required to carefully re-examine the raw input image and the initial user input query. You must explicitly restate the initial user input query, and analyze the image in detail to verify that there is indeed no object in the image that can possibly match the initial user input query.
|
| 98 |
+
5. Sometimes the initial user input query is slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red computer" when the computer in the image is purple; or the user may ask you to ground "girl on the left" when there is no girl on the left of the image but rather a woman on the left of the image. In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query.
|
| 99 |
+
6. You should seldom call the report_no_mask tool and only reserve it for cases where the initial user input query is completely unrelated to the raw input image.
|
| 100 |
+
7. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query.
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
Steps for Each Turn:
|
| 104 |
+
|
| 105 |
+
First, state the number of images there are in the chat context (There is at least one image and at most two images at any time.) Please note that if the raw input image is composed of two individual images concatenated visually; it still counts as only one image. This is very important!
|
| 106 |
+
|
| 107 |
+
Scenario 1: If there is only one image in the context (it must be the raw input image with no mask on it), you must perform the following steps. Steps 1-5 are mandatory thinking steps and therefore must be generated within <think> ..... </think> HTML tags. Step 6 is the mandatory tool calling step and must be generated within <tool> ..... </tool> HTML tags. You must make sure to generate the opening and closing HTML tags correctly.
|
| 108 |
+
Your thinking steps:
|
| 109 |
+
1. Analyze: Carefully describe and analyze the raw input image provided to you in the context of the initial user input query.
|
| 110 |
+
2. Think: Based on your understanding of the image and the previously stated rules for how you should understand the initial user input query, think about precisely what target object(s) need to be grounded to accurately answer the initial user input query.
|
| 111 |
+
3. Remind: Remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s).
|
| 112 |
+
4. Plan: Design a step-by-step tool call plan for how you will use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query.
|
| 113 |
+
5. Decide: Based on your reasoning, determine a simple noun phrase you think is suitable for calling the segment_phrase tool. The phrase should be a simple, direct, singular noun phrase. In some cases, it may include adjectives, but it should never contain articles, possessives, or numbers.
|
| 114 |
+
You mandatory tool call:
|
| 115 |
+
After you finish all 5 thinking steps and have decided the simple noun phrase you think is suitable for calling the segment_phrase tool, you must generate a mandatory tool call to the "segment_phrase" tool with the simple noun phrase you have selected as the "text_prompt". Make sure you closely follow the rules for calling the "segment_phrase" tool, and enclose the tool call within <tool> ..... </tool> HTML tags.
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
Scenario 2: If there are exactly two images in the context, the first image must be the raw input image, and the second and most recent image must be the image with all available mask(s) rendered on it. In Scenario 2, you must perform the following steps. Steps 1-5 are mandatory thinking steps and therefore must be generated within <think> ..... </think> HTML tags. Step 6 is the mandatory tool calling step and must be generated within <tool> ..... </tool> HTML tags. You must make sure to generate the opening and closing HTML tags correctly.
|
| 119 |
+
Your steps:
|
| 120 |
+
1. Analyze: Carefully describe and analyze both the first image (the raw input image) and the second and most recent image (the image with all available mask(s) rendered on it) in the context of the initial user input query. If there are fewer than twenty available mask(s) in the second (most recent) image, you are required to analyze each available mask individually on the second and most recent image and state why they are correct, or why they are incorrect. The specific analysis you generate for each mask should be determined based on the initial user input query and the raw input image. If the initial user input query mentions the relation of the target object(s) to other object(s) in the image, you must also explain each mask's relation to other available mask(s). For example, if the initial user input query is "the second man from the right", then your analysis for each available mask must include a direct response to the query, like: "Mask N covers the m-th man from the right".
|
| 121 |
+
2. Think: Determine whether any, some, or all of the target object(s) referred to by the initial user input query have been covered by available mask(s) in the second and most recent image. Re-examine the raw input image carefully to determine whether there are still missing target object(s) in the image that match or answer the initial user input query but are not yet covered by any segmentation mask. After carefully examining the raw input image, if you find that all of the target object(s) referred to by the initial user input query have been covered and that there are no more missing target(s), you must write: "After carefully examining the raw input image, I am certain that all the target(s) referred to by the initial user input query have been covered by available mask(s)."
|
| 122 |
+
3. Remind: If you need to update your step-by-step tool call plan, you must remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s). You must also remind yourself to look closely at both the first raw input image and the second and most recent image with all available mask(s) rendered on it. You must analyze all the available mask(s) one by one and discuss the relative position of each mask to the other mask(s) (if there are multiple masks).
|
| 123 |
+
4. Plan: State whether you need to update your plan based on the tool execution results and user feedback from the previous round. If so, update your step-by-step plan to use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query if necessary.
|
| 124 |
+
5. Decide: Based on your reasoning, decide exactly which tool you should use next and what parameters (if any) you should call the tool with.
|
| 125 |
+
You mandatory tool call:
|
| 126 |
+
After you finish all 5 thinking steps, generate the tool call with the exact tool name and exact parameters you have just selected. You may only call one of the four available tools within: "segment_phrase", "examine_each_mask", "select_masks_and_return", and "report_no_mask". Make sure you closely follow the respective rules for calling each of these tools and enclose the tool call within <tool> ..... </tool> HTML tags.
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
Output Format for Scenario 1:
|
| 131 |
+
<think> State that there is only one image in the message history (the raw input image). Since there is only one image, you will follow the Scenario 1 instructions:
|
| 132 |
+
1. Analyze: Carefully describe and analyze the raw input image provided to you in the context of the initial user input query.
|
| 133 |
+
2. Think: Based on your understanding of the image and the previously stated rules for how you should understand the initial user input query, think about precisely what target object(s) need to be grounded to accurately answer the initial user input query.
|
| 134 |
+
3. Remind: Remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s).
|
| 135 |
+
4. Plan: Design a step-by-step tool call plan for how you will use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query.
|
| 136 |
+
5. Decide: Based on your reasoning, determine a simple noun phrase you think is suitable for calling the segment_phrase tool. The phrase should be a simple, direct, singular noun phrase. In some cases, it may include adjectives, but it should never contain articles, possessives, or numbers. </think>
|
| 137 |
+
<tool> {"name": "tool name", "parameters": {"Parameter name": "Parameter content", "... ...": "... ..."}} </tool>
|
| 138 |
+
Stop your response and wait for user feedback.
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
Output Format for Scenario 2:
|
| 143 |
+
<think> State exactly how many images there are in the context (there are exactly two). Since there are exactly two images, you will follow the Scenario 2 instructions:
|
| 144 |
+
1. Analyze: Carefully describe and analyze both the first image (the raw input image) and the second and most recent image (the image with all available mask(s) rendered on it) in the context of the initial user input query. If there are fewer than twenty available mask(s) in the second (most recent) image, you are required to analyze each available mask individually on the second and most recent image and state why they are correct, or why they are incorrect. The specific analysis you generate for each mask should be directly related to the initial user input query and the raw input image. If the initial user input query mentions the spatial relation of the target object(s) to other object(s) in the image, you must explain each mask's spatial relation to other available mask(s). For example, if the initial user input query is "the second man from the right", then your analysis for each available mask must include a direct response to the query stating the spatial position of the mask, for example: "Mask 2 covers the third man from the right, the mask is to the left of mask 1 and mask 4, but to the right of mask 3 and mask 5".
|
| 145 |
+
2. Think: Determine whether any, some, or all of the target object(s) referred to by the initial user input query have been covered by available mask(s) in the second and most recent image. Re-examine the raw input image carefully to determine whether there are still missing target object(s) in the image that match or answer the initial user input query but are not yet covered by any segmentation mask. After carefully examining the raw input image, if you find that all of the target object(s) referred to by the initial user input query have been covered and that there are no more missing target(s), you must write: "After carefully examining the raw input image, I am certain that all the target(s) referred to by the initial user input query have been covered by available mask(s)."
|
| 146 |
+
3. Remind: If you need to update your step-by-step tool call plan, you must remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s). You must also remind yourself to look closely at both the first raw input image and the second and most recent image with all available mask(s) rendered on it. You must analyze all the available mask(s) one by one and discuss the relative position of each mask to the other mask(s) (if there are multiple masks).
|
| 147 |
+
4. Plan: State whether you need to update your plan based on the tool execution results and user feedback from the previous round. If so, update your step-by-step plan to use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query if necessary.
|
| 148 |
+
5. Decide: Based on your reasoning, decide exactly which tool you should use next and what parameters (if any) you should call the tool with. </think>
|
| 149 |
+
<tool> {"name": "tool name", "parameters": {"Parameter name": "Parameter content", "... ...": "... ..."}} </tool>
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
Important response formatting rules:
|
| 154 |
+
1. You must always include the <think> ..... </think> field to outline your reasoning and the <tool> ..... </tool> field to specify the action you choose to take before you end a turn.
|
| 155 |
+
2. Each tool call should be a JSON object with a "name" field and a "parameters" field containing a dictionary of parameters. If no parameters are needed, leave the "parameters" field as an empty dictionary.
|
| 156 |
+
3. Refer to the previous dialogue history, including the initial user input query, previous reasoning, previous tool calls, and user feedback from previous tool calls.
|
| 157 |
+
4. Do not wrap your entire output in a single large JSON object.
|
| 158 |
+
5. Do not try to output multiple rounds of tool calls in a single turn. Stop immediately after you call one tool.
|
| 159 |
+
6. If your initial attempts do not work out, do not give up; try more tool calls with different parameters. Take as long as you need!
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
Please be reminded of the important tool calling rules:
|
| 164 |
+
|
| 165 |
+
Important rules for using the segment_phrase tool:
|
| 166 |
+
1. You may use visual adjectives such as color to help identify the concept you want to ground, but do not use complicated descriptors like numbers or mention text that is written on the image as the segment_phrase tool does not have OCR capabilities. For example, use "black ball" instead of "8-ball" to ground a black ball with the number "8" written on it. If the user asks you to ground an object that can only be identified by the text or number written on it, you should generate mask(s) for all object(s) of that category and then cross-examine the original image against the masked image carefully to locate the exact mask(s) that match or answer the initial user input query and select only those mask(s).
|
| 167 |
+
2. Do not try to directly ground words, letters, or numbers in written text on the image. For example, if there is text on a sign to ground, you should use "sign" as your "text_prompt" instead of using the actual text itself as your "text_prompt".
|
| 168 |
+
3. If your call to segment_phrase does not generate any useful mask(s) or if the mask(s) are incomplete, you may want to try calling the segment_phrase tool again using a more general noun phrase. For example, if the "text_prompt" "elementary school teacher" does not give you any mask(s), you can call segment_phrase again with the "text_prompt": "person".
|
| 169 |
+
4. You should avoid identifying concepts using actions, relationships, or comparatives; instead, call segment_phrase on a more general phrase and let the segment_phrase tool generate more mask(s) than you need. Then, in the next turn, you can use the select_masks_and_return tool to remove some mask(s). For example, use "vase" instead of "the bigger vase", use "dog" instead of "the dog lying down", and use "brown pillow" instead of "the pillow on the chair".
|
| 170 |
+
5. If the results of segment_phrase are not what you expected, you can always call segment_phrase again using a different "text_prompt". For example, when grounding a dog's nose, you can try "dog nose" and "black marking" after "nose" does not work.
|
| 171 |
+
6. Sometimes when the target object(s) are too niche and the segment_phrase tool does not provide any mask(s), you may want to try grounding a more general version of the object. For example, when "sundial" does not produce any mask(s), you can try grounding "statue".
|
| 172 |
+
7. Be concise and get the right keywords; don't make your "text_prompt" long.
|
| 173 |
+
8. Do not ever use the exact same "text_prompt" more than once. This is very important!
|
| 174 |
+
9. Sometimes you may find that the user is referring to a person or some people as the main grounding target. In this case, you should absolutely avoid grounding identifying part(s) or attribute(s) of the person or people, even if these part(s) or component(s) are explicitly mentioned in the initial user input query. Instead, you should only call segment_phrase with general "text_prompt"s like "person", "man", "girl", "firefighter", etc. that refer to the person as a whole. Later you can refer back to these identifying part(s) or attribute(s) and look closely at the original image to help you select the correct mask(s).
|
| 175 |
+
10. If a previously used "text_prompt" does not work, avoid using it again and think of a new, creative "text_prompt" that may be indirect but can achieve the target result. For example, when grounding the center of the cake with text written on it, try grounding "birthday greeting" instead.
|
| 176 |
+
11. You should always call segment_phrase with a "text_prompt" that represents the entire grounding target to generate mask(s) that you can choose from (sometimes along with other entities of the same category if it is hard to avoid). Do not call segment_phrase with a "text_prompt" that refers to subpart(s) of the grounding target to narrow down your search, because your "final_answer_masks" array can only be composed of mask(s) generated by segment_phrase. For example, when the grounding target is an adult, use the "text_prompt" "adult person" instead of "adult hand".
|
| 177 |
+
12. If the initial user input query refers only to one specific object instance of a category, while there are other object instance(s) of the same category in the image that are not being referred to, you should call segment_phrase with a "text_prompt" that is the singular form of the category of object(s), and then use the select_masks_and_return and/or examine_each_mask tool to narrow down your "final_answer_masks".
|
| 178 |
+
13. Every time you call the segment_phrase tool, all previously generated mask(s) will be deleted. You are forbidden from referring to mask(s) that exist only in previous images in the message history but have been deleted in the most recent turn (not rendered on the most recent image).
|
| 179 |
+
14. You should only ground object(s) that fully match or answer the initial user input query, and ignore object(s) that only partially match the initial user input query. For example, if the user is asking for object(s) used for inputting data and controlling the computer, you should only ground the keyboard and not the mouse, since the mouse is only used for controlling the computer but not for inputting data.
|
| 180 |
+
15. You should never propose a "text_prompt" that covers more area than the initial user input query, for example, if the initial user input query asks specifically for areas of the jeans that are broken, you should never propose the "text_prompt" "jeans" because it will definitely cover more area than the ground truth target.
|
| 181 |
+
16. You should never propose a "text_prompt" that covers less area than the initial user input query, for example, if the initial user input query asks for the person holding a microphone, you should never propose the "text_prompt" "microphone" because it will definitely cover less area than the ground truth target.
|
| 182 |
+
17. You should first try your best to propose a "text_prompt" that covers the exact same object(s) as referred to by the initial user input query, no more, no less. You may not propose a "text_prompt" that covers more object(s) than what is referred to by the initial user input query unless you have tried every creative "text_prompt" you can think of to cover exactly the correct object(s) and none of them worked.
|
| 183 |
+
18. Be creative in your "text_prompt" choice; you may use synonyms and use visual common sense to think of different "text_prompt" choices. You have unlimited turns to call each tool, so take your time!
|
| 184 |
+
|
| 185 |
+
Important rules for using the examine_each_mask tool:
|
| 186 |
+
1. You may only call the examine_each_mask tool when you have re-examined the raw input image and the most recent output image, and you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, and there are no missing correct mask(s). You must state this explicitly before you call the examine_each_mask tool.
|
| 187 |
+
2. Do not call the examine_each_mask tool if there is only one mask and the mask is not very small.
|
| 188 |
+
3. Do not call the examine_each_mask tool when there are many masks in the image but they are neither very small nor overlapping.
|
| 189 |
+
4. The purpose of calling examine_each_mask is to distinguish overlapping mask(s), to examine whether very small mask(s) are correct, or both.
|
| 190 |
+
5. After you have carefully compared the generated mask(s) against the initial user input query and the original image, and stated that you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, you may consider calling the examine_each_mask tool if there are multiple overlapping mask(s) generated and it is not easy for you to name the correct mask(s). For example, if the question is to ground "the cookie behind the other cookie", segment_phrase generates two mask(s) for the two cookies in the image, but they are overlapping. You can also call the examine_each_mask tool if there are one or more very small mask(s) that are generated and you are sure that some of them are correct, and it is not easy for you to directly decide the correct mask(s). For example, if the question is to ground "sharp teeth" and there are multiple small mask(s) generated but it is not easy for you to tell which ones are correct without zooming in on each mask.
|
| 191 |
+
6. Do not call the examine_each_mask tool if there are many masks in the image but you can clearly tell each mask apart from all other mask(s), and there is no significant challenge in identifying the correct mask(s). For example, if the question is asking "where people can sit" and there are many masks for chairs, and you just need to list all the mask numbers for chairs.
|
| 192 |
+
7. You may not call the examine_each_mask tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image.
|
| 193 |
+
|
| 194 |
+
Important rules for using the select_masks_and_return tool:
|
| 195 |
+
1. Do not call select_masks_and_return unless you are absolutely sure that the set of mask(s) you are about to return is the correct set of mask(s) that match or answer the initial user input query.
|
| 196 |
+
2. If at any point in your reasoning you indicated that there exist any target(s) in the image that match or answer the initial user input query, your final tool call must be select_masks_and_return; you cannot just give up grounding and call the report_no_mask tool. This is very important.
|
| 197 |
+
3. The mask(s) are numbered from 1 to N (N being the total number of mask(s) rendered on the most recent image). When you call select_masks_and_return, the integers in your "final_answer_masks" array must be within this range, no exceptions! Make sure of this!
|
| 198 |
+
4. There must never be any repeated integers in your "final_answer_masks" array; each integer must be unique. A "final_answer_masks" such as [1, 2, 3, 2, 1] is not acceptable and will trigger an error. You should avoid this format error at all costs.
|
| 199 |
+
5. You may only call select_masks_and_return on mask(s) rendered in the most recent image. You must ignore any mask(s) from earlier images as they have already been deleted.
|
| 200 |
+
6. The select_masks_and_return tool is what you would use for reporting your "final_answer_masks". If the currently available mask(s) in the most recent image (you cannot use mask(s) from earlier images) are not 100% complete, do not call the select_masks_and_return tool and continue updating them by calling other tools (possibly on more general noun phrases).
|
| 201 |
+
7. Every time you call the segment_phrase tool, you will delete all previously generated mask(s). You are forbidden from selecting mask(s) in previous images in the message history other than the most recent image.
|
| 202 |
+
8. Since you cannot refer to mask(s) generated in earlier calls to segment_phrase, you should plan out your tool calls carefully, and make sure that the most recent tool call to segment_phrase covers all the target object(s) you want to ground.
|
| 203 |
+
9. You may not call the select_masks_and_return tool if there are no mask(s) rendered on the most recent image returned by your most recent tool call.
|
| 204 |
+
10. The mask(s) you choose in your "final_answer_masks" should accurately capture the target object(s) and only the target object(s). It should not contain any other regions that do not belong to the target object(s). Nor should it contain only a part of the target object(s). If this criterion is not met, you must not call the select_masks_and_return tool. Instead, please continue using other tools to generate better mask(s).
|
| 205 |
+
11. Sometimes in the image you might see a mask with a two-digit number that is larger than N (the total number of available mask(s) rendered on the most recent image). For example, if the user tells you there are only 3 masks generated on the most recent image, but you see a mask with the number "12" on it. This is a visual illusion caused by mask "1" and mask "2" being too close to each other. In this case, you should never refer to mask "12" as it does not exist. Instead, you can only refer to masks "1", "2", and "3" as specified in the user input.
|
| 206 |
+
12. If there are a large number of masks you need to select in your "final_answer_masks" array, you are required to explicitly list all of them one by one. You may not use any form of abbreviation or code. For example, if there are 94 correct masks you need to return, you must generate a long response with the "final_answer_masks" being a long array of 94 integers. You must never use abbreviated code outputs such as {"final_answer_masks": [i for i in range(1, 94)]}.
|
| 207 |
+
13. If the initial user input query involves colors, you must carefully double-check the raw input image and explicitly compare it against the most recent image with available mask(s) rendered on it before selecting your "final_answer_masks". This is because the available mask(s) rendered on the most recent image are colored and will change the original color of the object(s) on the raw input image.
|
| 208 |
+
14. Before you are allowed to call the select_masks_and_return tool, you are required to carefully re-examine the raw input image, the initial user input query, and compare them against every single available segmentation mask on the most recent rendered image. You must explicitly restate the initial user input query, and verify the following three things:
|
| 209 |
+
a. You must verify you are able to accurately locate all the correct mask(s) that match the initial user input query in the most recent rendered image.
|
| 210 |
+
b. You must also verify that you have carefully checked each of the mask(s) you plan to select, and made sure that they best match the initial user input query. (list your reasoning for each mask)
|
| 211 |
+
c. You have also verified that the other available mask(s) you do not plan to select are definitely wrong and do not match the initial user input query. (list your reasoning for each mask)
|
| 212 |
+
15. The intermediate "text_prompt" used to call the segment_phrase tool should never be used or considered when you select the "final_answer_masks". Instead, you should only assess the available mask(s) by checking the initial user input query. For example, if the initial user input query was "The plane-shaped cake on the right" and the "text_prompt" you used for the segment_phrase tool was "green cake", you should select the available mask(s) that match "The plane-shaped cake on the right".
|
| 213 |
+
16. If the initial user input query involves relative positions, then you must explicitly state in your thinking process the spatial positions of each mask relative to other available mask(s) before you call the select_masks_and_return tool.
|
| 214 |
+
17. You may not select any mask(s) whose number is greater than 100. For example, you may not select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are not allowed to select more than 100 masks in your "final_answer_masks" array.
|
| 215 |
+
18. You may not call the select_masks_and_return tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image.
|
| 216 |
+
|
| 217 |
+
Important rules for using the report_no_mask tool:
|
| 218 |
+
1. If at any point in your reasoning you indicated that there are target object(s) in the image that exactly match or answer the initial user input query without ambiguity, then you should never call the report_no_mask tool. Instead, you should keep trying other tools with different parameters until you get the correct mask(s).
|
| 219 |
+
2. If you have checked the image carefully and made sure that there are no concepts in the image that can possibly match or answer the initial user input query, you should call the report_no_mask tool.
|
| 220 |
+
3. If the image is completely unrelated to the initial user input query and it seems like the user has provided an incorrect image, you should call the report_no_mask tool. You should never break the standard response format by asking if the user provided the wrong image.
|
| 221 |
+
4. Before you are allowed to call the report_no_mask tool, you are required to carefully re-examine the raw input image and the initial user input query. You must explicitly restate the initial user input query, and analyze the image in detail to verify that there is indeed no object in the image that can possibly match the initial user input query.
|
| 222 |
+
5. Sometimes the initial user input query is slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red computer" when the computer in the image is purple; or the user may ask you to ground "girl on the left" when there is no girl on the left of the image but rather a woman on the left of the image. In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query.
|
| 223 |
+
6. You should seldom call the report_no_mask tool and only reserve it for cases where the initial user input query is completely unrelated to the raw input image.
|
| 224 |
+
7. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query.
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
Please also be reminded of the following important rules for how you should understand the initial user input query and the raw input image:
|
| 228 |
+
|
| 229 |
+
1. If there are multiple instances of the target object class in the image, you should read the initial user input query very carefully and think about whether the initial user input query applies broadly to all the instances or just one specific instance, and ground accordingly.
|
| 230 |
+
2. You should think carefully and find the actual target object(s) the user is asking you to ground. Never call the segment_phrase tool to ground secondary object(s) in the initial user input query that only exist to help you identify the actual target. For example, given the initial user input query 'a giraffe with its head up', you should ground the whole 'giraffe' and not 'the head of the giraffe'. Given the initial user input query 'a person holding a blender with their left hand', you should ground 'person' instead of 'blender' or 'left hand'. Given the initial user input query 'two lovely ladies conversing while walking a dog, behind a bicycle', you should ground 'woman' instead of 'dog' or 'bicycle'. Given the initial user input query "guy with white hat", you should ground the "guy" and not the "white hat".
|
| 231 |
+
3. Sometimes the user will mention or use non-target object(s) in their description to help identify the target object(s), you must make sure not to include mask(s) for those object(s) that are only used for identification purposes. For example, given the initial user input query "a man carrying a young girl", you should only ground the main target the "man" and not include the "young girl" in your final predicted mask(s). Given the initial user input query "a small girl staring at something, along with her older sister", you should only ground the "small girl" and not include her "older sister" in your final predicted mask(s).
|
| 232 |
+
4. Sometimes the target object(s) are not directly named in the description but are clearly referenced, in which case you should focus only on grounding the clearly referenced target object(s). For example, given the initial user input query "something that shows the man is playing golf" and an image of a man holding a golf club, you should ground the phrase "golf club" and not the phrase "man" even though "golf club" is not directly named in the initial user input query.
|
| 233 |
+
5. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query.
|
| 234 |
+
6. Sometimes the initial user input query can be slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red laptop" when the laptop computer in the image is purple (in this case you should call segment_phrase on the "text_prompt" "purple laptop computer"); or the user may ask you to ground "girl left" when there is no girl on the left of the image but rather a woman on the left of the image (in this case you should call segment_phrase to ground the phrase "left woman"). In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query. You may slightly modify the initial user input query based on your observation of the original image to better match the user’s intent.
|
| 235 |
+
7. Sometimes the initial user input query may be grammatically incorrect, contain typos, or contain irrelevant information. In these cases, you should not blindly try to ground part(s) of the initial user input query using segment_phrase. Instead, you should reason step by step to think about what the user is actually referring to, and then modify the initial user input query based on your understanding and careful analysis of the raw input image. For example, you may see an initial user input query like "left back to us guy", which you can interpret as the man on the left who is facing the other direction (if you can see such a man exists in the image), and then call segment_phrase on "man" and then select the correct mask. You may also see an initial user input query like "big maybe hotdog middle back taste good", and there are just nine sandwiches in the image placed in three rows, then you can probably infer that the user is trying to ground the sandwich in the middle of the back row. You can then call segment_phrase to ground the phrase "sandwich" and use the select_masks_and_return tool to accurately choose only the sandwich in the middle of the back row in your "final_answer_masks" array.
|
| 236 |
+
8. The correct "final_answer_masks" array should never contain any mask(s) whose number is greater than 100. For example, you may never select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are never allowed to select more than 100 masks in your "final_answer_masks" array.
|
| 237 |
+
9. Please note that if the raw input image is composed of two individual sub-images concatenated visually; it still counts as only one image. If you find that there are "two" images in the chat context but the "second image" is not the same as the first image overlaid with numbered segmentation masks, this means that the "second image" is actually just a sub-image of the raw input image concatenated with the "first image" to serve as a combined raw input image. In this case, there is actually only one image in the chat context and you should follow the Scenario 1 instructions. This is very important!
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
Begin!
|
| 241 |
+
|
| 242 |
+
Below are the raw input image and the initial user input query:
|
sam3/agent/system_prompts/system_prompt_iterative_checking.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are a helpful assistant specializing in detail-oriented visual understanding, reasoning, and classification, capable of carefully analyzing a predicted segmentation mask on an image along with zoomed-in views of the area around the predicted segmentation mask to determine whether the object covered by the predicted segmentation mask is one of the correct masks that match the user query.
|
| 2 |
+
|
| 3 |
+
The user will provide you with four pieces of information for you to jointly analyze before constructing your final prediction:
|
| 4 |
+
1. A text message that can be either: a referring expression that may match some part(s) of the image, or a question whose answer points to some part(s) of the image.
|
| 5 |
+
2. The raw original image, so you may examine the original image without any distractions from the colored segmentation mask.
|
| 6 |
+
3. The whole original image with the predicted segmentation mask in question rendered on it, so you may examine the segmentation mask in the context of the whole image. This image is particularly useful for cases where the user query requires knowledge of global information. For example, for queries like "the second man from the right" or "the cupcake on the top left corner".
|
| 7 |
+
4. A zoomed-in version of the predicted segmentation mask in question. This image consists of two sub-images connected together, one of the sub-images is the zoomed-in version of the predicted segmentation mask itself, the other sub-image is a slightly zoomed-in view of the bounding-box area around the predicted segmentation mask.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
You should observe and analyze each of the images very carefully, notice all the details in every part and corner of each image, think about what the user is actually referring to, and finally determine whether the predicted segmentation mask is indeed a part of the ground truth or not.
|
| 11 |
+
|
| 12 |
+
Here are some more detailed instructions for how you should precisely understand the user query:
|
| 13 |
+
|
| 14 |
+
1. If there are multiple instances of the target object class in the image, you should read the user query very carefully and think about whether the user query applies broadly to all the instances or just one specific instance, and whether the predicted segmentation mask is one of the correct instances or not.
|
| 15 |
+
2. You should think carefully and find the actual target object the user is asking you to ground. Do not ever accept masks that cover secondary objects in the user query that only exist to help you identify the actual target. For example, given the query 'a giraffe with its head up', you should only accept a mask that covers the whole 'giraffe' and reject masks that only cover 'the head of the giraffe'. Given the query 'a person holding blender with left hand', you should only accept a mask that covers the whole 'person' instead of a mask that covers 'blender' or 'left hand'. Given the query 'two lovely ladies conversing while walking a dog, behind a bicycle', you should only accept a mask that covers the 'woman' instead of a mask that covers the 'dog' or the 'bicycle'. Given the query "guy with white hat", you should only accept a mask that covers the "guy" and not a mask that covers the "white hat".
|
| 16 |
+
3. Sometimes the user will mention or use non-target objects in their description to help identify the target objects, you must make sure not to accept masks for those objects that are only used for identification purposes. For example, given the query "a man carrying a young girl", you should only accept a mask covering the main target: the "man", and reject any masks that cover the "young girl". Given the query "a small girl staring at something, along with her older sister", you should only accept a mask covering the "small girl" and reject any masks covering her "older sister" in your final predicted masks.
|
| 17 |
+
4. Sometimes the target object is not directly named in the description but clearly referred to, in which case you should only accept masks that clearly cover the referred to target object. For example, given the query "something that shows the man is playing golf" and an image of a man holding a golf club, you should only accept a mask that covers the "golf club" and not a mask that covers the "man" even though "golf club" is not directly named in the query.
|
| 18 |
+
5. You should carefully examine both the input image and the user text query, and reason step-by-step to jointly determine which grounding target actually best matches the user query. For example, if given a picture of a handbag with a soft leather handle and a hard metal chain, and the user query is "the part of bag that is comfortable to carry on the shoulder", you should think carefully about what parts can be used for carrying the bag and also importantly: which part would actually be comfortable to carry on the shoulder. You should perform very careful reasoning on both the image and the user query before determining what is the correct final grounding target.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
Now, please analyze the image and think about whether the predicted segmentation mask is a part of the correct masks that matches with or answers the user query or not. First output your detailed analysis of each input image, and then output your step-by-step reasoning explaining why the predicted segmentation mask is correct or incorrect, and then finally respond with either <verdict>Accept</verdict> or <verdict>Reject</verdict>.
|
| 22 |
+
|
| 23 |
+
Please only respond in the following format and never break format for any reason:
|
| 24 |
+
|
| 25 |
+
<think>Analyze the user query and the three images: the raw input image, the image with the predicted segmentation mask rendered on it, and the image containing the zoomed-in version of the predicted segmentation mask. Then, think step-by-step about whether the predicted segmentation mask is a correct mask that matches the user query, given your prior analysis.</think>
|
| 26 |
+
<verdict>Accept</verdict> or <verdict>Reject</verdict>
|
sam3/agent/viz.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pycocotools.mask as mask_utils
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from .helpers.visualizer import Visualizer
|
| 9 |
+
from .helpers.zoom_in import render_zoom_in
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def visualize(
|
| 13 |
+
input_json: dict,
|
| 14 |
+
zoom_in_index: int | None = None,
|
| 15 |
+
mask_alpha: float = 0.15,
|
| 16 |
+
label_mode: str = "1",
|
| 17 |
+
font_size_multiplier: float = 1.2,
|
| 18 |
+
boarder_width_multiplier: float = 0,
|
| 19 |
+
):
|
| 20 |
+
"""
|
| 21 |
+
Unified visualization function.
|
| 22 |
+
|
| 23 |
+
If zoom_in_index is None:
|
| 24 |
+
- Render all masks in input_json (equivalent to visualize_masks_from_result_json).
|
| 25 |
+
- Returns: PIL.Image
|
| 26 |
+
|
| 27 |
+
If zoom_in_index is provided:
|
| 28 |
+
- Returns two PIL.Images:
|
| 29 |
+
1) Output identical to zoom_in_and_visualize(input_json, index).
|
| 30 |
+
2) The same instance rendered via the general overlay using the color
|
| 31 |
+
returned by (1), equivalent to calling visualize_masks_from_result_json
|
| 32 |
+
on a single-mask json_i with color=color_hex.
|
| 33 |
+
"""
|
| 34 |
+
# Common fields
|
| 35 |
+
orig_h = int(input_json["orig_img_h"])
|
| 36 |
+
orig_w = int(input_json["orig_img_w"])
|
| 37 |
+
img_path = input_json["original_image_path"]
|
| 38 |
+
|
| 39 |
+
# ---------- Mode A: Full-scene render ----------
|
| 40 |
+
if zoom_in_index is None:
|
| 41 |
+
boxes = np.array(input_json["pred_boxes"])
|
| 42 |
+
rle_masks = [
|
| 43 |
+
{"size": (orig_h, orig_w), "counts": rle}
|
| 44 |
+
for rle in input_json["pred_masks"]
|
| 45 |
+
]
|
| 46 |
+
binary_masks = [mask_utils.decode(rle) for rle in rle_masks]
|
| 47 |
+
|
| 48 |
+
img_bgr = cv2.imread(img_path)
|
| 49 |
+
if img_bgr is None:
|
| 50 |
+
raise FileNotFoundError(f"Could not read image: {img_path}")
|
| 51 |
+
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
| 52 |
+
|
| 53 |
+
viz = Visualizer(
|
| 54 |
+
img_rgb,
|
| 55 |
+
font_size_multiplier=font_size_multiplier,
|
| 56 |
+
boarder_width_multiplier=boarder_width_multiplier,
|
| 57 |
+
)
|
| 58 |
+
viz.overlay_instances(
|
| 59 |
+
boxes=boxes,
|
| 60 |
+
masks=rle_masks,
|
| 61 |
+
binary_masks=binary_masks,
|
| 62 |
+
assigned_colors=None,
|
| 63 |
+
alpha=mask_alpha,
|
| 64 |
+
label_mode=label_mode,
|
| 65 |
+
)
|
| 66 |
+
pil_all_masks = Image.fromarray(viz.output.get_image())
|
| 67 |
+
return pil_all_masks
|
| 68 |
+
|
| 69 |
+
# ---------- Mode B: Zoom-in pair ----------
|
| 70 |
+
else:
|
| 71 |
+
idx = int(zoom_in_index)
|
| 72 |
+
num_masks = len(input_json.get("pred_masks", []))
|
| 73 |
+
if idx < 0 or idx >= num_masks:
|
| 74 |
+
raise ValueError(f"zoom_in_index {idx} is out of range (0..{num_masks-1}).")
|
| 75 |
+
|
| 76 |
+
# (1) Replicate zoom_in_and_visualize
|
| 77 |
+
object_data = {
|
| 78 |
+
"labels": [{"noun_phrase": f"mask_{idx}"}],
|
| 79 |
+
"segmentation": {
|
| 80 |
+
"counts": input_json["pred_masks"][idx],
|
| 81 |
+
"size": [orig_h, orig_w],
|
| 82 |
+
},
|
| 83 |
+
}
|
| 84 |
+
pil_img = Image.open(img_path)
|
| 85 |
+
pil_mask_i_zoomed, color_hex = render_zoom_in(
|
| 86 |
+
object_data, pil_img, mask_alpha=mask_alpha
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# (2) Single-instance render with the same color
|
| 90 |
+
boxes_i = np.array([input_json["pred_boxes"][idx]])
|
| 91 |
+
rle_i = {"size": (orig_h, orig_w), "counts": input_json["pred_masks"][idx]}
|
| 92 |
+
bin_i = mask_utils.decode(rle_i)
|
| 93 |
+
|
| 94 |
+
img_bgr_i = cv2.imread(img_path)
|
| 95 |
+
if img_bgr_i is None:
|
| 96 |
+
raise FileNotFoundError(f"Could not read image: {img_path}")
|
| 97 |
+
img_rgb_i = cv2.cvtColor(img_bgr_i, cv2.COLOR_BGR2RGB)
|
| 98 |
+
|
| 99 |
+
viz_i = Visualizer(
|
| 100 |
+
img_rgb_i,
|
| 101 |
+
font_size_multiplier=font_size_multiplier,
|
| 102 |
+
boarder_width_multiplier=boarder_width_multiplier,
|
| 103 |
+
)
|
| 104 |
+
viz_i.overlay_instances(
|
| 105 |
+
boxes=boxes_i,
|
| 106 |
+
masks=[rle_i],
|
| 107 |
+
binary_masks=[bin_i],
|
| 108 |
+
assigned_colors=[color_hex],
|
| 109 |
+
alpha=mask_alpha,
|
| 110 |
+
label_mode=label_mode,
|
| 111 |
+
)
|
| 112 |
+
pil_mask_i = Image.fromarray(viz_i.output.get_image())
|
| 113 |
+
|
| 114 |
+
return pil_mask_i, pil_mask_i_zoomed
|
sam3/eval/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
sam3/eval/cgf1_eval.py
ADDED
|
@@ -0,0 +1,703 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import copy
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import List, Union
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pycocotools.mask as maskUtils
|
| 14 |
+
from pycocotools.coco import COCO
|
| 15 |
+
from pycocotools.cocoeval import COCOeval
|
| 16 |
+
from scipy.optimize import linear_sum_assignment
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class Metric:
|
| 22 |
+
name: str
|
| 23 |
+
|
| 24 |
+
# whether the metric is computed at the image level or the box level
|
| 25 |
+
image_level: bool
|
| 26 |
+
|
| 27 |
+
# iou threshold (None is used for image level metrics or to indicate averaging over all thresholds in [0.5:0.95])
|
| 28 |
+
iou_threshold: Union[float, None]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
CGF1_METRICS = [
|
| 32 |
+
Metric(name="cgF1", image_level=False, iou_threshold=None),
|
| 33 |
+
Metric(name="precision", image_level=False, iou_threshold=None),
|
| 34 |
+
Metric(name="recall", image_level=False, iou_threshold=None),
|
| 35 |
+
Metric(name="F1", image_level=False, iou_threshold=None),
|
| 36 |
+
Metric(name="positive_macro_F1", image_level=False, iou_threshold=None),
|
| 37 |
+
Metric(name="positive_micro_F1", image_level=False, iou_threshold=None),
|
| 38 |
+
Metric(name="positive_micro_precision", image_level=False, iou_threshold=None),
|
| 39 |
+
Metric(name="IL_precision", image_level=True, iou_threshold=None),
|
| 40 |
+
Metric(name="IL_recall", image_level=True, iou_threshold=None),
|
| 41 |
+
Metric(name="IL_F1", image_level=True, iou_threshold=None),
|
| 42 |
+
Metric(name="IL_FPR", image_level=True, iou_threshold=None),
|
| 43 |
+
Metric(name="IL_MCC", image_level=True, iou_threshold=None),
|
| 44 |
+
Metric(name="cgF1", image_level=False, iou_threshold=0.5),
|
| 45 |
+
Metric(name="precision", image_level=False, iou_threshold=0.5),
|
| 46 |
+
Metric(name="recall", image_level=False, iou_threshold=0.5),
|
| 47 |
+
Metric(name="F1", image_level=False, iou_threshold=0.5),
|
| 48 |
+
Metric(name="positive_macro_F1", image_level=False, iou_threshold=0.5),
|
| 49 |
+
Metric(name="positive_micro_F1", image_level=False, iou_threshold=0.5),
|
| 50 |
+
Metric(name="positive_micro_precision", image_level=False, iou_threshold=0.5),
|
| 51 |
+
Metric(name="cgF1", image_level=False, iou_threshold=0.75),
|
| 52 |
+
Metric(name="precision", image_level=False, iou_threshold=0.75),
|
| 53 |
+
Metric(name="recall", image_level=False, iou_threshold=0.75),
|
| 54 |
+
Metric(name="F1", image_level=False, iou_threshold=0.75),
|
| 55 |
+
Metric(name="positive_macro_F1", image_level=False, iou_threshold=0.75),
|
| 56 |
+
Metric(name="positive_micro_F1", image_level=False, iou_threshold=0.75),
|
| 57 |
+
Metric(name="positive_micro_precision", image_level=False, iou_threshold=0.75),
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class COCOCustom(COCO):
|
| 62 |
+
"""COCO class from pycocotools with tiny modifications for speed"""
|
| 63 |
+
|
| 64 |
+
def createIndex(self):
|
| 65 |
+
# create index
|
| 66 |
+
print("creating index...")
|
| 67 |
+
anns, cats, imgs = {}, {}, {}
|
| 68 |
+
imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
|
| 69 |
+
if "annotations" in self.dataset:
|
| 70 |
+
for ann in self.dataset["annotations"]:
|
| 71 |
+
imgToAnns[ann["image_id"]].append(ann)
|
| 72 |
+
anns[ann["id"]] = ann
|
| 73 |
+
|
| 74 |
+
if "images" in self.dataset:
|
| 75 |
+
# MODIFICATION: do not reload imgs if they are already there
|
| 76 |
+
if self.imgs:
|
| 77 |
+
imgs = self.imgs
|
| 78 |
+
else:
|
| 79 |
+
for img in self.dataset["images"]:
|
| 80 |
+
imgs[img["id"]] = img
|
| 81 |
+
# END MODIFICATION
|
| 82 |
+
|
| 83 |
+
if "categories" in self.dataset:
|
| 84 |
+
for cat in self.dataset["categories"]:
|
| 85 |
+
cats[cat["id"]] = cat
|
| 86 |
+
|
| 87 |
+
if "annotations" in self.dataset and "categories" in self.dataset:
|
| 88 |
+
for ann in self.dataset["annotations"]:
|
| 89 |
+
catToImgs[ann["category_id"]].append(ann["image_id"])
|
| 90 |
+
|
| 91 |
+
print("index created!")
|
| 92 |
+
|
| 93 |
+
# create class members
|
| 94 |
+
self.anns = anns
|
| 95 |
+
self.imgToAnns = imgToAnns
|
| 96 |
+
self.catToImgs = catToImgs
|
| 97 |
+
self.imgs = imgs
|
| 98 |
+
self.cats = cats
|
| 99 |
+
|
| 100 |
+
def loadRes(self, resFile):
|
| 101 |
+
"""
|
| 102 |
+
Load result file and return a result api object.
|
| 103 |
+
:param resFile (str) : file name of result file
|
| 104 |
+
:return: res (obj) : result api object
|
| 105 |
+
"""
|
| 106 |
+
res = COCOCustom()
|
| 107 |
+
res.dataset["info"] = copy.deepcopy(self.dataset.get("info", {}))
|
| 108 |
+
# MODIFICATION: no copy
|
| 109 |
+
# res.dataset['images'] = [img for img in self.dataset['images']]
|
| 110 |
+
res.dataset["images"] = self.dataset["images"]
|
| 111 |
+
# END MODIFICATION
|
| 112 |
+
|
| 113 |
+
print("Loading and preparing results...")
|
| 114 |
+
tic = time.time()
|
| 115 |
+
if type(resFile) == str:
|
| 116 |
+
with open(resFile) as f:
|
| 117 |
+
anns = json.load(f)
|
| 118 |
+
elif type(resFile) == np.ndarray:
|
| 119 |
+
anns = self.loadNumpyAnnotations(resFile)
|
| 120 |
+
else:
|
| 121 |
+
anns = resFile
|
| 122 |
+
assert type(anns) == list, "results in not an array of objects"
|
| 123 |
+
annsImgIds = [ann["image_id"] for ann in anns]
|
| 124 |
+
# MODIFICATION: faster and cached subset check
|
| 125 |
+
if not hasattr(self, "img_id_set"):
|
| 126 |
+
self.img_id_set = set(self.getImgIds())
|
| 127 |
+
assert set(annsImgIds).issubset(
|
| 128 |
+
self.img_id_set
|
| 129 |
+
), "Results do not correspond to current coco set"
|
| 130 |
+
# END MODIFICATION
|
| 131 |
+
if "caption" in anns[0]:
|
| 132 |
+
imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
|
| 133 |
+
[ann["image_id"] for ann in anns]
|
| 134 |
+
)
|
| 135 |
+
res.dataset["images"] = [
|
| 136 |
+
img for img in res.dataset["images"] if img["id"] in imgIds
|
| 137 |
+
]
|
| 138 |
+
for id, ann in enumerate(anns):
|
| 139 |
+
ann["id"] = id + 1
|
| 140 |
+
elif "bbox" in anns[0] and not anns[0]["bbox"] == []:
|
| 141 |
+
res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
|
| 142 |
+
for id, ann in enumerate(anns):
|
| 143 |
+
bb = ann["bbox"]
|
| 144 |
+
x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
|
| 145 |
+
if not "segmentation" in ann:
|
| 146 |
+
ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
|
| 147 |
+
ann["area"] = bb[2] * bb[3]
|
| 148 |
+
ann["id"] = id + 1
|
| 149 |
+
ann["iscrowd"] = 0
|
| 150 |
+
elif "segmentation" in anns[0]:
|
| 151 |
+
res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
|
| 152 |
+
for id, ann in enumerate(anns):
|
| 153 |
+
# now only support compressed RLE format as segmentation results
|
| 154 |
+
ann["area"] = maskUtils.area(ann["segmentation"])
|
| 155 |
+
if not "bbox" in ann:
|
| 156 |
+
ann["bbox"] = maskUtils.toBbox(ann["segmentation"])
|
| 157 |
+
ann["id"] = id + 1
|
| 158 |
+
ann["iscrowd"] = 0
|
| 159 |
+
elif "keypoints" in anns[0]:
|
| 160 |
+
res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
|
| 161 |
+
for id, ann in enumerate(anns):
|
| 162 |
+
s = ann["keypoints"]
|
| 163 |
+
x = s[0::3]
|
| 164 |
+
y = s[1::3]
|
| 165 |
+
x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
|
| 166 |
+
ann["area"] = (x1 - x0) * (y1 - y0)
|
| 167 |
+
ann["id"] = id + 1
|
| 168 |
+
ann["bbox"] = [x0, y0, x1 - x0, y1 - y0]
|
| 169 |
+
print("DONE (t={:0.2f}s)".format(time.time() - tic))
|
| 170 |
+
|
| 171 |
+
res.dataset["annotations"] = anns
|
| 172 |
+
# MODIFICATION: inherit images
|
| 173 |
+
res.imgs = self.imgs
|
| 174 |
+
# END MODIFICATION
|
| 175 |
+
res.createIndex()
|
| 176 |
+
return res
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class CGF1Eval(COCOeval):
|
| 180 |
+
"""
|
| 181 |
+
This evaluator is based upon COCO evaluation, but evaluates the model in a more realistic setting
|
| 182 |
+
for downstream applications.
|
| 183 |
+
See SAM3 paper for the details on the CGF1 metric.
|
| 184 |
+
|
| 185 |
+
Do not use this evaluator directly. Prefer the CGF1Evaluator wrapper.
|
| 186 |
+
|
| 187 |
+
Notes:
|
| 188 |
+
- This evaluator does not support per-category evaluation (in the way defined by pyCocotools)
|
| 189 |
+
- In open vocabulary settings, we have different noun-phrases for each image. What we call an "image_id" here is actually an (image, noun-phrase) pair. So in every "image_id" there is only one category, implied by the noun-phrase. Thus we can ignore the usual coco "category" field of the predictions
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
coco_gt=None,
|
| 195 |
+
coco_dt=None,
|
| 196 |
+
iouType="segm",
|
| 197 |
+
threshold=0.5,
|
| 198 |
+
):
|
| 199 |
+
"""
|
| 200 |
+
Args:
|
| 201 |
+
coco_gt (COCO): ground truth COCO API
|
| 202 |
+
coco_dt (COCO): detections COCO API
|
| 203 |
+
iou_type (str): type of IoU to evaluate
|
| 204 |
+
threshold (float): threshold for predictions
|
| 205 |
+
"""
|
| 206 |
+
super().__init__(coco_gt, coco_dt, iouType)
|
| 207 |
+
self.threshold = threshold
|
| 208 |
+
|
| 209 |
+
self.params.useCats = False
|
| 210 |
+
self.params.areaRng = [[0**2, 1e5**2]]
|
| 211 |
+
self.params.areaRngLbl = ["all"]
|
| 212 |
+
self.params.maxDets = [1000000]
|
| 213 |
+
|
| 214 |
+
def computeIoU(self, imgId, catId):
|
| 215 |
+
# Same as the original COCOeval.computeIoU, but without sorting
|
| 216 |
+
p = self.params
|
| 217 |
+
if p.useCats:
|
| 218 |
+
gt = self._gts[imgId, catId]
|
| 219 |
+
dt = self._dts[imgId, catId]
|
| 220 |
+
else:
|
| 221 |
+
gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
|
| 222 |
+
dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
|
| 223 |
+
if len(gt) == 0 and len(dt) == 0:
|
| 224 |
+
return []
|
| 225 |
+
|
| 226 |
+
if p.iouType == "segm":
|
| 227 |
+
g = [g["segmentation"] for g in gt]
|
| 228 |
+
d = [d["segmentation"] for d in dt]
|
| 229 |
+
elif p.iouType == "bbox":
|
| 230 |
+
g = [g["bbox"] for g in gt]
|
| 231 |
+
d = [d["bbox"] for d in dt]
|
| 232 |
+
else:
|
| 233 |
+
raise Exception("unknown iouType for iou computation")
|
| 234 |
+
|
| 235 |
+
# compute iou between each dt and gt region
|
| 236 |
+
iscrowd = [int(o["iscrowd"]) for o in gt]
|
| 237 |
+
ious = maskUtils.iou(d, g, iscrowd)
|
| 238 |
+
return ious
|
| 239 |
+
|
| 240 |
+
def evaluateImg(self, imgId, catId, aRng, maxDet):
|
| 241 |
+
"""
|
| 242 |
+
perform evaluation for single category and image
|
| 243 |
+
:return: dict (single image results)
|
| 244 |
+
"""
|
| 245 |
+
p = self.params
|
| 246 |
+
assert not p.useCats, "This evaluator does not support per-category evaluation."
|
| 247 |
+
assert catId == -1
|
| 248 |
+
all_gts = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
|
| 249 |
+
keep_gt = np.array([not g["ignore"] for g in all_gts], dtype=bool)
|
| 250 |
+
gt = [g for g in all_gts if not g["ignore"]]
|
| 251 |
+
all_dts = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
|
| 252 |
+
keep_dt = np.array([d["score"] >= self.threshold for d in all_dts], dtype=bool)
|
| 253 |
+
dt = [d for d in all_dts if d["score"] >= self.threshold]
|
| 254 |
+
if len(gt) == 0 and len(dt) == 0:
|
| 255 |
+
# This is a "true negative" case, where there are no GTs and no predictions
|
| 256 |
+
# The box-level metrics are ill-defined, so we don't add them to this dict
|
| 257 |
+
return {
|
| 258 |
+
"image_id": imgId,
|
| 259 |
+
"IL_TP": 0,
|
| 260 |
+
"IL_TN": 1,
|
| 261 |
+
"IL_FP": 0,
|
| 262 |
+
"IL_FN": 0,
|
| 263 |
+
"num_dt": len(dt),
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
if len(gt) > 0 and len(dt) == 0:
|
| 267 |
+
# This is a "false negative" case, where there are GTs but no predictions
|
| 268 |
+
return {
|
| 269 |
+
"image_id": imgId,
|
| 270 |
+
"IL_TP": 0,
|
| 271 |
+
"IL_TN": 0,
|
| 272 |
+
"IL_FP": 0,
|
| 273 |
+
"IL_FN": 1,
|
| 274 |
+
"TPs": np.zeros((len(p.iouThrs),), dtype=np.int64),
|
| 275 |
+
"FPs": np.zeros((len(p.iouThrs),), dtype=np.int64),
|
| 276 |
+
"FNs": np.ones((len(p.iouThrs),), dtype=np.int64) * len(gt),
|
| 277 |
+
"local_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64),
|
| 278 |
+
"local_positive_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64),
|
| 279 |
+
"num_dt": len(dt),
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
# Load pre-computed ious
|
| 283 |
+
ious = self.ious[(imgId, catId)]
|
| 284 |
+
|
| 285 |
+
# compute matching
|
| 286 |
+
if len(ious) == 0:
|
| 287 |
+
ious = np.zeros((len(dt), len(gt)))
|
| 288 |
+
else:
|
| 289 |
+
ious = ious[keep_dt, :][:, keep_gt]
|
| 290 |
+
assert ious.shape == (len(dt), len(gt))
|
| 291 |
+
|
| 292 |
+
matched_dt, matched_gt = linear_sum_assignment(-ious)
|
| 293 |
+
|
| 294 |
+
match_scores = ious[matched_dt, matched_gt]
|
| 295 |
+
|
| 296 |
+
TPs, FPs, FNs = [], [], []
|
| 297 |
+
IL_perfect = []
|
| 298 |
+
for thresh in p.iouThrs:
|
| 299 |
+
TP = (match_scores >= thresh).sum()
|
| 300 |
+
FP = len(dt) - TP
|
| 301 |
+
FN = len(gt) - TP
|
| 302 |
+
assert (
|
| 303 |
+
FP >= 0 and FN >= 0
|
| 304 |
+
), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
|
| 305 |
+
TPs.append(TP)
|
| 306 |
+
FPs.append(FP)
|
| 307 |
+
FNs.append(FN)
|
| 308 |
+
|
| 309 |
+
if FP == FN and FP == 0:
|
| 310 |
+
IL_perfect.append(1)
|
| 311 |
+
else:
|
| 312 |
+
IL_perfect.append(0)
|
| 313 |
+
|
| 314 |
+
TPs = np.array(TPs, dtype=np.int64)
|
| 315 |
+
FPs = np.array(FPs, dtype=np.int64)
|
| 316 |
+
FNs = np.array(FNs, dtype=np.int64)
|
| 317 |
+
IL_perfect = np.array(IL_perfect, dtype=np.int64)
|
| 318 |
+
|
| 319 |
+
# compute precision recall and F1
|
| 320 |
+
precision = TPs / (TPs + FPs + 1e-4)
|
| 321 |
+
assert np.all(precision <= 1)
|
| 322 |
+
recall = TPs / (TPs + FNs + 1e-4)
|
| 323 |
+
assert np.all(recall <= 1)
|
| 324 |
+
F1 = 2 * precision * recall / (precision + recall + 1e-4)
|
| 325 |
+
|
| 326 |
+
result = {
|
| 327 |
+
"image_id": imgId,
|
| 328 |
+
"TPs": TPs,
|
| 329 |
+
"FPs": FPs,
|
| 330 |
+
"FNs": FNs,
|
| 331 |
+
"local_F1s": F1,
|
| 332 |
+
"IL_TP": (len(gt) > 0) and (len(dt) > 0),
|
| 333 |
+
"IL_FP": (len(gt) == 0) and (len(dt) > 0),
|
| 334 |
+
"IL_TN": (len(gt) == 0) and (len(dt) == 0),
|
| 335 |
+
"IL_FN": (len(gt) > 0) and (len(dt) == 0),
|
| 336 |
+
"num_dt": len(dt),
|
| 337 |
+
}
|
| 338 |
+
if len(gt) > 0 and len(dt) > 0:
|
| 339 |
+
result["local_positive_F1s"] = F1
|
| 340 |
+
return result
|
| 341 |
+
|
| 342 |
+
def accumulate(self, p=None):
|
| 343 |
+
"""
|
| 344 |
+
Accumulate per image evaluation results and store the result in self.eval
|
| 345 |
+
:param p: input params for evaluation
|
| 346 |
+
:return: None
|
| 347 |
+
"""
|
| 348 |
+
if self.evalImgs is None or len(self.evalImgs) == 0:
|
| 349 |
+
print("Please run evaluate() first")
|
| 350 |
+
# allows input customized parameters
|
| 351 |
+
if p is None:
|
| 352 |
+
p = self.params
|
| 353 |
+
|
| 354 |
+
setImgIds = set(p.imgIds)
|
| 355 |
+
|
| 356 |
+
# TPs, FPs, FNs
|
| 357 |
+
TPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
|
| 358 |
+
FPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
|
| 359 |
+
pmFPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
|
| 360 |
+
FNs = np.zeros((len(p.iouThrs),), dtype=np.int64)
|
| 361 |
+
local_F1s = np.zeros((len(p.iouThrs),), dtype=np.float64)
|
| 362 |
+
|
| 363 |
+
# Image level metrics
|
| 364 |
+
IL_TPs = 0
|
| 365 |
+
IL_FPs = 0
|
| 366 |
+
IL_TNs = 0
|
| 367 |
+
IL_FNs = 0
|
| 368 |
+
|
| 369 |
+
valid_img_count = 0
|
| 370 |
+
valid_F1_count = 0
|
| 371 |
+
evaledImgIds = set()
|
| 372 |
+
for res in self.evalImgs:
|
| 373 |
+
if res["image_id"] not in setImgIds:
|
| 374 |
+
continue
|
| 375 |
+
evaledImgIds.add(res["image_id"])
|
| 376 |
+
IL_TPs += res["IL_TP"]
|
| 377 |
+
IL_FPs += res["IL_FP"]
|
| 378 |
+
IL_TNs += res["IL_TN"]
|
| 379 |
+
IL_FNs += res["IL_FN"]
|
| 380 |
+
|
| 381 |
+
if "TPs" not in res:
|
| 382 |
+
continue
|
| 383 |
+
|
| 384 |
+
TPs += res["TPs"]
|
| 385 |
+
FPs += res["FPs"]
|
| 386 |
+
FNs += res["FNs"]
|
| 387 |
+
valid_img_count += 1
|
| 388 |
+
|
| 389 |
+
if "local_positive_F1s" in res:
|
| 390 |
+
local_F1s += res["local_positive_F1s"]
|
| 391 |
+
pmFPs += res["FPs"]
|
| 392 |
+
if res["num_dt"] > 0:
|
| 393 |
+
valid_F1_count += 1
|
| 394 |
+
|
| 395 |
+
assert len(setImgIds - evaledImgIds) == 0, (
|
| 396 |
+
f"{len(setImgIds - evaledImgIds)} images not evaluated. "
|
| 397 |
+
f"Here are the IDs of the first 3: {list(setImgIds - evaledImgIds)[:3]}"
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# compute precision recall and F1
|
| 401 |
+
precision = TPs / (TPs + FPs + 1e-4)
|
| 402 |
+
positive_micro_precision = TPs / (TPs + pmFPs + 1e-4)
|
| 403 |
+
assert np.all(precision <= 1)
|
| 404 |
+
recall = TPs / (TPs + FNs + 1e-4)
|
| 405 |
+
assert np.all(recall <= 1)
|
| 406 |
+
F1 = 2 * precision * recall / (precision + recall + 1e-4)
|
| 407 |
+
positive_micro_F1 = (
|
| 408 |
+
2
|
| 409 |
+
* positive_micro_precision
|
| 410 |
+
* recall
|
| 411 |
+
/ (positive_micro_precision + recall + 1e-4)
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
IL_rec = IL_TPs / (IL_TPs + IL_FNs + 1e-6)
|
| 415 |
+
IL_prec = IL_TPs / (IL_TPs + IL_FPs + 1e-6)
|
| 416 |
+
IL_F1 = 2 * IL_prec * IL_rec / (IL_prec + IL_rec + 1e-6)
|
| 417 |
+
IL_FPR = IL_FPs / (IL_FPs + IL_TNs + 1e-6)
|
| 418 |
+
IL_MCC = float(IL_TPs * IL_TNs - IL_FPs * IL_FNs) / (
|
| 419 |
+
(
|
| 420 |
+
float(IL_TPs + IL_FPs)
|
| 421 |
+
* float(IL_TPs + IL_FNs)
|
| 422 |
+
* float(IL_TNs + IL_FPs)
|
| 423 |
+
* float(IL_TNs + IL_FNs)
|
| 424 |
+
)
|
| 425 |
+
** 0.5
|
| 426 |
+
+ 1e-6
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
self.eval = {
|
| 430 |
+
"params": p,
|
| 431 |
+
"TPs": TPs,
|
| 432 |
+
"FPs": FPs,
|
| 433 |
+
"positive_micro_FPs": pmFPs,
|
| 434 |
+
"FNs": FNs,
|
| 435 |
+
"precision": precision,
|
| 436 |
+
"positive_micro_precision": positive_micro_precision,
|
| 437 |
+
"recall": recall,
|
| 438 |
+
"F1": F1,
|
| 439 |
+
"positive_micro_F1": positive_micro_F1,
|
| 440 |
+
"positive_macro_F1": local_F1s / valid_F1_count,
|
| 441 |
+
"IL_recall": IL_rec,
|
| 442 |
+
"IL_precision": IL_prec,
|
| 443 |
+
"IL_F1": IL_F1,
|
| 444 |
+
"IL_FPR": IL_FPR,
|
| 445 |
+
"IL_MCC": IL_MCC,
|
| 446 |
+
}
|
| 447 |
+
self.eval["cgF1"] = self.eval["positive_micro_F1"] * self.eval["IL_MCC"]
|
| 448 |
+
|
| 449 |
+
def summarize(self):
|
| 450 |
+
"""
|
| 451 |
+
Compute and display summary metrics for evaluation results.
|
| 452 |
+
"""
|
| 453 |
+
if not self.eval:
|
| 454 |
+
raise Exception("Please run accumulate() first")
|
| 455 |
+
|
| 456 |
+
def _summarize(iouThr=None, metric=""):
|
| 457 |
+
p = self.params
|
| 458 |
+
iStr = " {:<18} @[ IoU={:<9}] = {:0.3f}"
|
| 459 |
+
titleStr = "Average " + metric
|
| 460 |
+
iouStr = (
|
| 461 |
+
"{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
|
| 462 |
+
if iouThr is None
|
| 463 |
+
else "{:0.2f}".format(iouThr)
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
s = self.eval[metric]
|
| 467 |
+
# IoU
|
| 468 |
+
if iouThr is not None:
|
| 469 |
+
t = np.where(iouThr == p.iouThrs)[0]
|
| 470 |
+
s = s[t]
|
| 471 |
+
|
| 472 |
+
if len(s[s > -1]) == 0:
|
| 473 |
+
mean_s = -1
|
| 474 |
+
else:
|
| 475 |
+
mean_s = np.mean(s[s > -1])
|
| 476 |
+
print(iStr.format(titleStr, iouStr, mean_s))
|
| 477 |
+
return mean_s
|
| 478 |
+
|
| 479 |
+
def _summarize_single(metric=""):
|
| 480 |
+
titleStr = "Average " + metric
|
| 481 |
+
iStr = " {:<35} = {:0.3f}"
|
| 482 |
+
s = self.eval[metric]
|
| 483 |
+
print(iStr.format(titleStr, s))
|
| 484 |
+
return s
|
| 485 |
+
|
| 486 |
+
def _summarizeDets():
|
| 487 |
+
stats = []
|
| 488 |
+
|
| 489 |
+
for metric in CGF1_METRICS:
|
| 490 |
+
if metric.image_level:
|
| 491 |
+
stats.append(_summarize_single(metric=metric.name))
|
| 492 |
+
else:
|
| 493 |
+
stats.append(
|
| 494 |
+
_summarize(iouThr=metric.iou_threshold, metric=metric.name)
|
| 495 |
+
)
|
| 496 |
+
return np.asarray(stats)
|
| 497 |
+
|
| 498 |
+
summarize = _summarizeDets
|
| 499 |
+
self.stats = summarize()
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def _evaluate(self):
|
| 503 |
+
"""
|
| 504 |
+
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
|
| 505 |
+
"""
|
| 506 |
+
p = self.params
|
| 507 |
+
# add backward compatibility if useSegm is specified in params
|
| 508 |
+
p.imgIds = list(np.unique(p.imgIds))
|
| 509 |
+
p.useCats = False
|
| 510 |
+
p.maxDets = sorted(p.maxDets)
|
| 511 |
+
self.params = p
|
| 512 |
+
|
| 513 |
+
self._prepare()
|
| 514 |
+
# loop through images, area range, max detection number
|
| 515 |
+
catIds = [-1]
|
| 516 |
+
|
| 517 |
+
if p.iouType == "segm" or p.iouType == "bbox":
|
| 518 |
+
computeIoU = self.computeIoU
|
| 519 |
+
else:
|
| 520 |
+
raise RuntimeError(f"Unsupported iou {p.iouType}")
|
| 521 |
+
self.ious = {
|
| 522 |
+
(imgId, catId): computeIoU(imgId, catId)
|
| 523 |
+
for imgId in p.imgIds
|
| 524 |
+
for catId in catIds
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
maxDet = p.maxDets[-1]
|
| 528 |
+
evalImgs = [
|
| 529 |
+
self.evaluateImg(imgId, catId, areaRng, maxDet)
|
| 530 |
+
for catId in catIds
|
| 531 |
+
for areaRng in p.areaRng
|
| 532 |
+
for imgId in p.imgIds
|
| 533 |
+
]
|
| 534 |
+
# this is NOT in the pycocotools code, but could be done outside
|
| 535 |
+
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
|
| 536 |
+
return p.imgIds, evalImgs
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
class CGF1Evaluator:
|
| 540 |
+
"""
|
| 541 |
+
Wrapper class for cgF1 evaluation.
|
| 542 |
+
This supports the oracle setting (when several ground-truths are available per image)
|
| 543 |
+
"""
|
| 544 |
+
|
| 545 |
+
def __init__(
|
| 546 |
+
self,
|
| 547 |
+
gt_path: Union[str, List[str]],
|
| 548 |
+
iou_type="segm",
|
| 549 |
+
verbose=False,
|
| 550 |
+
):
|
| 551 |
+
"""
|
| 552 |
+
Args:
|
| 553 |
+
gt_path (str or list of str): path(s) to ground truth COCO json file(s)
|
| 554 |
+
iou_type (str): type of IoU to evaluate
|
| 555 |
+
threshold (float): threshold for predictions
|
| 556 |
+
"""
|
| 557 |
+
self.gt_paths = gt_path if isinstance(gt_path, list) else [gt_path]
|
| 558 |
+
self.iou_type = iou_type
|
| 559 |
+
|
| 560 |
+
self.coco_gts = [COCOCustom(gt) for gt in self.gt_paths]
|
| 561 |
+
|
| 562 |
+
self.verbose = verbose
|
| 563 |
+
|
| 564 |
+
self.coco_evals = []
|
| 565 |
+
for i, coco_gt in enumerate(self.coco_gts):
|
| 566 |
+
self.coco_evals.append(
|
| 567 |
+
CGF1Eval(
|
| 568 |
+
coco_gt=coco_gt,
|
| 569 |
+
iouType=iou_type,
|
| 570 |
+
)
|
| 571 |
+
)
|
| 572 |
+
self.coco_evals[i].useCats = False
|
| 573 |
+
|
| 574 |
+
exclude_img_ids = set()
|
| 575 |
+
# exclude_img_ids are the ids that are not exhaustively annotated in any of the other gts
|
| 576 |
+
for coco_gt in self.coco_gts[1:]:
|
| 577 |
+
exclude_img_ids = exclude_img_ids.union(
|
| 578 |
+
{
|
| 579 |
+
img["id"]
|
| 580 |
+
for img in coco_gt.dataset["images"]
|
| 581 |
+
if not img["is_instance_exhaustive"]
|
| 582 |
+
}
|
| 583 |
+
)
|
| 584 |
+
# we only eval on instance exhaustive queries
|
| 585 |
+
self.eval_img_ids = [
|
| 586 |
+
img["id"]
|
| 587 |
+
for img in self.coco_gts[0].dataset["images"]
|
| 588 |
+
if (img["is_instance_exhaustive"] and img["id"] not in exclude_img_ids)
|
| 589 |
+
]
|
| 590 |
+
|
| 591 |
+
def evaluate(self, pred_file: str):
|
| 592 |
+
"""
|
| 593 |
+
Evaluate the detections using cgF1 metric.
|
| 594 |
+
|
| 595 |
+
Args:
|
| 596 |
+
pred_file: path to the predictions COCO json file
|
| 597 |
+
|
| 598 |
+
"""
|
| 599 |
+
assert len(self.coco_gts) > 0, "No ground truth provided for evaluation."
|
| 600 |
+
assert len(self.coco_gts) == len(
|
| 601 |
+
self.coco_evals
|
| 602 |
+
), "Mismatch in number of ground truths and evaluators."
|
| 603 |
+
|
| 604 |
+
if self.verbose:
|
| 605 |
+
print(f"Loading predictions from {pred_file}")
|
| 606 |
+
|
| 607 |
+
with open(pred_file, "r") as f:
|
| 608 |
+
preds = json.load(f)
|
| 609 |
+
|
| 610 |
+
if self.verbose:
|
| 611 |
+
print(f"Loaded {len(preds)} predictions")
|
| 612 |
+
|
| 613 |
+
img2preds = defaultdict(list)
|
| 614 |
+
for pred in preds:
|
| 615 |
+
img2preds[pred["image_id"]].append(pred)
|
| 616 |
+
|
| 617 |
+
all_eval_imgs = []
|
| 618 |
+
for img_id in tqdm(self.eval_img_ids, disable=not self.verbose):
|
| 619 |
+
results = img2preds[img_id]
|
| 620 |
+
all_scorings = []
|
| 621 |
+
for cur_coco_gt, coco_eval in zip(self.coco_gts, self.coco_evals):
|
| 622 |
+
# suppress pycocotools prints
|
| 623 |
+
with open(os.devnull, "w") as devnull:
|
| 624 |
+
with contextlib.redirect_stdout(devnull):
|
| 625 |
+
coco_dt = (
|
| 626 |
+
cur_coco_gt.loadRes(results) if results else COCOCustom()
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
coco_eval.cocoDt = coco_dt
|
| 630 |
+
coco_eval.params.imgIds = [img_id]
|
| 631 |
+
coco_eval.params.useCats = False
|
| 632 |
+
img_ids, eval_imgs = _evaluate(coco_eval)
|
| 633 |
+
all_scorings.append(eval_imgs)
|
| 634 |
+
selected = self._select_best_scoring(all_scorings)
|
| 635 |
+
all_eval_imgs.append(selected)
|
| 636 |
+
|
| 637 |
+
# After this point, we have selected the best scoring per image among several ground truths
|
| 638 |
+
# we can now accumulate and summarize, using only the first coco_eval
|
| 639 |
+
|
| 640 |
+
self.coco_evals[0].evalImgs = list(
|
| 641 |
+
np.concatenate(all_eval_imgs, axis=2).flatten()
|
| 642 |
+
)
|
| 643 |
+
self.coco_evals[0].params.imgIds = self.eval_img_ids
|
| 644 |
+
self.coco_evals[0]._paramsEval = copy.deepcopy(self.coco_evals[0].params)
|
| 645 |
+
|
| 646 |
+
if self.verbose:
|
| 647 |
+
print(f"Accumulating results")
|
| 648 |
+
self.coco_evals[0].accumulate()
|
| 649 |
+
print("cgF1 metric, IoU type={}".format(self.iou_type))
|
| 650 |
+
self.coco_evals[0].summarize()
|
| 651 |
+
print()
|
| 652 |
+
|
| 653 |
+
out = {}
|
| 654 |
+
for i, value in enumerate(self.coco_evals[0].stats):
|
| 655 |
+
name = CGF1_METRICS[i].name
|
| 656 |
+
if CGF1_METRICS[i].iou_threshold is not None:
|
| 657 |
+
name = f"{name}@{CGF1_METRICS[i].iou_threshold}"
|
| 658 |
+
out[f"cgF1_eval_{self.iou_type}_{name}"] = float(value)
|
| 659 |
+
|
| 660 |
+
return out
|
| 661 |
+
|
| 662 |
+
@staticmethod
|
| 663 |
+
def _select_best_scoring(scorings):
|
| 664 |
+
# This function is used for "oracle" type evaluation.
|
| 665 |
+
# It accepts the evaluation results with respect to several ground truths, and picks the best
|
| 666 |
+
if len(scorings) == 1:
|
| 667 |
+
return scorings[0]
|
| 668 |
+
|
| 669 |
+
assert (
|
| 670 |
+
scorings[0].ndim == 3
|
| 671 |
+
), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
|
| 672 |
+
assert (
|
| 673 |
+
scorings[0].shape[0] == 1
|
| 674 |
+
), f"Expecting a single category, got {scorings[0].shape[0]}"
|
| 675 |
+
|
| 676 |
+
for scoring in scorings:
|
| 677 |
+
assert (
|
| 678 |
+
scoring.shape == scorings[0].shape
|
| 679 |
+
), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
|
| 680 |
+
|
| 681 |
+
selected_imgs = []
|
| 682 |
+
for img_id in range(scorings[0].shape[-1]):
|
| 683 |
+
best = scorings[0][:, :, img_id]
|
| 684 |
+
|
| 685 |
+
for scoring in scorings[1:]:
|
| 686 |
+
current = scoring[:, :, img_id]
|
| 687 |
+
if "local_F1s" in best[0, 0] and "local_F1s" in current[0, 0]:
|
| 688 |
+
# we were able to compute a F1 score for this particular image in both evaluations
|
| 689 |
+
# best["local_F1s"] contains the results at various IoU thresholds. We simply take the average for comparision
|
| 690 |
+
best_score = best[0, 0]["local_F1s"].mean()
|
| 691 |
+
current_score = current[0, 0]["local_F1s"].mean()
|
| 692 |
+
if current_score > best_score:
|
| 693 |
+
best = current
|
| 694 |
+
|
| 695 |
+
else:
|
| 696 |
+
# If we're here, it means that in that in some evaluation we were not able to get a valid local F1
|
| 697 |
+
# This happens when both the predictions and targets are empty. In that case, we can assume it's a perfect prediction
|
| 698 |
+
if "local_F1s" not in current[0, 0]:
|
| 699 |
+
best = current
|
| 700 |
+
selected_imgs.append(best)
|
| 701 |
+
result = np.stack(selected_imgs, axis=-1)
|
| 702 |
+
assert result.shape == scorings[0].shape
|
| 703 |
+
return result
|
sam3/eval/coco_eval.py
ADDED
|
@@ -0,0 +1,916 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
COCO evaluator that works in distributed mode.
|
| 5 |
+
|
| 6 |
+
Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
|
| 7 |
+
The difference is that there is less copy-pasting from pycocotools
|
| 8 |
+
in the end of the file, as python3 can suppress prints with contextlib
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import contextlib
|
| 12 |
+
import copy
|
| 13 |
+
import json
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import pickle
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
from typing import Any, List, Optional
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
import pycocotools.mask as mask_utils
|
| 25 |
+
import torch
|
| 26 |
+
from iopath.common.file_io import g_pathmgr
|
| 27 |
+
from pycocotools.coco import COCO
|
| 28 |
+
from pycocotools.cocoeval import COCOeval
|
| 29 |
+
|
| 30 |
+
from sam3.train.masks_ops import rle_encode
|
| 31 |
+
|
| 32 |
+
from sam3.train.utils.distributed import (
|
| 33 |
+
all_gather,
|
| 34 |
+
gather_to_rank_0_via_filesys,
|
| 35 |
+
get_rank,
|
| 36 |
+
is_main_process,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
RARITY_BUCKETS = {0: "frequent", 1: "common", 2: "medium", 3: "rare"}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class CocoEvaluator:
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
coco_gt,
|
| 46 |
+
iou_types: List[str],
|
| 47 |
+
useCats: bool,
|
| 48 |
+
dump_dir: Optional[str],
|
| 49 |
+
postprocessor,
|
| 50 |
+
average_by_rarity=False,
|
| 51 |
+
metrics_dump_dir: Optional[str] = None,
|
| 52 |
+
gather_pred_via_filesys=False,
|
| 53 |
+
use_normalized_areas=True,
|
| 54 |
+
maxdets=[1, 10, 100],
|
| 55 |
+
exhaustive_only=False,
|
| 56 |
+
all_exhaustive_only=True,
|
| 57 |
+
):
|
| 58 |
+
"""Online coco evaluator. It will evaluate images as they are generated by the model, then accumulate/summarize at the end
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
- coco_gt: COCO api object containing the gt
|
| 62 |
+
- iou_types: can be either "bbox" or "segm"
|
| 63 |
+
- useCats: If true, categories will be used for evaluation
|
| 64 |
+
- dump_dir: if non null, then the predictions will be dumped in that directory
|
| 65 |
+
- postprocessor: Module to convert the model's output into the coco format
|
| 66 |
+
- average_by_rarity: if true then we expect the images information in the gt dataset
|
| 67 |
+
to have a "rarity" field. Then the AP will be computed on all rarity buckets
|
| 68 |
+
individually, then averaged
|
| 69 |
+
- gather_pred_via_filesys: if true, we use the filesystem for collective gathers
|
| 70 |
+
- use_normalized_areas: if true, the areas of the objects in the GT are assumed to be
|
| 71 |
+
normalized by the area of the image. In that case, the size buckets are adjusted
|
| 72 |
+
- maxdets: maximal number of detections to be evaluated on each image.
|
| 73 |
+
- exhaustive_only: If true, we restrict eval only to exhaustive annotations
|
| 74 |
+
- all_exhaustive_only: If true, datapoints are restricted only to those with all exhaustive annotations
|
| 75 |
+
|
| 76 |
+
"""
|
| 77 |
+
# coco_gt = copy.deepcopy(coco_gt)
|
| 78 |
+
self.coco_gts = [coco_gt] if not isinstance(coco_gt, list) else coco_gt
|
| 79 |
+
assert len(maxdets) == 3, f"expecting 3 detection threshold, got {len(maxdets)}"
|
| 80 |
+
|
| 81 |
+
self.use_normalized_areas = use_normalized_areas
|
| 82 |
+
self.iou_types = iou_types
|
| 83 |
+
self.useCats = useCats
|
| 84 |
+
self.maxdets = maxdets
|
| 85 |
+
self.dump = None
|
| 86 |
+
self.dump_dir = dump_dir
|
| 87 |
+
if self.dump_dir is not None:
|
| 88 |
+
self.dump = []
|
| 89 |
+
if is_main_process():
|
| 90 |
+
if not os.path.exists(self.dump_dir):
|
| 91 |
+
os.makedirs(self.dump_dir, exist_ok=True)
|
| 92 |
+
logging.info(f"Create the folder: {dump_dir}")
|
| 93 |
+
|
| 94 |
+
self.initialized = False
|
| 95 |
+
|
| 96 |
+
# Whether to gather predictions through filesystem (instead of torch
|
| 97 |
+
# collective ops; requiring a shared filesystem across all ranks)
|
| 98 |
+
self.gather_pred_via_filesys = gather_pred_via_filesys
|
| 99 |
+
self.use_self_evaluate = True # CPP version is disabled
|
| 100 |
+
self.postprocessor = postprocessor
|
| 101 |
+
self.average_by_rarity = average_by_rarity
|
| 102 |
+
self.exhaustive_only = exhaustive_only
|
| 103 |
+
self.all_exhaustive_only = all_exhaustive_only
|
| 104 |
+
self.metrics_dump_dir = metrics_dump_dir
|
| 105 |
+
if self.metrics_dump_dir is not None:
|
| 106 |
+
if is_main_process():
|
| 107 |
+
if not os.path.exists(self.metrics_dump_dir):
|
| 108 |
+
os.makedirs(self.metrics_dump_dir, exist_ok=True)
|
| 109 |
+
logging.info(f"Create the folder: {metrics_dump_dir}")
|
| 110 |
+
|
| 111 |
+
def _lazy_init(self, coco_cls=COCO):
|
| 112 |
+
if self.initialized:
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
self.initialized = True
|
| 116 |
+
|
| 117 |
+
self.coco_gts = [
|
| 118 |
+
coco_cls(g_pathmgr.get_local_path(gt)) if isinstance(gt, str) else gt
|
| 119 |
+
for gt in self.coco_gts
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
self.reset()
|
| 123 |
+
|
| 124 |
+
self.eval_img_ids = None
|
| 125 |
+
|
| 126 |
+
if self.exhaustive_only:
|
| 127 |
+
exclude_img_ids = set()
|
| 128 |
+
# exclude_img_ids are the ids that are not exhaustively annotated in any of the other gts
|
| 129 |
+
if self.all_exhaustive_only:
|
| 130 |
+
for coco_gt in self.coco_gts[1:]:
|
| 131 |
+
exclude_img_ids = exclude_img_ids.union(
|
| 132 |
+
{
|
| 133 |
+
img["id"]
|
| 134 |
+
for img in coco_gt.dataset["images"]
|
| 135 |
+
if not img["is_instance_exhaustive"]
|
| 136 |
+
}
|
| 137 |
+
)
|
| 138 |
+
# we only eval on instance exhaustive queries
|
| 139 |
+
self.eval_img_ids = [
|
| 140 |
+
img["id"]
|
| 141 |
+
for img in self.coco_gts[0].dataset["images"]
|
| 142 |
+
if (img["is_instance_exhaustive"] and img["id"] not in exclude_img_ids)
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
self.rarity_buckets = None
|
| 146 |
+
if self.average_by_rarity:
|
| 147 |
+
self.rarity_buckets = defaultdict(list)
|
| 148 |
+
eval_img_ids_set = (
|
| 149 |
+
set(self.eval_img_ids) if self.eval_img_ids is not None else None
|
| 150 |
+
)
|
| 151 |
+
for img in self.coco_gts[0].dataset["images"]:
|
| 152 |
+
if self.eval_img_ids is not None and img["id"] not in eval_img_ids_set:
|
| 153 |
+
continue
|
| 154 |
+
self.rarity_buckets[img["rarity"]].append(img["id"])
|
| 155 |
+
print("Rarity buckets sizes:")
|
| 156 |
+
for k, v in self.rarity_buckets.items():
|
| 157 |
+
print(f"{k}: {len(v)}")
|
| 158 |
+
|
| 159 |
+
def set_sync_device(self, device: torch.device) -> Any:
|
| 160 |
+
self._sync_device = device
|
| 161 |
+
|
| 162 |
+
def _evaluate(self, *args, **kwargs):
|
| 163 |
+
return evaluate(*args, **kwargs)
|
| 164 |
+
|
| 165 |
+
def _loadRes(self, *args, **kwargs):
|
| 166 |
+
return loadRes(*args, **kwargs)
|
| 167 |
+
|
| 168 |
+
def update(self, *args, **kwargs):
|
| 169 |
+
self._lazy_init()
|
| 170 |
+
predictions = self.postprocessor.process_results(*args, **kwargs)
|
| 171 |
+
|
| 172 |
+
img_ids = list(np.unique(list(predictions.keys())))
|
| 173 |
+
self.img_ids.extend(img_ids)
|
| 174 |
+
|
| 175 |
+
for iou_type in self.iou_types:
|
| 176 |
+
results = self.prepare(predictions, iou_type)
|
| 177 |
+
self._dump(results)
|
| 178 |
+
|
| 179 |
+
assert len(self.coco_gts) == len(self.coco_evals)
|
| 180 |
+
all_scorings = []
|
| 181 |
+
for cur_coco_gt, cur_coco_eval in zip(self.coco_gts, self.coco_evals):
|
| 182 |
+
# suppress pycocotools prints
|
| 183 |
+
with open(os.devnull, "w") as devnull:
|
| 184 |
+
with contextlib.redirect_stdout(devnull):
|
| 185 |
+
coco_dt = (
|
| 186 |
+
self._loadRes(cur_coco_gt, results) if results else COCO()
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
coco_eval = cur_coco_eval[iou_type]
|
| 190 |
+
|
| 191 |
+
coco_eval.cocoDt = coco_dt
|
| 192 |
+
coco_eval.params.imgIds = list(img_ids)
|
| 193 |
+
coco_eval.params.useCats = self.useCats
|
| 194 |
+
coco_eval.params.maxDets = self.maxdets
|
| 195 |
+
img_ids, eval_imgs = self._evaluate(coco_eval, self.use_self_evaluate)
|
| 196 |
+
all_scorings.append(eval_imgs)
|
| 197 |
+
|
| 198 |
+
selected = self.select_best_scoring(all_scorings)
|
| 199 |
+
self.eval_imgs[iou_type].append(selected)
|
| 200 |
+
|
| 201 |
+
def select_best_scoring(self, scorings):
|
| 202 |
+
# This function is used for "oracle" type evaluation.
|
| 203 |
+
# It accepts the evaluation results with respect to several ground truths, and picks the best
|
| 204 |
+
if len(scorings) == 1:
|
| 205 |
+
return scorings[0]
|
| 206 |
+
|
| 207 |
+
# Currently we don't support Oracle Phrase AP.
|
| 208 |
+
# To implement it, we likely need to modify the cpp code since the eval_image type is opaque
|
| 209 |
+
raise RuntimeError("Not implemented")
|
| 210 |
+
|
| 211 |
+
def _dump(self, results):
|
| 212 |
+
if self.dump is not None:
|
| 213 |
+
dumped_results = copy.deepcopy(results)
|
| 214 |
+
for r in dumped_results:
|
| 215 |
+
if "bbox" not in self.iou_types and "bbox" in r:
|
| 216 |
+
del r["bbox"]
|
| 217 |
+
elif "bbox" in r:
|
| 218 |
+
r["bbox"] = [round(coord, 5) for coord in r["bbox"]]
|
| 219 |
+
r["score"] = round(r["score"], 5)
|
| 220 |
+
self.dump.extend(dumped_results)
|
| 221 |
+
|
| 222 |
+
def synchronize_between_processes(self):
|
| 223 |
+
self._lazy_init()
|
| 224 |
+
logging.info("Coco evaluator: Synchronizing between processes")
|
| 225 |
+
for iou_type in self.iou_types:
|
| 226 |
+
if len(self.eval_imgs[iou_type]) > 0:
|
| 227 |
+
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
|
| 228 |
+
else:
|
| 229 |
+
num_areas = len(self.coco_evals[0][iou_type].params.areaRng)
|
| 230 |
+
# assuming 1 class
|
| 231 |
+
assert not self.useCats
|
| 232 |
+
self.eval_imgs[iou_type] = np.empty((1, num_areas, 0))
|
| 233 |
+
create_common_coco_eval(
|
| 234 |
+
self.coco_evals[0][iou_type],
|
| 235 |
+
self.img_ids,
|
| 236 |
+
self.eval_imgs[iou_type],
|
| 237 |
+
use_self_evaluate=self.use_self_evaluate,
|
| 238 |
+
gather_pred_via_filesys=self.gather_pred_via_filesys,
|
| 239 |
+
metrics_dump_dir=self.metrics_dump_dir,
|
| 240 |
+
)
|
| 241 |
+
if self.dump is not None:
|
| 242 |
+
dumped_file = Path(self.dump_dir) / f"coco_predictions_{get_rank()}.json"
|
| 243 |
+
logging.info(f"COCO evaluator: Dumping local predictions to {dumped_file}")
|
| 244 |
+
with g_pathmgr.open(str(dumped_file), "w") as f:
|
| 245 |
+
json.dump(self.dump, f)
|
| 246 |
+
|
| 247 |
+
# if self.gather_pred_via_filesys:
|
| 248 |
+
# dump = gather_to_rank_0_via_filesys(self.dump)
|
| 249 |
+
# else:
|
| 250 |
+
# dump = all_gather(self.dump, force_cpu=True)
|
| 251 |
+
# self.dump = sum(dump, [])
|
| 252 |
+
|
| 253 |
+
def accumulate(self, imgIds=None):
|
| 254 |
+
self._lazy_init()
|
| 255 |
+
logging.info(
|
| 256 |
+
f"Coco evaluator: Accumulating on {len(imgIds) if imgIds is not None else 'all'} images"
|
| 257 |
+
)
|
| 258 |
+
if not is_main_process():
|
| 259 |
+
return
|
| 260 |
+
|
| 261 |
+
if imgIds is None:
|
| 262 |
+
for coco_eval in self.coco_evals[0].values():
|
| 263 |
+
accumulate(coco_eval, use_self_eval=self.use_self_evaluate)
|
| 264 |
+
|
| 265 |
+
if imgIds is not None:
|
| 266 |
+
imgIds = set(imgIds)
|
| 267 |
+
for coco_eval in self.coco_evals[0].values():
|
| 268 |
+
p = coco_eval.params
|
| 269 |
+
id_mask = np.array([(i in imgIds) for i in p.imgIds], dtype=bool)
|
| 270 |
+
old_img_ids = p.imgIds
|
| 271 |
+
coco_eval.params.imgIds = np.asarray(p.imgIds)[id_mask]
|
| 272 |
+
old_img_evals = coco_eval.evalImgs
|
| 273 |
+
catIds = p.catIds if p.useCats else [-1]
|
| 274 |
+
coco_eval.evalImgs = list(
|
| 275 |
+
np.asarray(coco_eval.evalImgs)
|
| 276 |
+
.reshape(len(catIds), len(p.areaRng), len(old_img_ids))[
|
| 277 |
+
..., id_mask
|
| 278 |
+
]
|
| 279 |
+
.flatten()
|
| 280 |
+
)
|
| 281 |
+
accumulate(coco_eval, use_self_eval=self.use_self_evaluate)
|
| 282 |
+
coco_eval.evalImgs = old_img_evals
|
| 283 |
+
coco_eval.params.imgIds = old_img_ids
|
| 284 |
+
|
| 285 |
+
def summarize(self):
|
| 286 |
+
self._lazy_init()
|
| 287 |
+
logging.info("Coco evaluator: Summarizing")
|
| 288 |
+
if not is_main_process():
|
| 289 |
+
return {}
|
| 290 |
+
|
| 291 |
+
outs = {}
|
| 292 |
+
if self.rarity_buckets is None:
|
| 293 |
+
self.accumulate(self.eval_img_ids)
|
| 294 |
+
for iou_type, coco_eval in self.coco_evals[0].items():
|
| 295 |
+
print("IoU metric: {}".format(iou_type))
|
| 296 |
+
summarize(coco_eval)
|
| 297 |
+
|
| 298 |
+
if "bbox" in self.coco_evals[0]:
|
| 299 |
+
for key, value in zip(*self.coco_evals[0]["bbox"].stats):
|
| 300 |
+
outs[f"coco_eval_bbox_{key}"] = value
|
| 301 |
+
if "segm" in self.coco_evals[0]:
|
| 302 |
+
for key, value in zip(*self.coco_evals[0]["segm"].stats):
|
| 303 |
+
outs[f"coco_eval_masks_{key}"] = value
|
| 304 |
+
else:
|
| 305 |
+
total_stats = {}
|
| 306 |
+
all_keys = {}
|
| 307 |
+
for bucket, img_list in self.rarity_buckets.items():
|
| 308 |
+
self.accumulate(imgIds=img_list)
|
| 309 |
+
bucket_name = RARITY_BUCKETS[bucket]
|
| 310 |
+
for iou_type, coco_eval in self.coco_evals[0].items():
|
| 311 |
+
print(f"IoU metric: {iou_type}. Rarity bucket: {bucket_name}")
|
| 312 |
+
summarize(coco_eval)
|
| 313 |
+
|
| 314 |
+
if "bbox" in self.coco_evals[0]:
|
| 315 |
+
if "bbox" not in total_stats:
|
| 316 |
+
total_stats["bbox"] = np.zeros_like(
|
| 317 |
+
self.coco_evals[0]["bbox"].stats[1]
|
| 318 |
+
)
|
| 319 |
+
all_keys["bbox"] = self.coco_evals[0]["bbox"].stats[0]
|
| 320 |
+
total_stats["bbox"] += self.coco_evals[0]["bbox"].stats[1]
|
| 321 |
+
for key, value in zip(*self.coco_evals[0]["bbox"].stats):
|
| 322 |
+
outs[f"coco_eval_bbox_{bucket_name}_{key}"] = value
|
| 323 |
+
if "segm" in self.coco_evals[0]:
|
| 324 |
+
if "segm" not in total_stats:
|
| 325 |
+
total_stats["segm"] = np.zeros_like(
|
| 326 |
+
self.coco_evals[0]["segm"].stats[1]
|
| 327 |
+
)
|
| 328 |
+
all_keys["segm"] = self.coco_evals[0]["segm"].stats[0]
|
| 329 |
+
total_stats["segm"] += self.coco_evals[0]["segm"].stats[1]
|
| 330 |
+
for key, value in zip(*self.coco_evals[0]["segm"].stats):
|
| 331 |
+
outs[f"coco_eval_masks_{bucket_name}_{key}"] = value
|
| 332 |
+
|
| 333 |
+
if "bbox" in total_stats:
|
| 334 |
+
total_stats["bbox"] /= len(self.rarity_buckets)
|
| 335 |
+
for key, value in zip(all_keys["bbox"], total_stats["bbox"]):
|
| 336 |
+
outs[f"coco_eval_bbox_{key}"] = value
|
| 337 |
+
if "segm" in total_stats:
|
| 338 |
+
total_stats["segm"] /= len(self.rarity_buckets)
|
| 339 |
+
for key, value in zip(all_keys["segm"], total_stats["segm"]):
|
| 340 |
+
outs[f"coco_eval_masks_{key}"] = value
|
| 341 |
+
|
| 342 |
+
# if self.dump is not None:
|
| 343 |
+
# assert self.dump_dir is not None
|
| 344 |
+
# logging.info("Coco evaluator: Dumping the global result file to disk")
|
| 345 |
+
# with g_pathmgr.open(str(Path(self.dump_dir) / "coco_eval.json"), "w") as f:
|
| 346 |
+
# json.dump(self.dump, f)
|
| 347 |
+
return outs
|
| 348 |
+
|
| 349 |
+
def compute_synced(self):
|
| 350 |
+
self._lazy_init()
|
| 351 |
+
self.synchronize_between_processes()
|
| 352 |
+
return self.summarize()
|
| 353 |
+
|
| 354 |
+
def compute(self):
|
| 355 |
+
self._lazy_init()
|
| 356 |
+
return {"": 0.0}
|
| 357 |
+
|
| 358 |
+
def reset(self, cocoeval_cls=COCOeval):
|
| 359 |
+
self.coco_evals = [{} for _ in range(len(self.coco_gts))]
|
| 360 |
+
for i, coco_gt in enumerate(self.coco_gts):
|
| 361 |
+
for iou_type in self.iou_types:
|
| 362 |
+
self.coco_evals[i][iou_type] = cocoeval_cls(coco_gt, iouType=iou_type)
|
| 363 |
+
self.coco_evals[i][iou_type].params.useCats = self.useCats
|
| 364 |
+
self.coco_evals[i][iou_type].params.maxDets = self.maxdets
|
| 365 |
+
if self.use_normalized_areas:
|
| 366 |
+
self.coco_evals[i][iou_type].params.areaRng = [
|
| 367 |
+
[0, 1e5],
|
| 368 |
+
[0, 0.001],
|
| 369 |
+
[0.001, 0.01],
|
| 370 |
+
[0.01, 0.1],
|
| 371 |
+
[0.1, 0.5],
|
| 372 |
+
[0.5, 0.95],
|
| 373 |
+
[0.95, 1e5],
|
| 374 |
+
]
|
| 375 |
+
self.coco_evals[i][iou_type].params.areaRngLbl = [
|
| 376 |
+
"all",
|
| 377 |
+
"tiny",
|
| 378 |
+
"small",
|
| 379 |
+
"medium",
|
| 380 |
+
"large",
|
| 381 |
+
"huge",
|
| 382 |
+
"whole_image",
|
| 383 |
+
]
|
| 384 |
+
|
| 385 |
+
self.img_ids = []
|
| 386 |
+
self.eval_imgs = {k: [] for k in self.iou_types}
|
| 387 |
+
if self.dump is not None:
|
| 388 |
+
self.dump = []
|
| 389 |
+
|
| 390 |
+
def write(self, stats):
|
| 391 |
+
self._lazy_init()
|
| 392 |
+
"""Write the results in the stats dict"""
|
| 393 |
+
if "bbox" in self.coco_evals[0]:
|
| 394 |
+
stats["coco_eval_bbox"] = self.coco_evals[0]["bbox"].stats.tolist()
|
| 395 |
+
if "segm" in self.coco_evals[0]:
|
| 396 |
+
stats["coco_eval_masks"] = self.coco_evals[0]["segm"].stats.tolist()
|
| 397 |
+
return stats
|
| 398 |
+
|
| 399 |
+
def prepare(self, predictions, iou_type):
|
| 400 |
+
self._lazy_init()
|
| 401 |
+
if iou_type == "bbox":
|
| 402 |
+
return self.prepare_for_coco_detection(predictions)
|
| 403 |
+
elif iou_type == "segm":
|
| 404 |
+
return self.prepare_for_coco_segmentation(predictions)
|
| 405 |
+
elif iou_type == "keypoints":
|
| 406 |
+
return self.prepare_for_coco_keypoint(predictions)
|
| 407 |
+
else:
|
| 408 |
+
raise ValueError("Unknown iou type {}".format(iou_type))
|
| 409 |
+
|
| 410 |
+
def prepare_for_coco_detection(self, predictions):
|
| 411 |
+
self._lazy_init()
|
| 412 |
+
coco_results = []
|
| 413 |
+
for original_id, prediction in predictions.items():
|
| 414 |
+
if len(prediction) == 0:
|
| 415 |
+
continue
|
| 416 |
+
|
| 417 |
+
boxes = prediction["boxes"]
|
| 418 |
+
boxes = convert_to_xywh(boxes).tolist()
|
| 419 |
+
scores = prediction["scores"].tolist()
|
| 420 |
+
labels = prediction["labels"].tolist()
|
| 421 |
+
|
| 422 |
+
coco_results.extend(
|
| 423 |
+
[
|
| 424 |
+
{
|
| 425 |
+
"image_id": original_id,
|
| 426 |
+
"category_id": labels[k],
|
| 427 |
+
"bbox": box,
|
| 428 |
+
"score": scores[k],
|
| 429 |
+
}
|
| 430 |
+
for k, box in enumerate(boxes)
|
| 431 |
+
]
|
| 432 |
+
)
|
| 433 |
+
return coco_results
|
| 434 |
+
|
| 435 |
+
@torch.no_grad()
|
| 436 |
+
def prepare_for_coco_segmentation(self, predictions):
|
| 437 |
+
self._lazy_init()
|
| 438 |
+
coco_results = []
|
| 439 |
+
for original_id, prediction in predictions.items():
|
| 440 |
+
if len(prediction) == 0:
|
| 441 |
+
continue
|
| 442 |
+
|
| 443 |
+
scores = prediction["scores"].tolist()
|
| 444 |
+
labels = prediction["labels"].tolist()
|
| 445 |
+
boundaries, dilated_boundaries = None, None
|
| 446 |
+
if "boundaries" in prediction:
|
| 447 |
+
boundaries = prediction["boundaries"]
|
| 448 |
+
dilated_boundaries = prediction["dilated_boundaries"]
|
| 449 |
+
assert dilated_boundaries is not None
|
| 450 |
+
assert len(scores) == len(boundaries)
|
| 451 |
+
|
| 452 |
+
if "masks_rle" in prediction:
|
| 453 |
+
rles = prediction["masks_rle"]
|
| 454 |
+
areas = []
|
| 455 |
+
for rle in rles:
|
| 456 |
+
cur_area = mask_utils.area(rle)
|
| 457 |
+
h, w = rle["size"]
|
| 458 |
+
areas.append(cur_area / (h * w))
|
| 459 |
+
else:
|
| 460 |
+
masks = prediction["masks"]
|
| 461 |
+
|
| 462 |
+
masks = masks > 0.5
|
| 463 |
+
h, w = masks.shape[-2:]
|
| 464 |
+
|
| 465 |
+
areas = masks.flatten(1).sum(1) / (h * w)
|
| 466 |
+
areas = areas.tolist()
|
| 467 |
+
|
| 468 |
+
rles = rle_encode(masks.squeeze(1))
|
| 469 |
+
|
| 470 |
+
# memory clean
|
| 471 |
+
del masks
|
| 472 |
+
del prediction["masks"]
|
| 473 |
+
|
| 474 |
+
assert len(areas) == len(rles) == len(scores)
|
| 475 |
+
for k, rle in enumerate(rles):
|
| 476 |
+
payload = {
|
| 477 |
+
"image_id": original_id,
|
| 478 |
+
"category_id": labels[k],
|
| 479 |
+
"segmentation": rle,
|
| 480 |
+
"score": scores[k],
|
| 481 |
+
"area": areas[k],
|
| 482 |
+
}
|
| 483 |
+
if boundaries is not None:
|
| 484 |
+
payload["boundary"] = boundaries[k]
|
| 485 |
+
payload["dilated_boundary"] = dilated_boundaries[k]
|
| 486 |
+
|
| 487 |
+
coco_results.append(payload)
|
| 488 |
+
|
| 489 |
+
return coco_results
|
| 490 |
+
|
| 491 |
+
def prepare_for_coco_keypoint(self, predictions):
|
| 492 |
+
self._lazy_init()
|
| 493 |
+
coco_results = []
|
| 494 |
+
for original_id, prediction in predictions.items():
|
| 495 |
+
if len(prediction) == 0:
|
| 496 |
+
continue
|
| 497 |
+
|
| 498 |
+
boxes = prediction["boxes"]
|
| 499 |
+
boxes = convert_to_xywh(boxes).tolist()
|
| 500 |
+
scores = prediction["scores"].tolist()
|
| 501 |
+
labels = prediction["labels"].tolist()
|
| 502 |
+
keypoints = prediction["keypoints"]
|
| 503 |
+
keypoints = keypoints.flatten(start_dim=1).tolist()
|
| 504 |
+
|
| 505 |
+
coco_results.extend(
|
| 506 |
+
[
|
| 507 |
+
{
|
| 508 |
+
"image_id": original_id,
|
| 509 |
+
"category_id": labels[k],
|
| 510 |
+
"keypoints": keypoint,
|
| 511 |
+
"score": scores[k],
|
| 512 |
+
}
|
| 513 |
+
for k, keypoint in enumerate(keypoints)
|
| 514 |
+
]
|
| 515 |
+
)
|
| 516 |
+
return coco_results
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def convert_to_xywh(boxes):
|
| 520 |
+
xmin, ymin, xmax, ymax = boxes.unbind(-1)
|
| 521 |
+
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=-1)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def merge(img_ids, eval_imgs, gather_pred_via_filesys=False):
|
| 525 |
+
if gather_pred_via_filesys:
|
| 526 |
+
# only gather the predictions to rank 0 (other ranks will receive empty
|
| 527 |
+
# lists for `all_img_ids` and `all_eval_imgs`, which should be OK as
|
| 528 |
+
# merging and evaluation are only done on rank 0)
|
| 529 |
+
all_img_ids = gather_to_rank_0_via_filesys(img_ids)
|
| 530 |
+
all_eval_imgs = gather_to_rank_0_via_filesys(eval_imgs)
|
| 531 |
+
else:
|
| 532 |
+
all_img_ids = all_gather(img_ids, force_cpu=True)
|
| 533 |
+
all_eval_imgs = all_gather(eval_imgs, force_cpu=True)
|
| 534 |
+
if not is_main_process():
|
| 535 |
+
return None, None
|
| 536 |
+
|
| 537 |
+
merged_img_ids = []
|
| 538 |
+
for p in all_img_ids:
|
| 539 |
+
merged_img_ids.extend(p)
|
| 540 |
+
|
| 541 |
+
merged_eval_imgs = []
|
| 542 |
+
for p in all_eval_imgs:
|
| 543 |
+
merged_eval_imgs.append(p)
|
| 544 |
+
|
| 545 |
+
merged_img_ids = np.array(merged_img_ids)
|
| 546 |
+
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
|
| 547 |
+
|
| 548 |
+
# keep only unique (and in sorted order) images
|
| 549 |
+
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
|
| 550 |
+
merged_eval_imgs = merged_eval_imgs[..., idx]
|
| 551 |
+
|
| 552 |
+
return merged_img_ids, merged_eval_imgs
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def create_common_coco_eval(
|
| 556 |
+
coco_eval,
|
| 557 |
+
img_ids,
|
| 558 |
+
eval_imgs,
|
| 559 |
+
use_self_evaluate,
|
| 560 |
+
gather_pred_via_filesys=False,
|
| 561 |
+
metrics_dump_dir=None,
|
| 562 |
+
):
|
| 563 |
+
img_ids, eval_imgs = merge(img_ids, eval_imgs, gather_pred_via_filesys)
|
| 564 |
+
if not is_main_process():
|
| 565 |
+
return
|
| 566 |
+
if metrics_dump_dir is not None:
|
| 567 |
+
dumped_file = (
|
| 568 |
+
Path(metrics_dump_dir) / f"coco_eval_img_metrics_{get_rank()}.json"
|
| 569 |
+
)
|
| 570 |
+
logging.info(f"COCO evaluator: Dumping local predictions to {dumped_file}")
|
| 571 |
+
with g_pathmgr.open(str(dumped_file), "w") as f:
|
| 572 |
+
json.dump(eval_imgs.squeeze(), f, default=lambda x: x.tolist())
|
| 573 |
+
img_ids = list(img_ids)
|
| 574 |
+
|
| 575 |
+
# If some images were not predicted, we need to create dummy detections for them
|
| 576 |
+
missing_img_ids = set(coco_eval.cocoGt.getImgIds()) - set(img_ids)
|
| 577 |
+
if len(missing_img_ids) > 0:
|
| 578 |
+
print(f"WARNING: {len(missing_img_ids)} images were not predicted!")
|
| 579 |
+
coco_eval.cocoDt = COCO()
|
| 580 |
+
coco_eval.params.imgIds = list(missing_img_ids)
|
| 581 |
+
new_img_ids, new_eval_imgs = evaluate(coco_eval, use_self_evaluate)
|
| 582 |
+
img_ids.extend(new_img_ids)
|
| 583 |
+
eval_imgs = np.concatenate((eval_imgs, new_eval_imgs), axis=2)
|
| 584 |
+
|
| 585 |
+
eval_imgs = list(eval_imgs.flatten())
|
| 586 |
+
assert len(img_ids) == len(coco_eval.cocoGt.getImgIds())
|
| 587 |
+
|
| 588 |
+
coco_eval.evalImgs = eval_imgs
|
| 589 |
+
coco_eval.params.imgIds = img_ids
|
| 590 |
+
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
#################################################################
|
| 594 |
+
# From pycocotools, just removed the prints and fixed
|
| 595 |
+
# a Python3 bug about unicode not defined
|
| 596 |
+
#################################################################
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
# Copy of COCO prepare, but doesn't convert anntoRLE
|
| 600 |
+
def segmentation_prepare(self):
|
| 601 |
+
"""
|
| 602 |
+
Prepare ._gts and ._dts for evaluation based on params
|
| 603 |
+
:return: None
|
| 604 |
+
"""
|
| 605 |
+
p = self.params
|
| 606 |
+
if p.useCats:
|
| 607 |
+
gts = self.cocoGt.loadAnns(
|
| 608 |
+
self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
|
| 609 |
+
)
|
| 610 |
+
dts = self.cocoDt.loadAnns(
|
| 611 |
+
self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
|
| 612 |
+
)
|
| 613 |
+
else:
|
| 614 |
+
gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
|
| 615 |
+
dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
|
| 616 |
+
|
| 617 |
+
for gt in gts:
|
| 618 |
+
gt["ignore"] = gt["ignore"] if "ignore" in gt else 0
|
| 619 |
+
gt["ignore"] = "iscrowd" in gt and gt["iscrowd"]
|
| 620 |
+
if p.iouType == "keypoints":
|
| 621 |
+
gt["ignore"] = (gt["num_keypoints"] == 0) or gt["ignore"]
|
| 622 |
+
self._gts = defaultdict(list) # gt for evaluation
|
| 623 |
+
self._dts = defaultdict(list) # dt for evaluation
|
| 624 |
+
for gt in gts:
|
| 625 |
+
self._gts[gt["image_id"], gt["category_id"]].append(gt)
|
| 626 |
+
for dt in dts:
|
| 627 |
+
self._dts[dt["image_id"], dt["category_id"]].append(dt)
|
| 628 |
+
self.evalImgs = defaultdict(list) # per-image per-category evaluation results
|
| 629 |
+
self.eval = {} # accumulated evaluation results
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
def evaluate(self, use_self_evaluate):
|
| 633 |
+
"""
|
| 634 |
+
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
|
| 635 |
+
:return: None
|
| 636 |
+
"""
|
| 637 |
+
# tic = time.time()
|
| 638 |
+
# print('Running per image evaluation...', use_self_evaluate)
|
| 639 |
+
p = self.params
|
| 640 |
+
# add backward compatibility if useSegm is specified in params
|
| 641 |
+
if p.useSegm is not None:
|
| 642 |
+
p.iouType = "segm" if p.useSegm == 1 else "bbox"
|
| 643 |
+
print(
|
| 644 |
+
"useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType)
|
| 645 |
+
)
|
| 646 |
+
# print('Evaluate annotation type *{}*'.format(p.iouType))
|
| 647 |
+
p.imgIds = list(np.unique(p.imgIds))
|
| 648 |
+
if p.useCats:
|
| 649 |
+
p.catIds = list(np.unique(p.catIds))
|
| 650 |
+
p.maxDets = sorted(p.maxDets)
|
| 651 |
+
self.params = p
|
| 652 |
+
|
| 653 |
+
self._prepare()
|
| 654 |
+
# loop through images, area range, max detection number
|
| 655 |
+
catIds = p.catIds if p.useCats else [-1]
|
| 656 |
+
|
| 657 |
+
if p.iouType == "segm" or p.iouType == "bbox":
|
| 658 |
+
computeIoU = self.computeIoU
|
| 659 |
+
elif p.iouType == "keypoints":
|
| 660 |
+
computeIoU = self.computeOks
|
| 661 |
+
self.ious = {
|
| 662 |
+
(imgId, catId): computeIoU(imgId, catId)
|
| 663 |
+
for imgId in p.imgIds
|
| 664 |
+
for catId in catIds
|
| 665 |
+
}
|
| 666 |
+
|
| 667 |
+
maxDet = p.maxDets[-1]
|
| 668 |
+
if use_self_evaluate:
|
| 669 |
+
evalImgs = [
|
| 670 |
+
self.evaluateImg(imgId, catId, areaRng, maxDet)
|
| 671 |
+
for catId in catIds
|
| 672 |
+
for areaRng in p.areaRng
|
| 673 |
+
for imgId in p.imgIds
|
| 674 |
+
]
|
| 675 |
+
# this is NOT in the pycocotools code, but could be done outside
|
| 676 |
+
evalImgs = np.asarray(evalImgs).reshape(
|
| 677 |
+
len(catIds), len(p.areaRng), len(p.imgIds)
|
| 678 |
+
)
|
| 679 |
+
return p.imgIds, evalImgs
|
| 680 |
+
|
| 681 |
+
# <<<< Beginning of code differences with original COCO API
|
| 682 |
+
# def convert_instances_to_cpp(instances, is_det=False):
|
| 683 |
+
# # Convert annotations for a list of instances in an image to a format that's fast
|
| 684 |
+
# # to access in C++
|
| 685 |
+
# instances_cpp = []
|
| 686 |
+
# for instance in instances:
|
| 687 |
+
# instance_cpp = _CPP.InstanceAnnotation(
|
| 688 |
+
# int(instance["id"]),
|
| 689 |
+
# instance["score"] if is_det else instance.get("score", 0.0),
|
| 690 |
+
# instance["area"],
|
| 691 |
+
# bool(instance.get("iscrowd", 0)),
|
| 692 |
+
# bool(instance.get("ignore", 0)),
|
| 693 |
+
# )
|
| 694 |
+
# instances_cpp.append(instance_cpp)
|
| 695 |
+
# return instances_cpp
|
| 696 |
+
|
| 697 |
+
# # Convert GT annotations, detections, and IOUs to a format that's fast to access in C++
|
| 698 |
+
# ground_truth_instances = [
|
| 699 |
+
# [convert_instances_to_cpp(self._gts[imgId, catId]) for catId in p.catIds]
|
| 700 |
+
# for imgId in p.imgIds
|
| 701 |
+
# ]
|
| 702 |
+
# detected_instances = [
|
| 703 |
+
# [
|
| 704 |
+
# convert_instances_to_cpp(self._dts[imgId, catId], is_det=True)
|
| 705 |
+
# for catId in p.catIds
|
| 706 |
+
# ]
|
| 707 |
+
# for imgId in p.imgIds
|
| 708 |
+
# ]
|
| 709 |
+
# ious = [[self.ious[imgId, catId] for catId in catIds] for imgId in p.imgIds]
|
| 710 |
+
|
| 711 |
+
# if not p.useCats:
|
| 712 |
+
# # For each image, flatten per-category lists into a single list
|
| 713 |
+
# ground_truth_instances = [
|
| 714 |
+
# [[o for c in i for o in c]] for i in ground_truth_instances
|
| 715 |
+
# ]
|
| 716 |
+
# detected_instances = [[[o for c in i for o in c]] for i in detected_instances]
|
| 717 |
+
|
| 718 |
+
# # Call C++ implementation of self.evaluateImgs()
|
| 719 |
+
# _evalImgs_cpp = _CPP.COCOevalEvaluateImages(
|
| 720 |
+
# p.areaRng, maxDet, p.iouThrs, ious, ground_truth_instances, detected_instances
|
| 721 |
+
# )
|
| 722 |
+
|
| 723 |
+
# self._paramsEval = copy.deepcopy(self.params)
|
| 724 |
+
# evalImgs = np.asarray(_evalImgs_cpp).reshape(
|
| 725 |
+
# len(catIds), len(p.areaRng), len(p.imgIds)
|
| 726 |
+
# )
|
| 727 |
+
# return p.imgIds, evalImgs
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
#################################################################
|
| 731 |
+
# end of straight copy from pycocotools, just removing the prints
|
| 732 |
+
#################################################################
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
#################################################################
|
| 736 |
+
# From pycocotools, but disabled mask->box conversion which is
|
| 737 |
+
# pointless
|
| 738 |
+
#################################################################
|
| 739 |
+
def loadRes(self, resFile):
|
| 740 |
+
"""
|
| 741 |
+
Load result file and return a result api object.
|
| 742 |
+
:param resFile (str) : file name of result file
|
| 743 |
+
:return: res (obj) : result api object
|
| 744 |
+
"""
|
| 745 |
+
res = COCO()
|
| 746 |
+
res.dataset["images"] = [img for img in self.dataset["images"]]
|
| 747 |
+
|
| 748 |
+
if type(resFile) == str:
|
| 749 |
+
anns = json.load(open(resFile))
|
| 750 |
+
elif type(resFile) == np.ndarray:
|
| 751 |
+
anns = self.loadNumpyAnnotations(resFile)
|
| 752 |
+
else:
|
| 753 |
+
anns = resFile
|
| 754 |
+
assert type(anns) == list, "results in not an array of objects"
|
| 755 |
+
annsImgIds = [ann["image_id"] for ann in anns]
|
| 756 |
+
assert set(annsImgIds) == (
|
| 757 |
+
set(annsImgIds) & set(self.getImgIds())
|
| 758 |
+
), "Results do not correspond to current coco set"
|
| 759 |
+
if "caption" in anns[0]:
|
| 760 |
+
imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
|
| 761 |
+
[ann["image_id"] for ann in anns]
|
| 762 |
+
)
|
| 763 |
+
res.dataset["images"] = [
|
| 764 |
+
img for img in res.dataset["images"] if img["id"] in imgIds
|
| 765 |
+
]
|
| 766 |
+
for id, ann in enumerate(anns):
|
| 767 |
+
ann["id"] = id + 1
|
| 768 |
+
elif "bbox" in anns[0] and not anns[0]["bbox"] == []:
|
| 769 |
+
res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
|
| 770 |
+
for id, ann in enumerate(anns):
|
| 771 |
+
bb = ann["bbox"]
|
| 772 |
+
x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
|
| 773 |
+
if "segmentation" not in ann:
|
| 774 |
+
ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
|
| 775 |
+
ann["area"] = bb[2] * bb[3]
|
| 776 |
+
ann["id"] = id + 1
|
| 777 |
+
ann["iscrowd"] = 0
|
| 778 |
+
elif "segmentation" in anns[0]:
|
| 779 |
+
res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
|
| 780 |
+
for id, ann in enumerate(anns):
|
| 781 |
+
# now only support compressed RLE format as segmentation results
|
| 782 |
+
# ann["area"] = mask_util.area(ann["segmentation"])
|
| 783 |
+
# The following lines are disabled because they are pointless
|
| 784 |
+
# if not 'bbox' in ann:
|
| 785 |
+
# ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
|
| 786 |
+
ann["id"] = id + 1
|
| 787 |
+
ann["iscrowd"] = 0
|
| 788 |
+
elif "keypoints" in anns[0]:
|
| 789 |
+
res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
|
| 790 |
+
for id, ann in enumerate(anns):
|
| 791 |
+
s = ann["keypoints"]
|
| 792 |
+
x = s[0::3]
|
| 793 |
+
y = s[1::3]
|
| 794 |
+
x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
|
| 795 |
+
ann["area"] = (x1 - x0) * (y1 - y0)
|
| 796 |
+
ann["id"] = id + 1
|
| 797 |
+
ann["bbox"] = [x0, y0, x1 - x0, y1 - y0]
|
| 798 |
+
|
| 799 |
+
res.dataset["annotations"] = anns
|
| 800 |
+
res.createIndex()
|
| 801 |
+
return res
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
#################################################################
|
| 805 |
+
# end of straight copy from pycocotools
|
| 806 |
+
#################################################################
|
| 807 |
+
|
| 808 |
+
|
| 809 |
+
#################################################################
|
| 810 |
+
# From pycocotools, but added handling of custom area rngs, and returns stat keys
|
| 811 |
+
#################################################################
|
| 812 |
+
def summarize(self):
|
| 813 |
+
"""
|
| 814 |
+
Compute and display summary metrics for evaluation results.
|
| 815 |
+
Note this functin can *only* be applied on the default parameter setting
|
| 816 |
+
"""
|
| 817 |
+
|
| 818 |
+
def _summarize(ap=1, iouThr=None, areaRng="all", maxDets=100):
|
| 819 |
+
p = self.params
|
| 820 |
+
iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
|
| 821 |
+
titleStr = "Average Precision" if ap == 1 else "Average Recall"
|
| 822 |
+
typeStr = "(AP)" if ap == 1 else "(AR)"
|
| 823 |
+
iouStr = (
|
| 824 |
+
"{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
|
| 825 |
+
if iouThr is None
|
| 826 |
+
else "{:0.2f}".format(iouThr)
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
|
| 830 |
+
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
|
| 831 |
+
if ap == 1:
|
| 832 |
+
# dimension of precision: [TxRxKxAxM]
|
| 833 |
+
s = self.eval["precision"]
|
| 834 |
+
# IoU
|
| 835 |
+
if iouThr is not None:
|
| 836 |
+
t = np.where(iouThr == p.iouThrs)[0]
|
| 837 |
+
s = s[t]
|
| 838 |
+
s = s[:, :, :, aind, mind]
|
| 839 |
+
else:
|
| 840 |
+
# dimension of recall: [TxKxAxM]
|
| 841 |
+
s = self.eval["recall"]
|
| 842 |
+
if iouThr is not None:
|
| 843 |
+
t = np.where(iouThr == p.iouThrs)[0]
|
| 844 |
+
s = s[t]
|
| 845 |
+
s = s[:, :, aind, mind]
|
| 846 |
+
if len(s[s > -1]) == 0:
|
| 847 |
+
mean_s = -1
|
| 848 |
+
else:
|
| 849 |
+
mean_s = np.mean(s[s > -1])
|
| 850 |
+
print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
|
| 851 |
+
return mean_s
|
| 852 |
+
|
| 853 |
+
def _summarizeDets():
|
| 854 |
+
nb_results = 6 + (len(self.params.areaRng) - 1) * 2
|
| 855 |
+
assert len(self.params.areaRng) == len(self.params.areaRngLbl)
|
| 856 |
+
stats = np.zeros((nb_results,))
|
| 857 |
+
keys = ["AP", "AP_50", "AP_75"]
|
| 858 |
+
stats[0] = _summarize(1, maxDets=self.params.maxDets[2])
|
| 859 |
+
stats[1] = _summarize(1, iouThr=0.5, maxDets=self.params.maxDets[2])
|
| 860 |
+
stats[2] = _summarize(1, iouThr=0.75, maxDets=self.params.maxDets[2])
|
| 861 |
+
cur_id = 3
|
| 862 |
+
for area in self.params.areaRngLbl[1:]:
|
| 863 |
+
stats[cur_id] = _summarize(1, areaRng=area, maxDets=self.params.maxDets[2])
|
| 864 |
+
cur_id += 1
|
| 865 |
+
keys.append(f"AP_{area}")
|
| 866 |
+
stats[cur_id] = _summarize(0, maxDets=self.params.maxDets[0])
|
| 867 |
+
cur_id += 1
|
| 868 |
+
stats[cur_id] = _summarize(0, maxDets=self.params.maxDets[1])
|
| 869 |
+
cur_id += 1
|
| 870 |
+
stats[cur_id] = _summarize(0, maxDets=self.params.maxDets[2])
|
| 871 |
+
cur_id += 1
|
| 872 |
+
keys += ["AR", "AR_50", "AR_75"]
|
| 873 |
+
|
| 874 |
+
for area in self.params.areaRngLbl[1:]:
|
| 875 |
+
stats[cur_id] = _summarize(0, areaRng=area, maxDets=self.params.maxDets[2])
|
| 876 |
+
cur_id += 1
|
| 877 |
+
keys.append(f"AR_{area}")
|
| 878 |
+
assert len(stats) == len(keys)
|
| 879 |
+
return keys, stats
|
| 880 |
+
|
| 881 |
+
if not self.eval:
|
| 882 |
+
raise Exception("Please run accumulate() first")
|
| 883 |
+
self.stats = _summarizeDets()
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
#################################################################
|
| 887 |
+
# end of straight copy from pycocotools
|
| 888 |
+
#################################################################
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
#################################################################
|
| 892 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/evaluation/fast_eval_api.py
|
| 893 |
+
# with slight adjustments
|
| 894 |
+
#################################################################
|
| 895 |
+
def accumulate(self, use_self_eval=False):
|
| 896 |
+
"""
|
| 897 |
+
Accumulate per image evaluation results and store the result in self.eval. Does not
|
| 898 |
+
support changing parameter settings from those used by self.evaluate()
|
| 899 |
+
"""
|
| 900 |
+
if use_self_eval:
|
| 901 |
+
self.accumulate()
|
| 902 |
+
return
|
| 903 |
+
# CPP code is disabled
|
| 904 |
+
# self.eval = _CPP.COCOevalAccumulate(self.params, self.evalImgs)
|
| 905 |
+
|
| 906 |
+
# # recall is num_iou_thresholds X num_categories X num_area_ranges X num_max_detections
|
| 907 |
+
# self.eval["recall"] = np.array(self.eval["recall"]).reshape(
|
| 908 |
+
# self.eval["counts"][:1] + self.eval["counts"][2:]
|
| 909 |
+
# )
|
| 910 |
+
|
| 911 |
+
# # precision and scores are num_iou_thresholds X num_recall_thresholds X num_categories X
|
| 912 |
+
# # num_area_ranges X num_max_detections
|
| 913 |
+
# self.eval["precision"] = np.array(self.eval["precision"]).reshape(
|
| 914 |
+
# self.eval["counts"]
|
| 915 |
+
# )
|
| 916 |
+
# self.eval["scores"] = np.array(self.eval["scores"]).reshape(self.eval["counts"])
|
sam3/eval/coco_eval_offline.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
This evaluator is meant for regular COCO mAP evaluation, for example on the COCO val set.
|
| 5 |
+
|
| 6 |
+
For Category mAP, we need the model to make predictions for all the categories on every single image.
|
| 7 |
+
In general, since the number of classes can be big, and the API model makes predictions individually for each pair (image, class),
|
| 8 |
+
we may need to split the inference process for a given image in several chunks.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from pycocotools.coco import COCO
|
| 16 |
+
from pycocotools.cocoeval import COCOeval
|
| 17 |
+
from sam3.train.utils.distributed import is_main_process
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from tidecv import datasets, TIDE
|
| 21 |
+
|
| 22 |
+
HAS_TIDE = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
HAS_TIDE = False
|
| 25 |
+
print("WARNING: TIDE not installed. Detailed analysis will not be available.")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# the COCO detection metrics (https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/cocoeval.py#L460-L471)
|
| 29 |
+
COCO_METRICS = [
|
| 30 |
+
"AP",
|
| 31 |
+
"AP_50",
|
| 32 |
+
"AP_75",
|
| 33 |
+
"AP_small",
|
| 34 |
+
"AP_medium",
|
| 35 |
+
"AP_large",
|
| 36 |
+
"AR_maxDets@1",
|
| 37 |
+
"AR_maxDets@10",
|
| 38 |
+
"AR_maxDets@100",
|
| 39 |
+
"AR_small",
|
| 40 |
+
"AR_medium",
|
| 41 |
+
"AR_large",
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def convert_to_xywh(boxes):
|
| 46 |
+
"""Convert bounding boxes from xyxy format to xywh format."""
|
| 47 |
+
xmin, ymin, xmax, ymax = boxes.unbind(-1)
|
| 48 |
+
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=-1)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class HeapElement:
|
| 52 |
+
"""Utility class to make a heap with a custom comparator"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, val):
|
| 55 |
+
self.val = val
|
| 56 |
+
|
| 57 |
+
def __lt__(self, other):
|
| 58 |
+
return self.val["score"] < other.val["score"]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class COCOevalCustom(COCOeval):
|
| 62 |
+
"""
|
| 63 |
+
This is a slightly modified version of the original COCO API with added support for positive split evaluation.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self, cocoGt=None, cocoDt=None, iouType="segm", dt_only_positive=False
|
| 68 |
+
):
|
| 69 |
+
super().__init__(cocoGt, cocoDt, iouType)
|
| 70 |
+
self.dt_only_positive = dt_only_positive
|
| 71 |
+
|
| 72 |
+
def _prepare(self):
|
| 73 |
+
"""
|
| 74 |
+
Prepare ._gts and ._dts for evaluation based on params
|
| 75 |
+
:return: None
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def _toMask(anns, coco):
|
| 79 |
+
# modify ann['segmentation'] by reference
|
| 80 |
+
for ann in anns:
|
| 81 |
+
rle = coco.annToRLE(ann)
|
| 82 |
+
ann["segmentation"] = rle
|
| 83 |
+
|
| 84 |
+
p = self.params
|
| 85 |
+
if p.useCats:
|
| 86 |
+
gts = self.cocoGt.loadAnns(
|
| 87 |
+
self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
|
| 88 |
+
)
|
| 89 |
+
dts = self.cocoDt.loadAnns(
|
| 90 |
+
self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
|
| 94 |
+
dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
|
| 95 |
+
|
| 96 |
+
# convert ground truth to mask if iouType == 'segm'
|
| 97 |
+
if p.iouType == "segm":
|
| 98 |
+
_toMask(gts, self.cocoGt)
|
| 99 |
+
_toMask(dts, self.cocoDt)
|
| 100 |
+
# set ignore flag
|
| 101 |
+
for gt in gts:
|
| 102 |
+
gt["ignore"] = gt["ignore"] if "ignore" in gt else 0
|
| 103 |
+
gt["ignore"] = "iscrowd" in gt and gt["iscrowd"]
|
| 104 |
+
if p.iouType == "keypoints":
|
| 105 |
+
gt["ignore"] = (gt["num_keypoints"] == 0) or gt["ignore"]
|
| 106 |
+
self._gts = defaultdict(list) # gt for evaluation
|
| 107 |
+
self._dts = defaultdict(list) # dt for evaluation
|
| 108 |
+
|
| 109 |
+
_gts_cat_ids = defaultdict(set) # gt for evaluation on positive split
|
| 110 |
+
for gt in gts:
|
| 111 |
+
self._gts[gt["image_id"], gt["category_id"]].append(gt)
|
| 112 |
+
_gts_cat_ids[gt["image_id"]].add(gt["category_id"])
|
| 113 |
+
|
| 114 |
+
#### BEGIN MODIFICATION ####
|
| 115 |
+
for dt in dts:
|
| 116 |
+
if (
|
| 117 |
+
self.dt_only_positive
|
| 118 |
+
and dt["category_id"] not in _gts_cat_ids[dt["image_id"]]
|
| 119 |
+
):
|
| 120 |
+
continue
|
| 121 |
+
self._dts[dt["image_id"], dt["category_id"]].append(dt)
|
| 122 |
+
#### END MODIFICATION ####
|
| 123 |
+
self.evalImgs = defaultdict(list) # per-image per-category evaluation results
|
| 124 |
+
self.eval = {} # accumulated evaluation results
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class CocoEvaluatorOfflineWithPredFileEvaluators:
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
gt_path,
|
| 131 |
+
tide: bool = True,
|
| 132 |
+
iou_type: str = "bbox",
|
| 133 |
+
positive_split=False,
|
| 134 |
+
):
|
| 135 |
+
self.gt_path = gt_path
|
| 136 |
+
self.tide_enabled = HAS_TIDE and tide
|
| 137 |
+
self.positive_split = positive_split
|
| 138 |
+
self.iou_type = iou_type
|
| 139 |
+
|
| 140 |
+
def evaluate(self, dumped_file):
|
| 141 |
+
if not is_main_process():
|
| 142 |
+
return {}
|
| 143 |
+
|
| 144 |
+
logging.info("OfflineCoco evaluator: Loading groundtruth")
|
| 145 |
+
self.gt = COCO(self.gt_path)
|
| 146 |
+
|
| 147 |
+
# Creating the result file
|
| 148 |
+
logging.info("Coco evaluator: Creating the result file")
|
| 149 |
+
cocoDt = self.gt.loadRes(str(dumped_file))
|
| 150 |
+
|
| 151 |
+
# Run the evaluation
|
| 152 |
+
logging.info("Coco evaluator: Running evaluation")
|
| 153 |
+
coco_eval = COCOevalCustom(
|
| 154 |
+
self.gt, cocoDt, iouType=self.iou_type, dt_only_positive=self.positive_split
|
| 155 |
+
)
|
| 156 |
+
coco_eval.evaluate()
|
| 157 |
+
coco_eval.accumulate()
|
| 158 |
+
coco_eval.summarize()
|
| 159 |
+
|
| 160 |
+
outs = {}
|
| 161 |
+
for i, value in enumerate(coco_eval.stats):
|
| 162 |
+
outs[f"coco_eval_{self.iou_type}_{COCO_METRICS[i]}"] = value
|
| 163 |
+
|
| 164 |
+
if self.tide_enabled:
|
| 165 |
+
logging.info("Coco evaluator: Loading TIDE")
|
| 166 |
+
self.tide_gt = datasets.COCO(self.gt_path)
|
| 167 |
+
self.tide = TIDE(mode="mask" if self.iou_type == "segm" else "bbox")
|
| 168 |
+
|
| 169 |
+
# Run TIDE
|
| 170 |
+
logging.info("Coco evaluator: Running TIDE")
|
| 171 |
+
self.tide.evaluate(
|
| 172 |
+
self.tide_gt, datasets.COCOResult(str(dumped_file)), name="coco_eval"
|
| 173 |
+
)
|
| 174 |
+
self.tide.summarize()
|
| 175 |
+
for k, v in self.tide.get_main_errors()["coco_eval"].items():
|
| 176 |
+
outs[f"coco_eval_{self.iou_type}_TIDE_{k}"] = v
|
| 177 |
+
|
| 178 |
+
for k, v in self.tide.get_special_errors()["coco_eval"].items():
|
| 179 |
+
outs[f"coco_eval_{self.iou_type}_TIDE_{k}"] = v
|
| 180 |
+
|
| 181 |
+
return outs
|
sam3/eval/coco_reindex.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Self-contained COCO JSON re-indexing function that creates temporary files.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import tempfile
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def reindex_coco_to_temp(input_json_path: str) -> Optional[str]:
|
| 15 |
+
"""
|
| 16 |
+
Convert 0-indexed COCO JSON file to 1-indexed and save to temporary location.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
input_json_path: Path to the input COCO JSON file
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Path to the new 1-indexed JSON file in temporary directory, or None if no conversion needed
|
| 23 |
+
|
| 24 |
+
Raises:
|
| 25 |
+
FileNotFoundError: If input file doesn't exist
|
| 26 |
+
json.JSONDecodeError: If input file is not valid JSON
|
| 27 |
+
ValueError: If input file is not a valid COCO format
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def is_coco_json(data: Dict[str, Any]) -> bool:
|
| 31 |
+
"""Check if data appears to be a COCO format file."""
|
| 32 |
+
if not isinstance(data, dict):
|
| 33 |
+
return False
|
| 34 |
+
# A COCO file should have at least one of these keys
|
| 35 |
+
coco_keys = {"images", "annotations", "categories"}
|
| 36 |
+
return any(key in data for key in coco_keys)
|
| 37 |
+
|
| 38 |
+
def check_zero_indexed(data: Dict[str, Any]) -> Tuple[bool, bool, bool]:
|
| 39 |
+
"""
|
| 40 |
+
Check if annotations, images, or categories start from index 0.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
Tuple of (annotations_zero_indexed, images_zero_indexed, categories_zero_indexed)
|
| 44 |
+
"""
|
| 45 |
+
annotations_zero = False
|
| 46 |
+
images_zero = False
|
| 47 |
+
categories_zero = False
|
| 48 |
+
|
| 49 |
+
# Check annotations
|
| 50 |
+
annotations = data.get("annotations", [])
|
| 51 |
+
if annotations and any(ann.get("id", -1) == 0 for ann in annotations):
|
| 52 |
+
annotations_zero = True
|
| 53 |
+
|
| 54 |
+
# Check images
|
| 55 |
+
images = data.get("images", [])
|
| 56 |
+
if images and any(img.get("id", -1) == 0 for img in images):
|
| 57 |
+
images_zero = True
|
| 58 |
+
|
| 59 |
+
# Check categories
|
| 60 |
+
categories = data.get("categories", [])
|
| 61 |
+
if categories and any(cat.get("id", -1) == 0 for cat in categories):
|
| 62 |
+
categories_zero = True
|
| 63 |
+
|
| 64 |
+
return annotations_zero, images_zero, categories_zero
|
| 65 |
+
|
| 66 |
+
def reindex_coco_data(data: Dict[str, Any]) -> Dict[str, Any]:
|
| 67 |
+
"""Convert 0-indexed COCO data to 1-indexed."""
|
| 68 |
+
modified_data = data.copy()
|
| 69 |
+
|
| 70 |
+
annotations_zero, images_zero, categories_zero = check_zero_indexed(data)
|
| 71 |
+
|
| 72 |
+
# Create ID mapping for consistency
|
| 73 |
+
image_id_mapping = {}
|
| 74 |
+
category_id_mapping = {}
|
| 75 |
+
|
| 76 |
+
# Process images first (since annotations reference image IDs)
|
| 77 |
+
if images_zero and "images" in modified_data:
|
| 78 |
+
for img in modified_data["images"]:
|
| 79 |
+
old_id = img["id"]
|
| 80 |
+
new_id = old_id + 1
|
| 81 |
+
image_id_mapping[old_id] = new_id
|
| 82 |
+
img["id"] = new_id
|
| 83 |
+
|
| 84 |
+
# Process categories (since annotations reference category IDs)
|
| 85 |
+
if categories_zero and "categories" in modified_data:
|
| 86 |
+
for cat in modified_data["categories"]:
|
| 87 |
+
old_id = cat["id"]
|
| 88 |
+
new_id = old_id + 1
|
| 89 |
+
category_id_mapping[old_id] = new_id
|
| 90 |
+
cat["id"] = new_id
|
| 91 |
+
|
| 92 |
+
# Process annotations
|
| 93 |
+
if "annotations" in modified_data:
|
| 94 |
+
for ann in modified_data["annotations"]:
|
| 95 |
+
# Update annotation ID if needed
|
| 96 |
+
if annotations_zero:
|
| 97 |
+
ann["id"] = ann["id"] + 1
|
| 98 |
+
|
| 99 |
+
# Update image_id reference if images were reindexed
|
| 100 |
+
if images_zero and ann.get("image_id") is not None:
|
| 101 |
+
old_image_id = ann["image_id"]
|
| 102 |
+
if old_image_id in image_id_mapping:
|
| 103 |
+
ann["image_id"] = image_id_mapping[old_image_id]
|
| 104 |
+
|
| 105 |
+
# Update category_id reference if categories were reindexed
|
| 106 |
+
if categories_zero and ann.get("category_id") is not None:
|
| 107 |
+
old_category_id = ann["category_id"]
|
| 108 |
+
if old_category_id in category_id_mapping:
|
| 109 |
+
ann["category_id"] = category_id_mapping[old_category_id]
|
| 110 |
+
|
| 111 |
+
return modified_data
|
| 112 |
+
|
| 113 |
+
# Validate input path
|
| 114 |
+
if not os.path.exists(input_json_path):
|
| 115 |
+
raise FileNotFoundError(f"Input file not found: {input_json_path}")
|
| 116 |
+
|
| 117 |
+
# Load and validate JSON data
|
| 118 |
+
try:
|
| 119 |
+
with open(input_json_path, "r", encoding="utf-8") as f:
|
| 120 |
+
data = json.load(f)
|
| 121 |
+
except json.JSONDecodeError as e:
|
| 122 |
+
raise json.JSONDecodeError(f"Invalid JSON in {input_json_path}: {e}")
|
| 123 |
+
|
| 124 |
+
# Validate COCO format
|
| 125 |
+
if not is_coco_json(data):
|
| 126 |
+
raise ValueError(
|
| 127 |
+
f"File does not appear to be in COCO format: {input_json_path}"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Check if reindexing is needed
|
| 131 |
+
annotations_zero, images_zero, categories_zero = check_zero_indexed(data)
|
| 132 |
+
|
| 133 |
+
if not (annotations_zero or images_zero or categories_zero):
|
| 134 |
+
# No conversion needed - just copy to temp location
|
| 135 |
+
input_path = Path(input_json_path)
|
| 136 |
+
temp_dir = tempfile.mkdtemp()
|
| 137 |
+
temp_filename = f"{input_path.stem}_1_indexed{input_path.suffix}"
|
| 138 |
+
temp_path = os.path.join(temp_dir, temp_filename)
|
| 139 |
+
|
| 140 |
+
with open(temp_path, "w", encoding="utf-8") as f:
|
| 141 |
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
| 142 |
+
|
| 143 |
+
return temp_path
|
| 144 |
+
|
| 145 |
+
# Perform reindexing
|
| 146 |
+
modified_data = reindex_coco_data(data)
|
| 147 |
+
|
| 148 |
+
# Create temporary file
|
| 149 |
+
input_path = Path(input_json_path)
|
| 150 |
+
temp_dir = tempfile.mkdtemp()
|
| 151 |
+
temp_filename = f"{input_path.stem}_1_indexed{input_path.suffix}"
|
| 152 |
+
temp_path = os.path.join(temp_dir, temp_filename)
|
| 153 |
+
|
| 154 |
+
# Write modified data to temporary file
|
| 155 |
+
with open(temp_path, "w", encoding="utf-8") as f:
|
| 156 |
+
json.dump(modified_data, f, indent=2, ensure_ascii=False)
|
| 157 |
+
|
| 158 |
+
return temp_path
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# Example usage and test function
|
| 162 |
+
def test_reindex_function():
|
| 163 |
+
"""Test the reindex function with a sample COCO file."""
|
| 164 |
+
|
| 165 |
+
# Create a test COCO file
|
| 166 |
+
test_data = {
|
| 167 |
+
"info": {"description": "Test COCO dataset", "version": "1.0", "year": 2023},
|
| 168 |
+
"images": [
|
| 169 |
+
{"id": 0, "width": 640, "height": 480, "file_name": "test1.jpg"},
|
| 170 |
+
{"id": 1, "width": 640, "height": 480, "file_name": "test2.jpg"},
|
| 171 |
+
],
|
| 172 |
+
"categories": [
|
| 173 |
+
{"id": 0, "name": "person", "supercategory": "person"},
|
| 174 |
+
{"id": 1, "name": "car", "supercategory": "vehicle"},
|
| 175 |
+
],
|
| 176 |
+
"annotations": [
|
| 177 |
+
{
|
| 178 |
+
"id": 0,
|
| 179 |
+
"image_id": 0,
|
| 180 |
+
"category_id": 0,
|
| 181 |
+
"bbox": [100, 100, 50, 75],
|
| 182 |
+
"area": 3750,
|
| 183 |
+
"iscrowd": 0,
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"id": 1,
|
| 187 |
+
"image_id": 1,
|
| 188 |
+
"category_id": 1,
|
| 189 |
+
"bbox": [200, 150, 120, 80],
|
| 190 |
+
"area": 9600,
|
| 191 |
+
"iscrowd": 0,
|
| 192 |
+
},
|
| 193 |
+
],
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
# Create temporary test file
|
| 197 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
| 198 |
+
json.dump(test_data, f, indent=2)
|
| 199 |
+
test_file_path = f.name
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
# Test the function
|
| 203 |
+
result_path = reindex_coco_to_temp(test_file_path)
|
| 204 |
+
print(f"Original file: {test_file_path}")
|
| 205 |
+
print(f"Converted file: {result_path}")
|
| 206 |
+
|
| 207 |
+
# Load and display the result
|
| 208 |
+
with open(result_path, "r") as f:
|
| 209 |
+
result_data = json.load(f)
|
| 210 |
+
|
| 211 |
+
print("\nConverted data sample:")
|
| 212 |
+
print(f"First image ID: {result_data['images'][0]['id']}")
|
| 213 |
+
print(f"First category ID: {result_data['categories'][0]['id']}")
|
| 214 |
+
print(f"First annotation ID: {result_data['annotations'][0]['id']}")
|
| 215 |
+
print(f"First annotation image_id: {result_data['annotations'][0]['image_id']}")
|
| 216 |
+
print(
|
| 217 |
+
f"First annotation category_id: {result_data['annotations'][0]['category_id']}"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# Clean up
|
| 221 |
+
os.unlink(result_path)
|
| 222 |
+
os.rmdir(os.path.dirname(result_path))
|
| 223 |
+
|
| 224 |
+
finally:
|
| 225 |
+
# Clean up test file
|
| 226 |
+
os.unlink(test_file_path)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
if __name__ == "__main__":
|
| 230 |
+
test_reindex_function()
|
sam3/eval/coco_writer.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
COCO prediction dumper for distributed training.
|
| 5 |
+
|
| 6 |
+
Handles collection and dumping of COCO-format predictions from models.
|
| 7 |
+
Supports distributed processing with multiple GPUs/processes.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import copy
|
| 11 |
+
import gc
|
| 12 |
+
import heapq
|
| 13 |
+
import json
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
from collections import defaultdict
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Any, Optional
|
| 19 |
+
|
| 20 |
+
import pycocotools.mask as mask_utils
|
| 21 |
+
import torch
|
| 22 |
+
from iopath.common.file_io import g_pathmgr
|
| 23 |
+
from sam3.eval.coco_eval_offline import convert_to_xywh
|
| 24 |
+
from sam3.train.masks_ops import rle_encode
|
| 25 |
+
from sam3.train.utils.distributed import (
|
| 26 |
+
all_gather,
|
| 27 |
+
gather_to_rank_0_via_filesys,
|
| 28 |
+
get_rank,
|
| 29 |
+
is_main_process,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
### Helper functions and classes
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class HeapElement:
|
| 37 |
+
"""Utility class to make a heap with a custom comparator based on score."""
|
| 38 |
+
|
| 39 |
+
def __init__(self, val):
|
| 40 |
+
self.val = val
|
| 41 |
+
|
| 42 |
+
def __lt__(self, other):
|
| 43 |
+
return self.val["score"] < other.val["score"]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class PredictionDumper:
|
| 47 |
+
"""
|
| 48 |
+
Handles collection and dumping of COCO-format predictions from a model.
|
| 49 |
+
|
| 50 |
+
This class processes model outputs through a postprocessor, converts them to COCO format,
|
| 51 |
+
and saves them to disk. It supports distributed processing with multiple GPUs/processes.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
dump_dir: str,
|
| 57 |
+
postprocessor,
|
| 58 |
+
maxdets: int,
|
| 59 |
+
iou_type: str,
|
| 60 |
+
gather_pred_via_filesys: bool = False,
|
| 61 |
+
merge_predictions: bool = False,
|
| 62 |
+
pred_file_evaluators: Optional[Any] = None,
|
| 63 |
+
):
|
| 64 |
+
"""
|
| 65 |
+
Initialize the PredictionDumper.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
dump_dir: Directory to dump predictions.
|
| 69 |
+
postprocessor: Module to convert the model's output into COCO format.
|
| 70 |
+
maxdets: Maximum number of detections per image.
|
| 71 |
+
iou_type: IoU type to evaluate. Can include "bbox", "segm"
|
| 72 |
+
gather_pred_via_filesys: If True, use the filesystem for collective gathers across
|
| 73 |
+
processes (requires a shared filesystem). Otherwise, use torch collective ops.
|
| 74 |
+
merge_predictions: If True, merge predictions from all processes and dump to a single file.
|
| 75 |
+
"""
|
| 76 |
+
self.iou_type = iou_type
|
| 77 |
+
self.maxdets = maxdets
|
| 78 |
+
self.dump_dir = dump_dir
|
| 79 |
+
self.postprocessor = postprocessor
|
| 80 |
+
self.gather_pred_via_filesys = gather_pred_via_filesys
|
| 81 |
+
self.merge_predictions = merge_predictions
|
| 82 |
+
self.pred_file_evaluators = pred_file_evaluators
|
| 83 |
+
if self.pred_file_evaluators is not None:
|
| 84 |
+
assert (
|
| 85 |
+
merge_predictions
|
| 86 |
+
), "merge_predictions must be True if pred_file_evaluators are provided"
|
| 87 |
+
assert self.dump_dir is not None, "dump_dir must be provided"
|
| 88 |
+
|
| 89 |
+
if is_main_process():
|
| 90 |
+
os.makedirs(self.dump_dir, exist_ok=True)
|
| 91 |
+
logging.info(f"Created prediction dump directory: {self.dump_dir}")
|
| 92 |
+
|
| 93 |
+
# Initialize state
|
| 94 |
+
self.reset()
|
| 95 |
+
|
| 96 |
+
def update(self, *args, **kwargs):
|
| 97 |
+
"""
|
| 98 |
+
Process and accumulate predictions from model outputs.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
*args, **kwargs: Arguments passed to postprocessor.process_results()
|
| 102 |
+
"""
|
| 103 |
+
predictions = self.postprocessor.process_results(*args, **kwargs)
|
| 104 |
+
results = self.prepare(predictions, self.iou_type)
|
| 105 |
+
self._dump(results)
|
| 106 |
+
|
| 107 |
+
def _dump(self, results):
|
| 108 |
+
"""
|
| 109 |
+
Add results to the dump list with precision rounding.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
results: List of prediction dictionaries in COCO format.
|
| 113 |
+
"""
|
| 114 |
+
dumped_results = copy.deepcopy(results)
|
| 115 |
+
for r in dumped_results:
|
| 116 |
+
if "bbox" in r:
|
| 117 |
+
r["bbox"] = [round(coord, 5) for coord in r["bbox"]]
|
| 118 |
+
r["score"] = round(r["score"], 5)
|
| 119 |
+
self.dump.extend(dumped_results)
|
| 120 |
+
|
| 121 |
+
def synchronize_between_processes(self):
|
| 122 |
+
"""
|
| 123 |
+
Synchronize predictions across all processes and save to disk.
|
| 124 |
+
|
| 125 |
+
If gather_pred_via_filesys is True, uses filesystem for gathering.
|
| 126 |
+
Otherwise, uses torch distributed collective operations.
|
| 127 |
+
Saves per-rank predictions to separate JSON files.
|
| 128 |
+
"""
|
| 129 |
+
logging.info("Prediction Dumper: Synchronizing between processes")
|
| 130 |
+
|
| 131 |
+
if not self.merge_predictions:
|
| 132 |
+
dumped_file = (
|
| 133 |
+
Path(self.dump_dir)
|
| 134 |
+
/ f"coco_predictions_{self.iou_type}_{get_rank()}.json"
|
| 135 |
+
)
|
| 136 |
+
logging.info(
|
| 137 |
+
f"Prediction Dumper: Dumping local predictions to {dumped_file}"
|
| 138 |
+
)
|
| 139 |
+
with g_pathmgr.open(str(dumped_file), "w") as f:
|
| 140 |
+
json.dump(self.dump, f)
|
| 141 |
+
else:
|
| 142 |
+
self.dump = self.gather_and_merge_predictions()
|
| 143 |
+
dumped_file = Path(self.dump_dir) / f"coco_predictions_{self.iou_type}.json"
|
| 144 |
+
if is_main_process():
|
| 145 |
+
logging.info(
|
| 146 |
+
f"Prediction Dumper: Dumping merged predictions to {dumped_file}"
|
| 147 |
+
)
|
| 148 |
+
with g_pathmgr.open(str(dumped_file), "w") as f:
|
| 149 |
+
json.dump(self.dump, f)
|
| 150 |
+
|
| 151 |
+
self.reset()
|
| 152 |
+
return dumped_file
|
| 153 |
+
|
| 154 |
+
def gather_and_merge_predictions(self):
|
| 155 |
+
"""
|
| 156 |
+
Gather predictions from all processes and merge them, keeping top predictions per image.
|
| 157 |
+
|
| 158 |
+
This method collects predictions from all processes, then keeps only the top maxdets
|
| 159 |
+
predictions per image based on score. It also deduplicates predictions by (image_id, category_id).
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
List of merged prediction dictionaries.
|
| 163 |
+
"""
|
| 164 |
+
logging.info("Prediction Dumper: Gathering predictions from all processes")
|
| 165 |
+
gc.collect()
|
| 166 |
+
|
| 167 |
+
if self.gather_pred_via_filesys:
|
| 168 |
+
dump = gather_to_rank_0_via_filesys(self.dump)
|
| 169 |
+
else:
|
| 170 |
+
dump = all_gather(self.dump, force_cpu=True)
|
| 171 |
+
|
| 172 |
+
# Combine predictions, keeping only top maxdets per image
|
| 173 |
+
preds_by_image = defaultdict(list)
|
| 174 |
+
seen_img_cat = set()
|
| 175 |
+
|
| 176 |
+
for cur_dump in dump:
|
| 177 |
+
cur_seen_img_cat = set()
|
| 178 |
+
for p in cur_dump:
|
| 179 |
+
image_id = p["image_id"]
|
| 180 |
+
cat_id = p["category_id"]
|
| 181 |
+
|
| 182 |
+
# Skip if we've already seen this image/category pair in a previous dump
|
| 183 |
+
if (image_id, cat_id) in seen_img_cat:
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
cur_seen_img_cat.add((image_id, cat_id))
|
| 187 |
+
|
| 188 |
+
# Use a min-heap to keep top predictions
|
| 189 |
+
if len(preds_by_image[image_id]) < self.maxdets:
|
| 190 |
+
heapq.heappush(preds_by_image[image_id], HeapElement(p))
|
| 191 |
+
else:
|
| 192 |
+
heapq.heappushpop(preds_by_image[image_id], HeapElement(p))
|
| 193 |
+
|
| 194 |
+
seen_img_cat.update(cur_seen_img_cat)
|
| 195 |
+
|
| 196 |
+
# Flatten the heap elements back to a list
|
| 197 |
+
merged_dump = sum(
|
| 198 |
+
[[h.val for h in cur_preds] for cur_preds in preds_by_image.values()], []
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
return merged_dump
|
| 202 |
+
|
| 203 |
+
def compute_synced(self):
|
| 204 |
+
"""
|
| 205 |
+
Synchronize predictions across processes and compute summary.
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
Summary dictionary from summarize().
|
| 209 |
+
"""
|
| 210 |
+
dumped_file = self.synchronize_between_processes()
|
| 211 |
+
if not is_main_process():
|
| 212 |
+
return {"": 0.0}
|
| 213 |
+
|
| 214 |
+
meters = {}
|
| 215 |
+
if self.pred_file_evaluators is not None:
|
| 216 |
+
for evaluator in self.pred_file_evaluators:
|
| 217 |
+
results = evaluator.evaluate(dumped_file)
|
| 218 |
+
meters.update(results)
|
| 219 |
+
|
| 220 |
+
if len(meters) == 0:
|
| 221 |
+
meters = {"": 0.0}
|
| 222 |
+
return meters
|
| 223 |
+
|
| 224 |
+
def compute(self):
|
| 225 |
+
"""
|
| 226 |
+
Compute without synchronization.
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Empty metric dictionary.
|
| 230 |
+
"""
|
| 231 |
+
return {"": 0.0}
|
| 232 |
+
|
| 233 |
+
def reset(self):
|
| 234 |
+
"""Reset internal state for a new evaluation round."""
|
| 235 |
+
self.dump = []
|
| 236 |
+
|
| 237 |
+
def prepare(self, predictions, iou_type):
|
| 238 |
+
"""
|
| 239 |
+
Route predictions to the appropriate preparation method based on iou_type.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
predictions: Dictionary mapping image IDs to prediction dictionaries.
|
| 243 |
+
iou_type: Type of evaluation ("bbox", "segm").
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
List of COCO-format prediction dictionaries.
|
| 247 |
+
"""
|
| 248 |
+
if iou_type == "bbox":
|
| 249 |
+
return self.prepare_for_coco_detection(predictions)
|
| 250 |
+
elif iou_type == "segm":
|
| 251 |
+
return self.prepare_for_coco_segmentation(predictions)
|
| 252 |
+
else:
|
| 253 |
+
raise ValueError(f"Unknown iou type: {iou_type}")
|
| 254 |
+
|
| 255 |
+
def prepare_for_coco_detection(self, predictions):
|
| 256 |
+
"""
|
| 257 |
+
Convert predictions to COCO detection format.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
predictions: Dictionary mapping image IDs to prediction dictionaries
|
| 261 |
+
containing "boxes", "scores", and "labels".
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
List of COCO-format detection dictionaries.
|
| 265 |
+
"""
|
| 266 |
+
coco_results = []
|
| 267 |
+
for original_id, prediction in predictions.items():
|
| 268 |
+
if len(prediction) == 0:
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
boxes = prediction["boxes"]
|
| 272 |
+
boxes = convert_to_xywh(boxes).tolist()
|
| 273 |
+
scores = prediction["scores"].tolist()
|
| 274 |
+
labels = prediction["labels"].tolist()
|
| 275 |
+
|
| 276 |
+
coco_results.extend(
|
| 277 |
+
[
|
| 278 |
+
{
|
| 279 |
+
"image_id": original_id,
|
| 280 |
+
"category_id": labels[k],
|
| 281 |
+
"bbox": box,
|
| 282 |
+
"score": scores[k],
|
| 283 |
+
}
|
| 284 |
+
for k, box in enumerate(boxes)
|
| 285 |
+
]
|
| 286 |
+
)
|
| 287 |
+
return coco_results
|
| 288 |
+
|
| 289 |
+
@torch.no_grad()
|
| 290 |
+
def prepare_for_coco_segmentation(self, predictions):
|
| 291 |
+
"""
|
| 292 |
+
Convert predictions to COCO segmentation format.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
predictions: Dictionary mapping image IDs to prediction dictionaries
|
| 296 |
+
containing "masks" or "masks_rle", "scores", and "labels".
|
| 297 |
+
Optionally includes "boundaries" and "dilated_boundaries".
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
List of COCO-format segmentation dictionaries with RLE-encoded masks.
|
| 301 |
+
"""
|
| 302 |
+
coco_results = []
|
| 303 |
+
for original_id, prediction in predictions.items():
|
| 304 |
+
if len(prediction) == 0:
|
| 305 |
+
continue
|
| 306 |
+
|
| 307 |
+
scores = prediction["scores"].tolist()
|
| 308 |
+
labels = prediction["labels"].tolist()
|
| 309 |
+
|
| 310 |
+
boxes = None
|
| 311 |
+
if "boxes" in prediction:
|
| 312 |
+
boxes = prediction["boxes"]
|
| 313 |
+
boxes = convert_to_xywh(boxes).tolist()
|
| 314 |
+
assert len(boxes) == len(scores)
|
| 315 |
+
|
| 316 |
+
if "masks_rle" in prediction:
|
| 317 |
+
rles = prediction["masks_rle"]
|
| 318 |
+
areas = []
|
| 319 |
+
for rle in rles:
|
| 320 |
+
cur_area = mask_utils.area(rle)
|
| 321 |
+
h, w = rle["size"]
|
| 322 |
+
areas.append(cur_area / (h * w))
|
| 323 |
+
else:
|
| 324 |
+
masks = prediction["masks"]
|
| 325 |
+
masks = masks > 0.5
|
| 326 |
+
h, w = masks.shape[-2:]
|
| 327 |
+
|
| 328 |
+
areas = masks.flatten(1).sum(1) / (h * w)
|
| 329 |
+
areas = areas.tolist()
|
| 330 |
+
|
| 331 |
+
rles = rle_encode(masks.squeeze(1))
|
| 332 |
+
|
| 333 |
+
# Memory cleanup
|
| 334 |
+
del masks
|
| 335 |
+
del prediction["masks"]
|
| 336 |
+
|
| 337 |
+
assert len(areas) == len(rles) == len(scores)
|
| 338 |
+
|
| 339 |
+
for k, rle in enumerate(rles):
|
| 340 |
+
payload = {
|
| 341 |
+
"image_id": original_id,
|
| 342 |
+
"category_id": labels[k],
|
| 343 |
+
"segmentation": rle,
|
| 344 |
+
"score": scores[k],
|
| 345 |
+
"area": areas[k],
|
| 346 |
+
}
|
| 347 |
+
if boxes is not None:
|
| 348 |
+
payload["bbox"] = boxes[k]
|
| 349 |
+
|
| 350 |
+
coco_results.append(payload)
|
| 351 |
+
|
| 352 |
+
return coco_results
|
sam3/eval/conversion_util.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def convert_ytbvis_to_cocovid_gt(ann_json, save_path=None):
|
| 10 |
+
"""Convert YouTube VIS dataset to COCO-style video instance segmentation format.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
ann_json (str): Path to YouTube VIS annotation JSON file
|
| 14 |
+
save_path (str): path to save converted COCO-style JSON
|
| 15 |
+
"""
|
| 16 |
+
# Initialize COCO structure
|
| 17 |
+
VIS = {
|
| 18 |
+
"info": {},
|
| 19 |
+
"images": [],
|
| 20 |
+
"videos": [],
|
| 21 |
+
"tracks": [],
|
| 22 |
+
"annotations": [],
|
| 23 |
+
"categories": [],
|
| 24 |
+
"licenses": [],
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
# Load original annotations
|
| 28 |
+
official_anns = json.load(open(ann_json))
|
| 29 |
+
VIS["categories"] = official_anns["categories"] # Direct copy categories
|
| 30 |
+
|
| 31 |
+
# Initialize counters
|
| 32 |
+
records = dict(img_id=1, ann_id=1)
|
| 33 |
+
|
| 34 |
+
# Create video-to-annotations mapping
|
| 35 |
+
vid_to_anns = defaultdict(list)
|
| 36 |
+
for ann in official_anns["annotations"]:
|
| 37 |
+
vid_to_anns[ann["video_id"]].append(ann)
|
| 38 |
+
|
| 39 |
+
# Create tracks directly
|
| 40 |
+
VIS["tracks"] = [
|
| 41 |
+
{
|
| 42 |
+
"id": ann["id"],
|
| 43 |
+
"category_id": ann["category_id"],
|
| 44 |
+
"video_id": ann["video_id"],
|
| 45 |
+
}
|
| 46 |
+
for ann in official_anns["annotations"]
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
# Process videos
|
| 50 |
+
for video_info in tqdm(official_anns["videos"]):
|
| 51 |
+
# Create video entry
|
| 52 |
+
video = {
|
| 53 |
+
"id": video_info["id"],
|
| 54 |
+
"name": os.path.dirname(video_info["file_names"][0]),
|
| 55 |
+
"width": video_info["width"],
|
| 56 |
+
"height": video_info["height"],
|
| 57 |
+
"length": video_info["length"],
|
| 58 |
+
"neg_category_ids": [],
|
| 59 |
+
"not_exhaustive_category_ids": [],
|
| 60 |
+
}
|
| 61 |
+
VIS["videos"].append(video)
|
| 62 |
+
|
| 63 |
+
# Process frames
|
| 64 |
+
num_frames = len(video_info["file_names"])
|
| 65 |
+
for frame_idx in range(num_frames):
|
| 66 |
+
# Create image entry
|
| 67 |
+
image = {
|
| 68 |
+
"id": records["img_id"],
|
| 69 |
+
"video_id": video_info["id"],
|
| 70 |
+
"file_name": video_info["file_names"][frame_idx],
|
| 71 |
+
"width": video_info["width"],
|
| 72 |
+
"height": video_info["height"],
|
| 73 |
+
"frame_index": frame_idx,
|
| 74 |
+
"frame_id": frame_idx,
|
| 75 |
+
}
|
| 76 |
+
VIS["images"].append(image)
|
| 77 |
+
|
| 78 |
+
# Process annotations for this frame
|
| 79 |
+
if video_info["id"] in vid_to_anns:
|
| 80 |
+
for ann in vid_to_anns[video_info["id"]]:
|
| 81 |
+
bbox = ann["bboxes"][frame_idx]
|
| 82 |
+
if bbox is None:
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
# Create annotation entry
|
| 86 |
+
annotation = {
|
| 87 |
+
"id": records["ann_id"],
|
| 88 |
+
"video_id": video_info["id"],
|
| 89 |
+
"image_id": records["img_id"],
|
| 90 |
+
"track_id": ann["id"],
|
| 91 |
+
"category_id": ann["category_id"],
|
| 92 |
+
"bbox": bbox,
|
| 93 |
+
"area": ann["areas"][frame_idx],
|
| 94 |
+
"segmentation": ann["segmentations"][frame_idx],
|
| 95 |
+
"iscrowd": ann["iscrowd"],
|
| 96 |
+
}
|
| 97 |
+
VIS["annotations"].append(annotation)
|
| 98 |
+
records["ann_id"] += 1
|
| 99 |
+
|
| 100 |
+
records["img_id"] += 1
|
| 101 |
+
|
| 102 |
+
# Print summary
|
| 103 |
+
print(f"Converted {len(VIS['videos'])} videos")
|
| 104 |
+
print(f"Converted {len(VIS['images'])} images")
|
| 105 |
+
print(f"Created {len(VIS['tracks'])} tracks")
|
| 106 |
+
print(f"Created {len(VIS['annotations'])} annotations")
|
| 107 |
+
|
| 108 |
+
if save_path is None:
|
| 109 |
+
return VIS
|
| 110 |
+
|
| 111 |
+
# Save output
|
| 112 |
+
save_dir = os.path.dirname(save_path)
|
| 113 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 114 |
+
json.dump(VIS, open(save_path, "w"))
|
| 115 |
+
|
| 116 |
+
return VIS
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def convert_ytbvis_to_cocovid_pred(
|
| 120 |
+
youtubevis_pred_path: str, converted_dataset_path: str, output_path: str
|
| 121 |
+
) -> None:
|
| 122 |
+
"""
|
| 123 |
+
Convert YouTubeVIS predictions to COCO format with video_id preservation
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
youtubevis_pred_path: Path to YouTubeVIS prediction JSON
|
| 127 |
+
converted_dataset_path: Path to converted COCO dataset JSON
|
| 128 |
+
output_path: Path to save COCO format predictions
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
# Load YouTubeVIS predictions
|
| 132 |
+
with open(youtubevis_pred_path) as f:
|
| 133 |
+
ytv_predictions = json.load(f)
|
| 134 |
+
|
| 135 |
+
# Load converted dataset for image ID mapping
|
| 136 |
+
with open(converted_dataset_path) as f:
|
| 137 |
+
coco_dataset = json.load(f)
|
| 138 |
+
|
| 139 |
+
# Create (video_id, frame_idx) -> image_id mapping
|
| 140 |
+
image_id_map = {
|
| 141 |
+
(img["video_id"], img["frame_index"]): img["id"]
|
| 142 |
+
for img in coco_dataset["images"]
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
coco_annotations = []
|
| 146 |
+
track_id_counter = 1 # Unique track ID generator
|
| 147 |
+
|
| 148 |
+
for pred in tqdm(ytv_predictions):
|
| 149 |
+
video_id = pred["video_id"]
|
| 150 |
+
category_id = pred["category_id"]
|
| 151 |
+
bboxes = pred["bboxes"]
|
| 152 |
+
segmentations = pred.get("segmentations", []) # Get segmentations if available
|
| 153 |
+
areas = pred.get("areas", []) # Get areas if available
|
| 154 |
+
score = pred["score"]
|
| 155 |
+
|
| 156 |
+
# Assign unique track ID for this prediction
|
| 157 |
+
track_id = track_id_counter
|
| 158 |
+
track_id_counter += 1
|
| 159 |
+
|
| 160 |
+
# Ensure segmentations and areas have the same length as bboxes
|
| 161 |
+
if len(segmentations) == 0:
|
| 162 |
+
segmentations = [None] * len(bboxes)
|
| 163 |
+
if len(areas) == 0:
|
| 164 |
+
areas = [None] * len(bboxes)
|
| 165 |
+
|
| 166 |
+
for frame_idx, (bbox, segmentation, area_from_pred) in enumerate(
|
| 167 |
+
zip(bboxes, segmentations, areas)
|
| 168 |
+
):
|
| 169 |
+
# Skip frames with missing objects (None or zero bbox)
|
| 170 |
+
if bbox is None or all(x == 0 for x in bbox):
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
# Get corresponding image ID from mapping
|
| 174 |
+
image_id = image_id_map.get((video_id, frame_idx))
|
| 175 |
+
if image_id is None:
|
| 176 |
+
raise RuntimeError(
|
| 177 |
+
f"prediction {video_id=}, {frame_idx=} does not match any images in the converted COCO format"
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Extract bbox coordinates
|
| 181 |
+
x, y, w, h = bbox
|
| 182 |
+
|
| 183 |
+
# Calculate area - use area from prediction if available, otherwise from bbox
|
| 184 |
+
if area_from_pred is not None and area_from_pred > 0:
|
| 185 |
+
area = area_from_pred
|
| 186 |
+
else:
|
| 187 |
+
area = w * h
|
| 188 |
+
|
| 189 |
+
# Create COCO annotation with video_id
|
| 190 |
+
coco_annotation = {
|
| 191 |
+
"image_id": int(image_id),
|
| 192 |
+
"video_id": video_id, # Added video_id field
|
| 193 |
+
"track_id": track_id,
|
| 194 |
+
"category_id": category_id,
|
| 195 |
+
"bbox": [float(x), float(y), float(w), float(h)],
|
| 196 |
+
"area": float(area),
|
| 197 |
+
"iscrowd": 0,
|
| 198 |
+
"score": float(score),
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
# Add segmentation if available
|
| 202 |
+
if segmentation is not None:
|
| 203 |
+
coco_annotation["segmentation"] = segmentation
|
| 204 |
+
|
| 205 |
+
coco_annotations.append(coco_annotation)
|
| 206 |
+
|
| 207 |
+
# Save output
|
| 208 |
+
with open(output_path, "w") as f:
|
| 209 |
+
json.dump(coco_annotations, f)
|
| 210 |
+
|
| 211 |
+
print(f"Converted {len(coco_annotations)} predictions to COCO format with video_id")
|
sam3/eval/demo_eval.py
ADDED
|
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
This evaluator is based upon COCO evaluation, but evaluates the model in a "demo" setting.
|
| 5 |
+
This means that the model's predictions are thresholded and evaluated as "hard" predictions.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pycocotools.mask as maskUtils
|
| 13 |
+
from pycocotools.cocoeval import COCOeval
|
| 14 |
+
|
| 15 |
+
from sam3.eval.coco_eval import CocoEvaluator
|
| 16 |
+
from sam3.train.masks_ops import compute_F_measure
|
| 17 |
+
from sam3.train.utils.distributed import is_main_process
|
| 18 |
+
|
| 19 |
+
from scipy.optimize import linear_sum_assignment
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DemoEval(COCOeval):
|
| 23 |
+
"""
|
| 24 |
+
This evaluator is based upon COCO evaluation, but evaluates the model in a "demo" setting.
|
| 25 |
+
This means that the model's predictions are thresholded and evaluated as "hard" predictions.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
coco_gt=None,
|
| 31 |
+
coco_dt=None,
|
| 32 |
+
iouType="bbox",
|
| 33 |
+
threshold=0.5,
|
| 34 |
+
compute_JnF=False,
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Args:
|
| 38 |
+
coco_gt (COCO): ground truth COCO API
|
| 39 |
+
coco_dt (COCO): detections COCO API
|
| 40 |
+
iou_type (str): type of IoU to evaluate
|
| 41 |
+
threshold (float): threshold for predictions
|
| 42 |
+
"""
|
| 43 |
+
super().__init__(coco_gt, coco_dt, iouType)
|
| 44 |
+
self.threshold = threshold
|
| 45 |
+
|
| 46 |
+
self.params.useCats = False
|
| 47 |
+
self.params.areaRng = [[0**2, 1e5**2]]
|
| 48 |
+
self.params.areaRngLbl = ["all"]
|
| 49 |
+
self.params.maxDets = [100000]
|
| 50 |
+
self.compute_JnF = compute_JnF
|
| 51 |
+
|
| 52 |
+
def computeIoU(self, imgId, catId):
|
| 53 |
+
# Same as the original COCOeval.computeIoU, but without sorting
|
| 54 |
+
p = self.params
|
| 55 |
+
if p.useCats:
|
| 56 |
+
gt = self._gts[imgId, catId]
|
| 57 |
+
dt = self._dts[imgId, catId]
|
| 58 |
+
else:
|
| 59 |
+
gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
|
| 60 |
+
dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
|
| 61 |
+
if len(gt) == 0 and len(dt) == 0:
|
| 62 |
+
return []
|
| 63 |
+
|
| 64 |
+
if p.iouType == "segm":
|
| 65 |
+
g = [g["segmentation"] for g in gt]
|
| 66 |
+
d = [d["segmentation"] for d in dt]
|
| 67 |
+
elif p.iouType == "bbox":
|
| 68 |
+
g = [g["bbox"] for g in gt]
|
| 69 |
+
d = [d["bbox"] for d in dt]
|
| 70 |
+
else:
|
| 71 |
+
raise Exception("unknown iouType for iou computation")
|
| 72 |
+
|
| 73 |
+
# compute iou between each dt and gt region
|
| 74 |
+
iscrowd = [int(o["iscrowd"]) for o in gt]
|
| 75 |
+
ious = maskUtils.iou(d, g, iscrowd)
|
| 76 |
+
return ious
|
| 77 |
+
|
| 78 |
+
def evaluateImg(self, imgId, catId, aRng, maxDet):
|
| 79 |
+
"""
|
| 80 |
+
perform evaluation for single category and image
|
| 81 |
+
:return: dict (single image results)
|
| 82 |
+
"""
|
| 83 |
+
p = self.params
|
| 84 |
+
assert not p.useCats, "This evaluator does not support per-category evaluation."
|
| 85 |
+
assert catId == -1
|
| 86 |
+
all_gts = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
|
| 87 |
+
keep_gt = np.array([not g["ignore"] for g in all_gts], dtype=bool)
|
| 88 |
+
gt = [g for g in all_gts if not g["ignore"]]
|
| 89 |
+
all_dts = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
|
| 90 |
+
keep_dt = np.array([d["score"] >= self.threshold for d in all_dts], dtype=bool)
|
| 91 |
+
dt = [d for d in all_dts if d["score"] >= self.threshold]
|
| 92 |
+
if len(gt) == 0 and len(dt) == 0:
|
| 93 |
+
# This is a "true negative" case, where there are no GTs and no predictions
|
| 94 |
+
# The box-level metrics are ill-defined, so we don't add them to this dict
|
| 95 |
+
return {
|
| 96 |
+
"image_id": imgId,
|
| 97 |
+
"IL_TP": 0,
|
| 98 |
+
"IL_TN": 1,
|
| 99 |
+
"IL_FP": 0,
|
| 100 |
+
"IL_FN": 0,
|
| 101 |
+
"IL_perfect_neg": np.ones((len(p.iouThrs),), dtype=np.int64),
|
| 102 |
+
"num_dt": len(dt),
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
if len(gt) > 0 and len(dt) == 0:
|
| 106 |
+
# This is a "false negative" case, where there are GTs but no predictions
|
| 107 |
+
return {
|
| 108 |
+
"image_id": imgId,
|
| 109 |
+
"IL_TP": 0,
|
| 110 |
+
"IL_TN": 0,
|
| 111 |
+
"IL_FP": 0,
|
| 112 |
+
"IL_FN": 1,
|
| 113 |
+
"TPs": np.zeros((len(p.iouThrs),), dtype=np.int64),
|
| 114 |
+
"FPs": np.zeros((len(p.iouThrs),), dtype=np.int64),
|
| 115 |
+
"FNs": np.ones((len(p.iouThrs),), dtype=np.int64) * len(gt),
|
| 116 |
+
"local_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64),
|
| 117 |
+
"local_positive_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64),
|
| 118 |
+
"IL_perfect_pos": np.zeros((len(p.iouThrs),), dtype=np.int64),
|
| 119 |
+
"num_dt": len(dt),
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
# Load pre-computed ious
|
| 123 |
+
ious = self.ious[(imgId, catId)]
|
| 124 |
+
|
| 125 |
+
# compute matching
|
| 126 |
+
if len(ious) == 0:
|
| 127 |
+
ious = np.zeros((len(dt), len(gt)))
|
| 128 |
+
else:
|
| 129 |
+
ious = ious[keep_dt, :][:, keep_gt]
|
| 130 |
+
assert ious.shape == (len(dt), len(gt))
|
| 131 |
+
|
| 132 |
+
matched_dt, matched_gt = linear_sum_assignment(-ious)
|
| 133 |
+
|
| 134 |
+
match_scores = ious[matched_dt, matched_gt]
|
| 135 |
+
|
| 136 |
+
if self.compute_JnF and len(match_scores) > 0:
|
| 137 |
+
j_score = match_scores.mean()
|
| 138 |
+
f_measure = 0
|
| 139 |
+
for dt_id, gt_id in zip(matched_dt, matched_gt):
|
| 140 |
+
f_measure += compute_F_measure(
|
| 141 |
+
gt_boundary_rle=gt[gt_id]["boundary"],
|
| 142 |
+
gt_dilated_boundary_rle=gt[gt_id]["dilated_boundary"],
|
| 143 |
+
dt_boundary_rle=dt[dt_id]["boundary"],
|
| 144 |
+
dt_dilated_boundary_rle=dt[dt_id]["dilated_boundary"],
|
| 145 |
+
)
|
| 146 |
+
f_measure /= len(match_scores) + 1e-9
|
| 147 |
+
JnF = (j_score + f_measure) * 0.5
|
| 148 |
+
else:
|
| 149 |
+
j_score = f_measure = JnF = -1
|
| 150 |
+
|
| 151 |
+
TPs, FPs, FNs = [], [], []
|
| 152 |
+
IL_perfect = []
|
| 153 |
+
for thresh in p.iouThrs:
|
| 154 |
+
TP = (match_scores >= thresh).sum()
|
| 155 |
+
FP = len(dt) - TP
|
| 156 |
+
FN = len(gt) - TP
|
| 157 |
+
assert (
|
| 158 |
+
FP >= 0 and FN >= 0
|
| 159 |
+
), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
|
| 160 |
+
TPs.append(TP)
|
| 161 |
+
FPs.append(FP)
|
| 162 |
+
FNs.append(FN)
|
| 163 |
+
|
| 164 |
+
if FP == FN and FP == 0:
|
| 165 |
+
IL_perfect.append(1)
|
| 166 |
+
else:
|
| 167 |
+
IL_perfect.append(0)
|
| 168 |
+
|
| 169 |
+
TPs = np.array(TPs, dtype=np.int64)
|
| 170 |
+
FPs = np.array(FPs, dtype=np.int64)
|
| 171 |
+
FNs = np.array(FNs, dtype=np.int64)
|
| 172 |
+
IL_perfect = np.array(IL_perfect, dtype=np.int64)
|
| 173 |
+
|
| 174 |
+
# compute precision recall and F1
|
| 175 |
+
precision = TPs / (TPs + FPs + 1e-4)
|
| 176 |
+
assert np.all(precision <= 1)
|
| 177 |
+
recall = TPs / (TPs + FNs + 1e-4)
|
| 178 |
+
assert np.all(recall <= 1)
|
| 179 |
+
F1 = 2 * precision * recall / (precision + recall + 1e-4)
|
| 180 |
+
|
| 181 |
+
result = {
|
| 182 |
+
"image_id": imgId,
|
| 183 |
+
"TPs": TPs,
|
| 184 |
+
"FPs": FPs,
|
| 185 |
+
"FNs": FNs,
|
| 186 |
+
"local_F1s": F1,
|
| 187 |
+
"IL_TP": (len(gt) > 0) and (len(dt) > 0),
|
| 188 |
+
"IL_FP": (len(gt) == 0) and (len(dt) > 0),
|
| 189 |
+
"IL_TN": (len(gt) == 0) and (len(dt) == 0),
|
| 190 |
+
"IL_FN": (len(gt) > 0) and (len(dt) == 0),
|
| 191 |
+
("IL_perfect_pos" if len(gt) > 0 else "IL_perfect_neg"): IL_perfect,
|
| 192 |
+
"F": f_measure,
|
| 193 |
+
"J": j_score,
|
| 194 |
+
"J&F": JnF,
|
| 195 |
+
"num_dt": len(dt),
|
| 196 |
+
}
|
| 197 |
+
if len(gt) > 0 and len(dt) > 0:
|
| 198 |
+
result["local_positive_F1s"] = F1
|
| 199 |
+
return result
|
| 200 |
+
|
| 201 |
+
def accumulate(self, p=None):
|
| 202 |
+
"""
|
| 203 |
+
Accumulate per image evaluation results and store the result in self.eval
|
| 204 |
+
:param p: input params for evaluation
|
| 205 |
+
:return: None
|
| 206 |
+
"""
|
| 207 |
+
if not self.evalImgs:
|
| 208 |
+
print("Please run evaluate() first")
|
| 209 |
+
# allows input customized parameters
|
| 210 |
+
if p is None:
|
| 211 |
+
p = self.params
|
| 212 |
+
|
| 213 |
+
setImgIds = set(p.imgIds)
|
| 214 |
+
|
| 215 |
+
# TPs, FPs, FNs
|
| 216 |
+
TPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
|
| 217 |
+
FPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
|
| 218 |
+
pmFPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
|
| 219 |
+
FNs = np.zeros((len(p.iouThrs),), dtype=np.int64)
|
| 220 |
+
local_F1s = np.zeros((len(p.iouThrs),), dtype=np.float64)
|
| 221 |
+
|
| 222 |
+
# Image level metrics
|
| 223 |
+
IL_TPs = 0
|
| 224 |
+
IL_FPs = 0
|
| 225 |
+
IL_TNs = 0
|
| 226 |
+
IL_FNs = 0
|
| 227 |
+
IL_perfects_neg = np.zeros((len(p.iouThrs),), dtype=np.int64)
|
| 228 |
+
IL_perfects_pos = np.zeros((len(p.iouThrs),), dtype=np.int64)
|
| 229 |
+
|
| 230 |
+
# JnF metric
|
| 231 |
+
total_J = 0
|
| 232 |
+
total_F = 0
|
| 233 |
+
total_JnF = 0
|
| 234 |
+
|
| 235 |
+
valid_img_count = 0
|
| 236 |
+
total_pos_count = 0
|
| 237 |
+
total_neg_count = 0
|
| 238 |
+
valid_J_count = 0
|
| 239 |
+
valid_F1_count = 0
|
| 240 |
+
valid_F1_count_w0dt = 0
|
| 241 |
+
for res in self.evalImgs:
|
| 242 |
+
if res["image_id"] not in setImgIds:
|
| 243 |
+
continue
|
| 244 |
+
IL_TPs += res["IL_TP"]
|
| 245 |
+
IL_FPs += res["IL_FP"]
|
| 246 |
+
IL_TNs += res["IL_TN"]
|
| 247 |
+
IL_FNs += res["IL_FN"]
|
| 248 |
+
if "IL_perfect_neg" in res:
|
| 249 |
+
IL_perfects_neg += res["IL_perfect_neg"]
|
| 250 |
+
total_neg_count += 1
|
| 251 |
+
else:
|
| 252 |
+
assert "IL_perfect_pos" in res
|
| 253 |
+
IL_perfects_pos += res["IL_perfect_pos"]
|
| 254 |
+
total_pos_count += 1
|
| 255 |
+
|
| 256 |
+
if "TPs" not in res:
|
| 257 |
+
continue
|
| 258 |
+
|
| 259 |
+
TPs += res["TPs"]
|
| 260 |
+
FPs += res["FPs"]
|
| 261 |
+
FNs += res["FNs"]
|
| 262 |
+
valid_img_count += 1
|
| 263 |
+
|
| 264 |
+
if "local_positive_F1s" in res:
|
| 265 |
+
local_F1s += res["local_positive_F1s"]
|
| 266 |
+
pmFPs += res["FPs"]
|
| 267 |
+
valid_F1_count_w0dt += 1
|
| 268 |
+
if res["num_dt"] > 0:
|
| 269 |
+
valid_F1_count += 1
|
| 270 |
+
|
| 271 |
+
if "J" in res and res["J"] > -1e-9:
|
| 272 |
+
total_J += res["J"]
|
| 273 |
+
total_F += res["F"]
|
| 274 |
+
total_JnF += res["J&F"]
|
| 275 |
+
valid_J_count += 1
|
| 276 |
+
|
| 277 |
+
# compute precision recall and F1
|
| 278 |
+
precision = TPs / (TPs + FPs + 1e-4)
|
| 279 |
+
positive_micro_precision = TPs / (TPs + pmFPs + 1e-4)
|
| 280 |
+
assert np.all(precision <= 1)
|
| 281 |
+
recall = TPs / (TPs + FNs + 1e-4)
|
| 282 |
+
assert np.all(recall <= 1)
|
| 283 |
+
F1 = 2 * precision * recall / (precision + recall + 1e-4)
|
| 284 |
+
positive_micro_F1 = (
|
| 285 |
+
2
|
| 286 |
+
* positive_micro_precision
|
| 287 |
+
* recall
|
| 288 |
+
/ (positive_micro_precision + recall + 1e-4)
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
IL_rec = IL_TPs / (IL_TPs + IL_FNs + 1e-6)
|
| 292 |
+
IL_prec = IL_TPs / (IL_TPs + IL_FPs + 1e-6)
|
| 293 |
+
IL_F1 = 2 * IL_prec * IL_rec / (IL_prec + IL_rec + 1e-6)
|
| 294 |
+
IL_FPR = IL_FPs / (IL_FPs + IL_TNs + 1e-6)
|
| 295 |
+
IL_MCC = float(IL_TPs * IL_TNs - IL_FPs * IL_FNs) / (
|
| 296 |
+
(
|
| 297 |
+
float(IL_TPs + IL_FPs)
|
| 298 |
+
* float(IL_TPs + IL_FNs)
|
| 299 |
+
* float(IL_TNs + IL_FPs)
|
| 300 |
+
* float(IL_TNs + IL_FNs)
|
| 301 |
+
)
|
| 302 |
+
** 0.5
|
| 303 |
+
+ 1e-6
|
| 304 |
+
)
|
| 305 |
+
IL_perfect_pos = IL_perfects_pos / (total_pos_count + 1e-9)
|
| 306 |
+
IL_perfect_neg = IL_perfects_neg / (total_neg_count + 1e-9)
|
| 307 |
+
|
| 308 |
+
total_J = total_J / (valid_J_count + 1e-9)
|
| 309 |
+
total_F = total_F / (valid_J_count + 1e-9)
|
| 310 |
+
total_JnF = total_JnF / (valid_J_count + 1e-9)
|
| 311 |
+
|
| 312 |
+
self.eval = {
|
| 313 |
+
"params": p,
|
| 314 |
+
"TPs": TPs,
|
| 315 |
+
"FPs": FPs,
|
| 316 |
+
"positive_micro_FPs": pmFPs,
|
| 317 |
+
"FNs": FNs,
|
| 318 |
+
"precision": precision,
|
| 319 |
+
"positive_micro_precision": positive_micro_precision,
|
| 320 |
+
"recall": recall,
|
| 321 |
+
"F1": F1,
|
| 322 |
+
"positive_micro_F1": positive_micro_F1,
|
| 323 |
+
"positive_macro_F1": local_F1s / valid_F1_count,
|
| 324 |
+
"positive_w0dt_macro_F1": local_F1s / valid_F1_count_w0dt,
|
| 325 |
+
"IL_recall": IL_rec,
|
| 326 |
+
"IL_precision": IL_prec,
|
| 327 |
+
"IL_F1": IL_F1,
|
| 328 |
+
"IL_FPR": IL_FPR,
|
| 329 |
+
"IL_MCC": IL_MCC,
|
| 330 |
+
"IL_perfect_pos": IL_perfect_pos,
|
| 331 |
+
"IL_perfect_neg": IL_perfect_neg,
|
| 332 |
+
"J": total_J,
|
| 333 |
+
"F": total_F,
|
| 334 |
+
"J&F": total_JnF,
|
| 335 |
+
}
|
| 336 |
+
self.eval["CGF1"] = self.eval["positive_macro_F1"] * self.eval["IL_MCC"]
|
| 337 |
+
self.eval["CGF1_w0dt"] = (
|
| 338 |
+
self.eval["positive_w0dt_macro_F1"] * self.eval["IL_MCC"]
|
| 339 |
+
)
|
| 340 |
+
self.eval["CGF1_micro"] = self.eval["positive_micro_F1"] * self.eval["IL_MCC"]
|
| 341 |
+
|
| 342 |
+
def summarize(self):
|
| 343 |
+
"""
|
| 344 |
+
Compute and display summary metrics for evaluation results.
|
| 345 |
+
Note this functin can *only* be applied on the default parameter setting
|
| 346 |
+
"""
|
| 347 |
+
if not self.eval:
|
| 348 |
+
raise Exception("Please run accumulate() first")
|
| 349 |
+
|
| 350 |
+
def _summarize(iouThr=None, metric=""):
|
| 351 |
+
p = self.params
|
| 352 |
+
iStr = " {:<18} @[ IoU={:<9}] = {:0.3f}"
|
| 353 |
+
titleStr = "Average " + metric
|
| 354 |
+
iouStr = (
|
| 355 |
+
"{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
|
| 356 |
+
if iouThr is None
|
| 357 |
+
else "{:0.2f}".format(iouThr)
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
s = self.eval[metric]
|
| 361 |
+
# IoU
|
| 362 |
+
if iouThr is not None:
|
| 363 |
+
t = np.where(iouThr == p.iouThrs)[0]
|
| 364 |
+
s = s[t]
|
| 365 |
+
|
| 366 |
+
if len(s[s > -1]) == 0:
|
| 367 |
+
mean_s = -1
|
| 368 |
+
else:
|
| 369 |
+
mean_s = np.mean(s[s > -1])
|
| 370 |
+
print(iStr.format(titleStr, iouStr, mean_s))
|
| 371 |
+
return mean_s
|
| 372 |
+
|
| 373 |
+
def _summarize_single(metric=""):
|
| 374 |
+
titleStr = "Average " + metric
|
| 375 |
+
iStr = " {:<35} = {:0.3f}"
|
| 376 |
+
s = self.eval[metric]
|
| 377 |
+
print(iStr.format(titleStr, s))
|
| 378 |
+
return s
|
| 379 |
+
|
| 380 |
+
def _summarizeDets():
|
| 381 |
+
# note: the index of these metrics are also used in video Demo F1 evaluation
|
| 382 |
+
# when adding new metrics, please update the index in video Demo F1 evaluation
|
| 383 |
+
# in "evaluate" method of the "VideoDemoF1Evaluator" class
|
| 384 |
+
stats = np.zeros((len(DEMO_METRICS),))
|
| 385 |
+
stats[0] = _summarize(metric="CGF1")
|
| 386 |
+
stats[1] = _summarize(metric="precision")
|
| 387 |
+
stats[2] = _summarize(metric="recall")
|
| 388 |
+
stats[3] = _summarize(metric="F1")
|
| 389 |
+
stats[4] = _summarize(metric="positive_macro_F1")
|
| 390 |
+
stats[5] = _summarize_single(metric="IL_precision")
|
| 391 |
+
stats[6] = _summarize_single(metric="IL_recall")
|
| 392 |
+
stats[7] = _summarize_single(metric="IL_F1")
|
| 393 |
+
stats[8] = _summarize_single(metric="IL_FPR")
|
| 394 |
+
stats[9] = _summarize_single(metric="IL_MCC")
|
| 395 |
+
stats[10] = _summarize(metric="IL_perfect_pos")
|
| 396 |
+
stats[11] = _summarize(metric="IL_perfect_neg")
|
| 397 |
+
stats[12] = _summarize(iouThr=0.5, metric="CGF1")
|
| 398 |
+
stats[13] = _summarize(iouThr=0.5, metric="precision")
|
| 399 |
+
stats[14] = _summarize(iouThr=0.5, metric="recall")
|
| 400 |
+
stats[15] = _summarize(iouThr=0.5, metric="F1")
|
| 401 |
+
stats[16] = _summarize(iouThr=0.5, metric="positive_macro_F1")
|
| 402 |
+
stats[17] = _summarize(iouThr=0.5, metric="IL_perfect_pos")
|
| 403 |
+
stats[18] = _summarize(iouThr=0.5, metric="IL_perfect_neg")
|
| 404 |
+
stats[19] = _summarize(iouThr=0.75, metric="CGF1")
|
| 405 |
+
stats[20] = _summarize(iouThr=0.75, metric="precision")
|
| 406 |
+
stats[21] = _summarize(iouThr=0.75, metric="recall")
|
| 407 |
+
stats[22] = _summarize(iouThr=0.75, metric="F1")
|
| 408 |
+
stats[23] = _summarize(iouThr=0.75, metric="positive_macro_F1")
|
| 409 |
+
stats[24] = _summarize(iouThr=0.75, metric="IL_perfect_pos")
|
| 410 |
+
stats[25] = _summarize(iouThr=0.75, metric="IL_perfect_neg")
|
| 411 |
+
stats[26] = _summarize_single(metric="J")
|
| 412 |
+
stats[27] = _summarize_single(metric="F")
|
| 413 |
+
stats[28] = _summarize_single(metric="J&F")
|
| 414 |
+
stats[29] = _summarize(metric="CGF1_micro")
|
| 415 |
+
stats[30] = _summarize(metric="positive_micro_precision")
|
| 416 |
+
stats[31] = _summarize(metric="positive_micro_F1")
|
| 417 |
+
stats[32] = _summarize(iouThr=0.5, metric="CGF1_micro")
|
| 418 |
+
stats[33] = _summarize(iouThr=0.5, metric="positive_micro_precision")
|
| 419 |
+
stats[34] = _summarize(iouThr=0.5, metric="positive_micro_F1")
|
| 420 |
+
stats[35] = _summarize(iouThr=0.75, metric="CGF1_micro")
|
| 421 |
+
stats[36] = _summarize(iouThr=0.75, metric="positive_micro_precision")
|
| 422 |
+
stats[37] = _summarize(iouThr=0.75, metric="positive_micro_F1")
|
| 423 |
+
stats[38] = _summarize(metric="CGF1_w0dt")
|
| 424 |
+
stats[39] = _summarize(metric="positive_w0dt_macro_F1")
|
| 425 |
+
stats[40] = _summarize(iouThr=0.5, metric="CGF1_w0dt")
|
| 426 |
+
stats[41] = _summarize(iouThr=0.5, metric="positive_w0dt_macro_F1")
|
| 427 |
+
stats[42] = _summarize(iouThr=0.75, metric="CGF1_w0dt")
|
| 428 |
+
stats[43] = _summarize(iouThr=0.75, metric="positive_w0dt_macro_F1")
|
| 429 |
+
return stats
|
| 430 |
+
|
| 431 |
+
summarize = _summarizeDets
|
| 432 |
+
self.stats = summarize()
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
DEMO_METRICS = [
|
| 436 |
+
"CGF1",
|
| 437 |
+
"Precision",
|
| 438 |
+
"Recall",
|
| 439 |
+
"F1",
|
| 440 |
+
"Macro_F1",
|
| 441 |
+
"IL_Precision",
|
| 442 |
+
"IL_Recall",
|
| 443 |
+
"IL_F1",
|
| 444 |
+
"IL_FPR",
|
| 445 |
+
"IL_MCC",
|
| 446 |
+
"IL_perfect_pos",
|
| 447 |
+
"IL_perfect_neg",
|
| 448 |
+
"CGF1@0.5",
|
| 449 |
+
"Precision@0.5",
|
| 450 |
+
"Recall@0.5",
|
| 451 |
+
"F1@0.5",
|
| 452 |
+
"Macro_F1@0.5",
|
| 453 |
+
"IL_perfect_pos@0.5",
|
| 454 |
+
"IL_perfect_neg@0.5",
|
| 455 |
+
"CGF1@0.75",
|
| 456 |
+
"Precision@0.75",
|
| 457 |
+
"Recall@0.75",
|
| 458 |
+
"F1@0.75",
|
| 459 |
+
"Macro_F1@0.75",
|
| 460 |
+
"IL_perfect_pos@0.75",
|
| 461 |
+
"IL_perfect_neg@0.75",
|
| 462 |
+
"J",
|
| 463 |
+
"F",
|
| 464 |
+
"J&F",
|
| 465 |
+
"CGF1_micro",
|
| 466 |
+
"positive_micro_Precision",
|
| 467 |
+
"positive_micro_F1",
|
| 468 |
+
"CGF1_micro@0.5",
|
| 469 |
+
"positive_micro_Precision@0.5",
|
| 470 |
+
"positive_micro_F1@0.5",
|
| 471 |
+
"CGF1_micro@0.75",
|
| 472 |
+
"positive_micro_Precision@0.75",
|
| 473 |
+
"positive_micro_F1@0.75",
|
| 474 |
+
"CGF1_w0dt",
|
| 475 |
+
"positive_w0dt_macro_F1",
|
| 476 |
+
"CGF1_w0dt@0.5",
|
| 477 |
+
"positive_w0dt_macro_F1@0.5",
|
| 478 |
+
"CGF1_w0dt@0.75",
|
| 479 |
+
"positive_w0dt_macro_F1@0.75",
|
| 480 |
+
]
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class DemoEvaluator(CocoEvaluator):
|
| 484 |
+
def __init__(
|
| 485 |
+
self,
|
| 486 |
+
coco_gt,
|
| 487 |
+
iou_types,
|
| 488 |
+
dump_dir: Optional[str],
|
| 489 |
+
postprocessor,
|
| 490 |
+
threshold=0.5,
|
| 491 |
+
average_by_rarity=False,
|
| 492 |
+
gather_pred_via_filesys=False,
|
| 493 |
+
exhaustive_only=False,
|
| 494 |
+
all_exhaustive_only=True,
|
| 495 |
+
compute_JnF=False,
|
| 496 |
+
metrics_dump_dir: Optional[str] = None,
|
| 497 |
+
):
|
| 498 |
+
self.iou_types = iou_types
|
| 499 |
+
self.threshold = threshold
|
| 500 |
+
super().__init__(
|
| 501 |
+
coco_gt=coco_gt,
|
| 502 |
+
iou_types=iou_types,
|
| 503 |
+
useCats=False,
|
| 504 |
+
dump_dir=dump_dir,
|
| 505 |
+
postprocessor=postprocessor,
|
| 506 |
+
# average_by_rarity=average_by_rarity,
|
| 507 |
+
gather_pred_via_filesys=gather_pred_via_filesys,
|
| 508 |
+
exhaustive_only=exhaustive_only,
|
| 509 |
+
all_exhaustive_only=all_exhaustive_only,
|
| 510 |
+
metrics_dump_dir=metrics_dump_dir,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
self.use_self_evaluate = True
|
| 514 |
+
self.compute_JnF = compute_JnF
|
| 515 |
+
|
| 516 |
+
def _lazy_init(self):
|
| 517 |
+
if self.initialized:
|
| 518 |
+
return
|
| 519 |
+
super()._lazy_init()
|
| 520 |
+
self.use_self_evaluate = True
|
| 521 |
+
self.reset()
|
| 522 |
+
|
| 523 |
+
def select_best_scoring(self, scorings):
|
| 524 |
+
# This function is used for "oracle" type evaluation.
|
| 525 |
+
# It accepts the evaluation results with respect to several ground truths, and picks the best
|
| 526 |
+
if len(scorings) == 1:
|
| 527 |
+
return scorings[0]
|
| 528 |
+
|
| 529 |
+
assert (
|
| 530 |
+
scorings[0].ndim == 3
|
| 531 |
+
), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
|
| 532 |
+
assert (
|
| 533 |
+
scorings[0].shape[0] == 1
|
| 534 |
+
), f"Expecting a single category, got {scorings[0].shape[0]}"
|
| 535 |
+
|
| 536 |
+
for scoring in scorings:
|
| 537 |
+
assert (
|
| 538 |
+
scoring.shape == scorings[0].shape
|
| 539 |
+
), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
|
| 540 |
+
|
| 541 |
+
selected_imgs = []
|
| 542 |
+
for img_id in range(scorings[0].shape[-1]):
|
| 543 |
+
best = scorings[0][:, :, img_id]
|
| 544 |
+
|
| 545 |
+
for scoring in scorings[1:]:
|
| 546 |
+
current = scoring[:, :, img_id]
|
| 547 |
+
if "local_F1s" in best[0, 0] and "local_F1s" in current[0, 0]:
|
| 548 |
+
# we were able to compute a F1 score for this particular image in both evaluations
|
| 549 |
+
# best["local_F1s"] contains the results at various IoU thresholds. We simply take the average for comparision
|
| 550 |
+
best_score = best[0, 0]["local_F1s"].mean()
|
| 551 |
+
current_score = current[0, 0]["local_F1s"].mean()
|
| 552 |
+
if current_score > best_score:
|
| 553 |
+
best = current
|
| 554 |
+
|
| 555 |
+
else:
|
| 556 |
+
# If we're here, it means that in that in some evaluation we were not able to get a valid local F1
|
| 557 |
+
# This happens when both the predictions and targets are empty. In that case, we can assume it's a perfect prediction
|
| 558 |
+
if "local_F1s" not in current[0, 0]:
|
| 559 |
+
best = current
|
| 560 |
+
selected_imgs.append(best)
|
| 561 |
+
result = np.stack(selected_imgs, axis=-1)
|
| 562 |
+
assert result.shape == scorings[0].shape
|
| 563 |
+
return result
|
| 564 |
+
|
| 565 |
+
def summarize(self):
|
| 566 |
+
self._lazy_init()
|
| 567 |
+
logging.info("Demo evaluator: Summarizing")
|
| 568 |
+
if not is_main_process():
|
| 569 |
+
return {}
|
| 570 |
+
outs = {}
|
| 571 |
+
prefix = "oracle_" if len(self.coco_evals) > 1 else ""
|
| 572 |
+
# if self.rarity_buckets is None:
|
| 573 |
+
self.accumulate(self.eval_img_ids)
|
| 574 |
+
for iou_type, coco_eval in self.coco_evals[0].items():
|
| 575 |
+
print("Demo metric, IoU type={}".format(iou_type))
|
| 576 |
+
coco_eval.summarize()
|
| 577 |
+
|
| 578 |
+
if "bbox" in self.coco_evals[0]:
|
| 579 |
+
for i, value in enumerate(self.coco_evals[0]["bbox"].stats):
|
| 580 |
+
outs[f"coco_eval_bbox_{prefix}{DEMO_METRICS[i]}"] = value
|
| 581 |
+
if "segm" in self.coco_evals[0]:
|
| 582 |
+
for i, value in enumerate(self.coco_evals[0]["segm"].stats):
|
| 583 |
+
outs[f"coco_eval_masks_{prefix}{DEMO_METRICS[i]}"] = value
|
| 584 |
+
# else:
|
| 585 |
+
# total_stats = {}
|
| 586 |
+
# for bucket, img_list in self.rarity_buckets.items():
|
| 587 |
+
# self.accumulate(imgIds=img_list)
|
| 588 |
+
# bucket_name = RARITY_BUCKETS[bucket]
|
| 589 |
+
# for iou_type, coco_eval in self.coco_evals[0].items():
|
| 590 |
+
# print(
|
| 591 |
+
# "Demo metric, IoU type={}, Rarity bucket={}".format(
|
| 592 |
+
# iou_type, bucket_name
|
| 593 |
+
# )
|
| 594 |
+
# )
|
| 595 |
+
# coco_eval.summarize()
|
| 596 |
+
|
| 597 |
+
# if "bbox" in self.coco_evals[0]:
|
| 598 |
+
# if "bbox" not in total_stats:
|
| 599 |
+
# total_stats["bbox"] = np.zeros_like(
|
| 600 |
+
# self.coco_evals[0]["bbox"].stats
|
| 601 |
+
# )
|
| 602 |
+
# total_stats["bbox"] += self.coco_evals[0]["bbox"].stats
|
| 603 |
+
# for i, value in enumerate(self.coco_evals[0]["bbox"].stats):
|
| 604 |
+
# outs[
|
| 605 |
+
# f"coco_eval_bbox_{bucket_name}_{prefix}{DEMO_METRICS[i]}"
|
| 606 |
+
# ] = value
|
| 607 |
+
# if "segm" in self.coco_evals[0]:
|
| 608 |
+
# if "segm" not in total_stats:
|
| 609 |
+
# total_stats["segm"] = np.zeros_like(
|
| 610 |
+
# self.coco_evals[0]["segm"].stats
|
| 611 |
+
# )
|
| 612 |
+
# total_stats["segm"] += self.coco_evals[0]["segm"].stats
|
| 613 |
+
# for i, value in enumerate(self.coco_evals[0]["segm"].stats):
|
| 614 |
+
# outs[
|
| 615 |
+
# f"coco_eval_masks_{bucket_name}_{prefix}{DEMO_METRICS[i]}"
|
| 616 |
+
# ] = value
|
| 617 |
+
|
| 618 |
+
# if "bbox" in total_stats:
|
| 619 |
+
# total_stats["bbox"] /= len(self.rarity_buckets)
|
| 620 |
+
# for i, value in enumerate(total_stats["bbox"]):
|
| 621 |
+
# outs[f"coco_eval_bbox_{prefix}{DEMO_METRICS[i]}"] = value
|
| 622 |
+
# if "segm" in total_stats:
|
| 623 |
+
# total_stats["segm"] /= len(self.rarity_buckets)
|
| 624 |
+
# for i, value in enumerate(total_stats["segm"]):
|
| 625 |
+
# outs[f"coco_eval_masks_{prefix}{DEMO_METRICS[i]}"] = value
|
| 626 |
+
|
| 627 |
+
return outs
|
| 628 |
+
|
| 629 |
+
def accumulate(self, imgIds=None):
|
| 630 |
+
self._lazy_init()
|
| 631 |
+
logging.info(
|
| 632 |
+
f"demo evaluator: Accumulating on {len(imgIds) if imgIds is not None else 'all'} images"
|
| 633 |
+
)
|
| 634 |
+
if not is_main_process():
|
| 635 |
+
return
|
| 636 |
+
|
| 637 |
+
if imgIds is not None:
|
| 638 |
+
for coco_eval in self.coco_evals[0].values():
|
| 639 |
+
coco_eval.params.imgIds = list(imgIds)
|
| 640 |
+
|
| 641 |
+
for coco_eval in self.coco_evals[0].values():
|
| 642 |
+
coco_eval.accumulate()
|
| 643 |
+
|
| 644 |
+
def reset(self):
|
| 645 |
+
self.coco_evals = [{} for _ in range(len(self.coco_gts))]
|
| 646 |
+
for i, coco_gt in enumerate(self.coco_gts):
|
| 647 |
+
for iou_type in self.iou_types:
|
| 648 |
+
self.coco_evals[i][iou_type] = DemoEval(
|
| 649 |
+
coco_gt=coco_gt,
|
| 650 |
+
iouType=iou_type,
|
| 651 |
+
threshold=self.threshold,
|
| 652 |
+
compute_JnF=self.compute_JnF,
|
| 653 |
+
)
|
| 654 |
+
self.coco_evals[i][iou_type].useCats = False
|
| 655 |
+
self.img_ids = []
|
| 656 |
+
self.eval_imgs = {k: [] for k in self.iou_types}
|
| 657 |
+
if self.dump is not None:
|
| 658 |
+
self.dump = []
|
sam3/eval/hota_eval_toolkit/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
sam3/eval/hota_eval_toolkit/run_ytvis_eval.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
"""run_youtube_vis.py
|
| 4 |
+
Run example:
|
| 5 |
+
run_youtube_vis.py --USE_PARALLEL False --METRICS HOTA --TRACKERS_TO_EVAL STEm_Seg
|
| 6 |
+
Command Line Arguments: Defaults, # Comments
|
| 7 |
+
Eval arguments:
|
| 8 |
+
'USE_PARALLEL': False,
|
| 9 |
+
'NUM_PARALLEL_CORES': 8,
|
| 10 |
+
'BREAK_ON_ERROR': True, # Raises exception and exits with error
|
| 11 |
+
'RETURN_ON_ERROR': False, # if not BREAK_ON_ERROR, then returns from function on error
|
| 12 |
+
'LOG_ON_ERROR': os.path.join(code_path, 'error_log.txt'), # if not None, save any errors into a log file.
|
| 13 |
+
'PRINT_RESULTS': True,
|
| 14 |
+
'PRINT_ONLY_COMBINED': False,
|
| 15 |
+
'PRINT_CONFIG': True,
|
| 16 |
+
'TIME_PROGRESS': True,
|
| 17 |
+
'DISPLAY_LESS_PROGRESS': True,
|
| 18 |
+
'OUTPUT_SUMMARY': True,
|
| 19 |
+
'OUTPUT_EMPTY_CLASSES': True, # If False, summary files are not output for classes with no detections
|
| 20 |
+
'OUTPUT_DETAILED': True,
|
| 21 |
+
'PLOT_CURVES': True,
|
| 22 |
+
Dataset arguments:
|
| 23 |
+
'GT_FOLDER': os.path.join(code_path, 'data/gt/youtube_vis/youtube_vis_training'), # Location of GT data
|
| 24 |
+
'TRACKERS_FOLDER': os.path.join(code_path, 'data/trackers/youtube_vis/youtube_vis_training'),
|
| 25 |
+
# Trackers location
|
| 26 |
+
'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER)
|
| 27 |
+
'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder)
|
| 28 |
+
'CLASSES_TO_EVAL': None, # Classes to eval (if None, all classes)
|
| 29 |
+
'SPLIT_TO_EVAL': 'training', # Valid: 'training', 'val'
|
| 30 |
+
'PRINT_CONFIG': True, # Whether to print current config
|
| 31 |
+
'OUTPUT_SUB_FOLDER': '', # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER
|
| 32 |
+
'TRACKER_SUB_FOLDER': 'data', # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER
|
| 33 |
+
'TRACKER_DISPLAY_NAMES': None, # Names of trackers to display, if None: TRACKERS_TO_EVAL
|
| 34 |
+
Metric arguments:
|
| 35 |
+
'METRICS': ['TrackMAP', 'HOTA', 'CLEAR', 'Identity']
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
import argparse
|
| 39 |
+
import os
|
| 40 |
+
import sys
|
| 41 |
+
from multiprocessing import freeze_support
|
| 42 |
+
|
| 43 |
+
from . import trackeval
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def run_ytvis_eval(args=None, gt_json=None, dt_json=None):
|
| 47 |
+
# Command line interface:
|
| 48 |
+
default_eval_config = trackeval.Evaluator.get_default_eval_config()
|
| 49 |
+
# print only combined since TrackMAP is undefined for per sequence breakdowns
|
| 50 |
+
default_eval_config["PRINT_ONLY_COMBINED"] = True
|
| 51 |
+
default_dataset_config = trackeval.datasets.YouTubeVIS.get_default_dataset_config()
|
| 52 |
+
default_metrics_config = {"METRICS": ["HOTA"]}
|
| 53 |
+
config = {
|
| 54 |
+
**default_eval_config,
|
| 55 |
+
**default_dataset_config,
|
| 56 |
+
**default_metrics_config,
|
| 57 |
+
} # Merge default configs
|
| 58 |
+
parser = argparse.ArgumentParser()
|
| 59 |
+
for setting in config.keys():
|
| 60 |
+
if type(config[setting]) == list or type(config[setting]) == type(None):
|
| 61 |
+
parser.add_argument("--" + setting, nargs="+")
|
| 62 |
+
else:
|
| 63 |
+
parser.add_argument("--" + setting)
|
| 64 |
+
args = parser.parse_args(args).__dict__
|
| 65 |
+
for setting in args.keys():
|
| 66 |
+
if args[setting] is not None:
|
| 67 |
+
if type(config[setting]) == type(True):
|
| 68 |
+
if args[setting] == "True":
|
| 69 |
+
x = True
|
| 70 |
+
elif args[setting] == "False":
|
| 71 |
+
x = False
|
| 72 |
+
else:
|
| 73 |
+
raise Exception(
|
| 74 |
+
"Command line parameter " + setting + "must be True or False"
|
| 75 |
+
)
|
| 76 |
+
elif type(config[setting]) == type(1):
|
| 77 |
+
x = int(args[setting])
|
| 78 |
+
elif type(args[setting]) == type(None):
|
| 79 |
+
x = None
|
| 80 |
+
else:
|
| 81 |
+
x = args[setting]
|
| 82 |
+
config[setting] = x
|
| 83 |
+
eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()}
|
| 84 |
+
dataset_config = {
|
| 85 |
+
k: v for k, v in config.items() if k in default_dataset_config.keys()
|
| 86 |
+
}
|
| 87 |
+
metrics_config = {
|
| 88 |
+
k: v for k, v in config.items() if k in default_metrics_config.keys()
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
# Run code
|
| 92 |
+
evaluator = trackeval.Evaluator(eval_config)
|
| 93 |
+
# allow directly specifying the GT JSON data and Tracker (result)
|
| 94 |
+
# JSON data as Python objects, without reading from files.
|
| 95 |
+
dataset_config["GT_JSON_OBJECT"] = gt_json
|
| 96 |
+
dataset_config["TRACKER_JSON_OBJECT"] = dt_json
|
| 97 |
+
dataset_list = [trackeval.datasets.YouTubeVIS(dataset_config)]
|
| 98 |
+
metrics_list = []
|
| 99 |
+
# for metric in [trackeval.metrics.TrackMAP, trackeval.metrics.HOTA, trackeval.metrics.CLEAR,
|
| 100 |
+
# trackeval.metrics.Identity]:
|
| 101 |
+
for metric in [trackeval.metrics.HOTA]:
|
| 102 |
+
if metric.get_name() in metrics_config["METRICS"]:
|
| 103 |
+
metrics_list.append(metric())
|
| 104 |
+
if len(metrics_list) == 0:
|
| 105 |
+
raise Exception("No metrics selected for evaluation")
|
| 106 |
+
output_res, output_msg = evaluator.evaluate(dataset_list, metrics_list)
|
| 107 |
+
return output_res, output_msg
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
import sys
|
| 112 |
+
|
| 113 |
+
freeze_support()
|
| 114 |
+
run_ytvis_eval(sys.argv[1:])
|
sam3/eval/hota_eval_toolkit/trackeval/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
from . import datasets, metrics, utils
|
| 4 |
+
from .eval import Evaluator
|
sam3/eval/hota_eval_toolkit/trackeval/_timing.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
import inspect
|
| 4 |
+
from functools import wraps
|
| 5 |
+
from time import perf_counter
|
| 6 |
+
|
| 7 |
+
DO_TIMING = False
|
| 8 |
+
DISPLAY_LESS_PROGRESS = False
|
| 9 |
+
timer_dict = {}
|
| 10 |
+
counter = 0
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def time(f):
|
| 14 |
+
@wraps(f)
|
| 15 |
+
def wrap(*args, **kw):
|
| 16 |
+
if DO_TIMING:
|
| 17 |
+
# Run function with timing
|
| 18 |
+
ts = perf_counter()
|
| 19 |
+
result = f(*args, **kw)
|
| 20 |
+
te = perf_counter()
|
| 21 |
+
tt = te - ts
|
| 22 |
+
|
| 23 |
+
# Get function name
|
| 24 |
+
arg_names = inspect.getfullargspec(f)[0]
|
| 25 |
+
if arg_names[0] == "self" and DISPLAY_LESS_PROGRESS:
|
| 26 |
+
return result
|
| 27 |
+
elif arg_names[0] == "self":
|
| 28 |
+
method_name = type(args[0]).__name__ + "." + f.__name__
|
| 29 |
+
else:
|
| 30 |
+
method_name = f.__name__
|
| 31 |
+
|
| 32 |
+
# Record accumulative time in each function for analysis
|
| 33 |
+
if method_name in timer_dict.keys():
|
| 34 |
+
timer_dict[method_name] += tt
|
| 35 |
+
else:
|
| 36 |
+
timer_dict[method_name] = tt
|
| 37 |
+
|
| 38 |
+
# If code is finished, display timing summary
|
| 39 |
+
if method_name == "Evaluator.evaluate":
|
| 40 |
+
print("")
|
| 41 |
+
print("Timing analysis:")
|
| 42 |
+
for key, value in timer_dict.items():
|
| 43 |
+
print("%-70s %2.4f sec" % (key, value))
|
| 44 |
+
else:
|
| 45 |
+
# Get function argument values for printing special arguments of interest
|
| 46 |
+
arg_titles = ["tracker", "seq", "cls"]
|
| 47 |
+
arg_vals = []
|
| 48 |
+
for i, a in enumerate(arg_names):
|
| 49 |
+
if a in arg_titles:
|
| 50 |
+
arg_vals.append(args[i])
|
| 51 |
+
arg_text = "(" + ", ".join(arg_vals) + ")"
|
| 52 |
+
|
| 53 |
+
# Display methods and functions with different indentation.
|
| 54 |
+
if arg_names[0] == "self":
|
| 55 |
+
print("%-74s %2.4f sec" % (" " * 4 + method_name + arg_text, tt))
|
| 56 |
+
elif arg_names[0] == "test":
|
| 57 |
+
pass
|
| 58 |
+
else:
|
| 59 |
+
global counter
|
| 60 |
+
counter += 1
|
| 61 |
+
print("%i %-70s %2.4f sec" % (counter, method_name + arg_text, tt))
|
| 62 |
+
|
| 63 |
+
return result
|
| 64 |
+
else:
|
| 65 |
+
# If config["TIME_PROGRESS"] is false, or config["USE_PARALLEL"] is true, run functions normally without timing.
|
| 66 |
+
return f(*args, **kw)
|
| 67 |
+
|
| 68 |
+
return wrap
|
sam3/eval/hota_eval_toolkit/trackeval/datasets/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
from .tao_ow import TAO_OW
|
| 4 |
+
from .youtube_vis import YouTubeVIS
|
sam3/eval/hota_eval_toolkit/trackeval/datasets/_base_dataset.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
import csv
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
import traceback
|
| 7 |
+
import zipfile
|
| 8 |
+
from abc import ABC, abstractmethod
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
from .. import _timing
|
| 14 |
+
from ..utils import TrackEvalException
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class _BaseDataset(ABC):
|
| 18 |
+
@abstractmethod
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.tracker_list = None
|
| 21 |
+
self.seq_list = None
|
| 22 |
+
self.class_list = None
|
| 23 |
+
self.output_fol = None
|
| 24 |
+
self.output_sub_fol = None
|
| 25 |
+
self.should_classes_combine = True
|
| 26 |
+
self.use_super_categories = False
|
| 27 |
+
|
| 28 |
+
# Functions to implement:
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
@abstractmethod
|
| 32 |
+
def get_default_dataset_config(): ...
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def _load_raw_file(self, tracker, seq, is_gt): ...
|
| 36 |
+
|
| 37 |
+
@_timing.time
|
| 38 |
+
@abstractmethod
|
| 39 |
+
def get_preprocessed_seq_data(self, raw_data, cls): ...
|
| 40 |
+
|
| 41 |
+
@abstractmethod
|
| 42 |
+
def _calculate_similarities(self, gt_dets_t, tracker_dets_t): ...
|
| 43 |
+
|
| 44 |
+
# Helper functions for all datasets:
|
| 45 |
+
|
| 46 |
+
@classmethod
|
| 47 |
+
def get_class_name(cls):
|
| 48 |
+
return cls.__name__
|
| 49 |
+
|
| 50 |
+
def get_name(self):
|
| 51 |
+
return self.get_class_name()
|
| 52 |
+
|
| 53 |
+
def get_output_fol(self, tracker):
|
| 54 |
+
return os.path.join(self.output_fol, tracker, self.output_sub_fol)
|
| 55 |
+
|
| 56 |
+
def get_display_name(self, tracker):
|
| 57 |
+
"""Can be overwritten if the trackers name (in files) is different to how it should be displayed.
|
| 58 |
+
By default this method just returns the trackers name as is.
|
| 59 |
+
"""
|
| 60 |
+
return tracker
|
| 61 |
+
|
| 62 |
+
def get_eval_info(self):
|
| 63 |
+
"""Return info about the dataset needed for the Evaluator"""
|
| 64 |
+
return self.tracker_list, self.seq_list, self.class_list
|
| 65 |
+
|
| 66 |
+
@_timing.time
|
| 67 |
+
def get_raw_seq_data(self, tracker, seq):
|
| 68 |
+
"""Loads raw data (tracker and ground-truth) for a single tracker on a single sequence.
|
| 69 |
+
Raw data includes all of the information needed for both preprocessing and evaluation, for all classes.
|
| 70 |
+
A later function (get_processed_seq_data) will perform such preprocessing and extract relevant information for
|
| 71 |
+
the evaluation of each class.
|
| 72 |
+
|
| 73 |
+
This returns a dict which contains the fields:
|
| 74 |
+
[num_timesteps]: integer
|
| 75 |
+
[gt_ids, tracker_ids, gt_classes, tracker_classes, tracker_confidences]:
|
| 76 |
+
list (for each timestep) of 1D NDArrays (for each det).
|
| 77 |
+
[gt_dets, tracker_dets, gt_crowd_ignore_regions]: list (for each timestep) of lists of detections.
|
| 78 |
+
[similarity_scores]: list (for each timestep) of 2D NDArrays.
|
| 79 |
+
[gt_extras]: dict (for each extra) of lists (for each timestep) of 1D NDArrays (for each det).
|
| 80 |
+
|
| 81 |
+
gt_extras contains dataset specific information used for preprocessing such as occlusion and truncation levels.
|
| 82 |
+
|
| 83 |
+
Note that similarities are extracted as part of the dataset and not the metric, because almost all metrics are
|
| 84 |
+
independent of the exact method of calculating the similarity. However datasets are not (e.g. segmentation
|
| 85 |
+
masks vs 2D boxes vs 3D boxes).
|
| 86 |
+
We calculate the similarity before preprocessing because often both preprocessing and evaluation require it and
|
| 87 |
+
we don't wish to calculate this twice.
|
| 88 |
+
We calculate similarity between all gt and tracker classes (not just each class individually) to allow for
|
| 89 |
+
calculation of metrics such as class confusion matrices. Typically the impact of this on performance is low.
|
| 90 |
+
"""
|
| 91 |
+
# Load raw data.
|
| 92 |
+
raw_gt_data = self._load_raw_file(tracker, seq, is_gt=True)
|
| 93 |
+
raw_tracker_data = self._load_raw_file(tracker, seq, is_gt=False)
|
| 94 |
+
raw_data = {**raw_tracker_data, **raw_gt_data} # Merges dictionaries
|
| 95 |
+
|
| 96 |
+
# Calculate similarities for each timestep.
|
| 97 |
+
similarity_scores = []
|
| 98 |
+
for t, (gt_dets_t, tracker_dets_t) in enumerate(
|
| 99 |
+
zip(raw_data["gt_dets"], raw_data["tracker_dets"])
|
| 100 |
+
):
|
| 101 |
+
ious = self._calculate_similarities(gt_dets_t, tracker_dets_t)
|
| 102 |
+
similarity_scores.append(ious)
|
| 103 |
+
raw_data["similarity_scores"] = similarity_scores
|
| 104 |
+
return raw_data
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def _load_simple_text_file(
|
| 108 |
+
file,
|
| 109 |
+
time_col=0,
|
| 110 |
+
id_col=None,
|
| 111 |
+
remove_negative_ids=False,
|
| 112 |
+
valid_filter=None,
|
| 113 |
+
crowd_ignore_filter=None,
|
| 114 |
+
convert_filter=None,
|
| 115 |
+
is_zipped=False,
|
| 116 |
+
zip_file=None,
|
| 117 |
+
force_delimiters=None,
|
| 118 |
+
):
|
| 119 |
+
"""Function that loads data which is in a commonly used text file format.
|
| 120 |
+
Assumes each det is given by one row of a text file.
|
| 121 |
+
There is no limit to the number or meaning of each column,
|
| 122 |
+
however one column needs to give the timestep of each det (time_col) which is default col 0.
|
| 123 |
+
|
| 124 |
+
The file dialect (deliminator, num cols, etc) is determined automatically.
|
| 125 |
+
This function automatically separates dets by timestep,
|
| 126 |
+
and is much faster than alternatives such as np.loadtext or pandas.
|
| 127 |
+
|
| 128 |
+
If remove_negative_ids is True and id_col is not None, dets with negative values in id_col are excluded.
|
| 129 |
+
These are not excluded from ignore data.
|
| 130 |
+
|
| 131 |
+
valid_filter can be used to only include certain classes.
|
| 132 |
+
It is a dict with ints as keys, and lists as values,
|
| 133 |
+
such that a row is included if "row[key].lower() is in value" for all key/value pairs in the dict.
|
| 134 |
+
If None, all classes are included.
|
| 135 |
+
|
| 136 |
+
crowd_ignore_filter can be used to read crowd_ignore regions separately. It has the same format as valid filter.
|
| 137 |
+
|
| 138 |
+
convert_filter can be used to convert value read to another format.
|
| 139 |
+
This is used most commonly to convert classes given as string to a class id.
|
| 140 |
+
This is a dict such that the key is the column to convert, and the value is another dict giving the mapping.
|
| 141 |
+
|
| 142 |
+
Optionally, input files could be a zip of multiple text files for storage efficiency.
|
| 143 |
+
|
| 144 |
+
Returns read_data and ignore_data.
|
| 145 |
+
Each is a dict (with keys as timesteps as strings) of lists (over dets) of lists (over column values).
|
| 146 |
+
Note that all data is returned as strings, and must be converted to float/int later if needed.
|
| 147 |
+
Note that timesteps will not be present in the returned dict keys if there are no dets for them
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
if remove_negative_ids and id_col is None:
|
| 151 |
+
raise TrackEvalException(
|
| 152 |
+
"remove_negative_ids is True, but id_col is not given."
|
| 153 |
+
)
|
| 154 |
+
if crowd_ignore_filter is None:
|
| 155 |
+
crowd_ignore_filter = {}
|
| 156 |
+
if convert_filter is None:
|
| 157 |
+
convert_filter = {}
|
| 158 |
+
try:
|
| 159 |
+
if is_zipped: # Either open file directly or within a zip.
|
| 160 |
+
if zip_file is None:
|
| 161 |
+
raise TrackEvalException(
|
| 162 |
+
"is_zipped set to True, but no zip_file is given."
|
| 163 |
+
)
|
| 164 |
+
archive = zipfile.ZipFile(os.path.join(zip_file), "r")
|
| 165 |
+
fp = io.TextIOWrapper(archive.open(file, "r"))
|
| 166 |
+
else:
|
| 167 |
+
fp = open(file)
|
| 168 |
+
read_data = {}
|
| 169 |
+
crowd_ignore_data = {}
|
| 170 |
+
fp.seek(0, os.SEEK_END)
|
| 171 |
+
# check if file is empty
|
| 172 |
+
if fp.tell():
|
| 173 |
+
fp.seek(0)
|
| 174 |
+
dialect = csv.Sniffer().sniff(
|
| 175 |
+
fp.readline(), delimiters=force_delimiters
|
| 176 |
+
) # Auto determine structure.
|
| 177 |
+
dialect.skipinitialspace = (
|
| 178 |
+
True # Deal with extra spaces between columns
|
| 179 |
+
)
|
| 180 |
+
fp.seek(0)
|
| 181 |
+
reader = csv.reader(fp, dialect)
|
| 182 |
+
for row in reader:
|
| 183 |
+
try:
|
| 184 |
+
# Deal with extra trailing spaces at the end of rows
|
| 185 |
+
if row[-1] in "":
|
| 186 |
+
row = row[:-1]
|
| 187 |
+
timestep = str(int(float(row[time_col])))
|
| 188 |
+
# Read ignore regions separately.
|
| 189 |
+
is_ignored = False
|
| 190 |
+
for ignore_key, ignore_value in crowd_ignore_filter.items():
|
| 191 |
+
if row[ignore_key].lower() in ignore_value:
|
| 192 |
+
# Convert values in one column (e.g. string to id)
|
| 193 |
+
for (
|
| 194 |
+
convert_key,
|
| 195 |
+
convert_value,
|
| 196 |
+
) in convert_filter.items():
|
| 197 |
+
row[convert_key] = convert_value[
|
| 198 |
+
row[convert_key].lower()
|
| 199 |
+
]
|
| 200 |
+
# Save data separated by timestep.
|
| 201 |
+
if timestep in crowd_ignore_data.keys():
|
| 202 |
+
crowd_ignore_data[timestep].append(row)
|
| 203 |
+
else:
|
| 204 |
+
crowd_ignore_data[timestep] = [row]
|
| 205 |
+
is_ignored = True
|
| 206 |
+
if (
|
| 207 |
+
is_ignored
|
| 208 |
+
): # if det is an ignore region, it cannot be a normal det.
|
| 209 |
+
continue
|
| 210 |
+
# Exclude some dets if not valid.
|
| 211 |
+
if valid_filter is not None:
|
| 212 |
+
for key, value in valid_filter.items():
|
| 213 |
+
if row[key].lower() not in value:
|
| 214 |
+
continue
|
| 215 |
+
if remove_negative_ids:
|
| 216 |
+
if int(float(row[id_col])) < 0:
|
| 217 |
+
continue
|
| 218 |
+
# Convert values in one column (e.g. string to id)
|
| 219 |
+
for convert_key, convert_value in convert_filter.items():
|
| 220 |
+
row[convert_key] = convert_value[row[convert_key].lower()]
|
| 221 |
+
# Save data separated by timestep.
|
| 222 |
+
if timestep in read_data.keys():
|
| 223 |
+
read_data[timestep].append(row)
|
| 224 |
+
else:
|
| 225 |
+
read_data[timestep] = [row]
|
| 226 |
+
except Exception:
|
| 227 |
+
exc_str_init = (
|
| 228 |
+
"In file %s the following line cannot be read correctly: \n"
|
| 229 |
+
% os.path.basename(file)
|
| 230 |
+
)
|
| 231 |
+
exc_str = " ".join([exc_str_init] + row)
|
| 232 |
+
raise TrackEvalException(exc_str)
|
| 233 |
+
fp.close()
|
| 234 |
+
except Exception:
|
| 235 |
+
print("Error loading file: %s, printing traceback." % file)
|
| 236 |
+
traceback.print_exc()
|
| 237 |
+
raise TrackEvalException(
|
| 238 |
+
"File %s cannot be read because it is either not present or invalidly formatted"
|
| 239 |
+
% os.path.basename(file)
|
| 240 |
+
)
|
| 241 |
+
return read_data, crowd_ignore_data
|
| 242 |
+
|
| 243 |
+
@staticmethod
|
| 244 |
+
def _calculate_mask_ious(masks1, masks2, is_encoded=False, do_ioa=False):
|
| 245 |
+
"""Calculates the IOU (intersection over union) between two arrays of segmentation masks.
|
| 246 |
+
If is_encoded a run length encoding with pycocotools is assumed as input format, otherwise an input of numpy
|
| 247 |
+
arrays of the shape (num_masks, height, width) is assumed and the encoding is performed.
|
| 248 |
+
If do_ioa (intersection over area) , then calculates the intersection over the area of masks1 - this is commonly
|
| 249 |
+
used to determine if detections are within crowd ignore region.
|
| 250 |
+
:param masks1: first set of masks (numpy array of shape (num_masks, height, width) if not encoded,
|
| 251 |
+
else pycocotools rle encoded format)
|
| 252 |
+
:param masks2: second set of masks (numpy array of shape (num_masks, height, width) if not encoded,
|
| 253 |
+
else pycocotools rle encoded format)
|
| 254 |
+
:param is_encoded: whether the input is in pycocotools rle encoded format
|
| 255 |
+
:param do_ioa: whether to perform IoA computation
|
| 256 |
+
:return: the IoU/IoA scores
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
# Only loaded when run to reduce minimum requirements
|
| 260 |
+
from pycocotools import mask as mask_utils
|
| 261 |
+
|
| 262 |
+
# use pycocotools for run length encoding of masks
|
| 263 |
+
if not is_encoded:
|
| 264 |
+
masks1 = mask_utils.encode(
|
| 265 |
+
np.array(np.transpose(masks1, (1, 2, 0)), order="F")
|
| 266 |
+
)
|
| 267 |
+
masks2 = mask_utils.encode(
|
| 268 |
+
np.array(np.transpose(masks2, (1, 2, 0)), order="F")
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# use pycocotools for iou computation of rle encoded masks
|
| 272 |
+
ious = mask_utils.iou(masks1, masks2, [do_ioa] * len(masks2))
|
| 273 |
+
if len(masks1) == 0 or len(masks2) == 0:
|
| 274 |
+
ious = np.asarray(ious).reshape(len(masks1), len(masks2))
|
| 275 |
+
assert (ious >= 0 - np.finfo("float").eps).all()
|
| 276 |
+
assert (ious <= 1 + np.finfo("float").eps).all()
|
| 277 |
+
|
| 278 |
+
return ious
|
| 279 |
+
|
| 280 |
+
@staticmethod
|
| 281 |
+
def _calculate_box_ious(bboxes1, bboxes2, box_format="xywh", do_ioa=False):
|
| 282 |
+
"""Calculates the IOU (intersection over union) between two arrays of boxes.
|
| 283 |
+
Allows variable box formats ('xywh' and 'x0y0x1y1').
|
| 284 |
+
If do_ioa (intersection over area) , then calculates the intersection over the area of boxes1 - this is commonly
|
| 285 |
+
used to determine if detections are within crowd ignore region.
|
| 286 |
+
"""
|
| 287 |
+
if box_format in "xywh":
|
| 288 |
+
# layout: (x0, y0, w, h)
|
| 289 |
+
bboxes1 = deepcopy(bboxes1)
|
| 290 |
+
bboxes2 = deepcopy(bboxes2)
|
| 291 |
+
|
| 292 |
+
bboxes1[:, 2] = bboxes1[:, 0] + bboxes1[:, 2]
|
| 293 |
+
bboxes1[:, 3] = bboxes1[:, 1] + bboxes1[:, 3]
|
| 294 |
+
bboxes2[:, 2] = bboxes2[:, 0] + bboxes2[:, 2]
|
| 295 |
+
bboxes2[:, 3] = bboxes2[:, 1] + bboxes2[:, 3]
|
| 296 |
+
elif box_format not in "x0y0x1y1":
|
| 297 |
+
raise (TrackEvalException("box_format %s is not implemented" % box_format))
|
| 298 |
+
|
| 299 |
+
# layout: (x0, y0, x1, y1)
|
| 300 |
+
min_ = np.minimum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :])
|
| 301 |
+
max_ = np.maximum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :])
|
| 302 |
+
intersection = np.maximum(min_[..., 2] - max_[..., 0], 0) * np.maximum(
|
| 303 |
+
min_[..., 3] - max_[..., 1], 0
|
| 304 |
+
)
|
| 305 |
+
area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (
|
| 306 |
+
bboxes1[..., 3] - bboxes1[..., 1]
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
if do_ioa:
|
| 310 |
+
ioas = np.zeros_like(intersection)
|
| 311 |
+
valid_mask = area1 > 0 + np.finfo("float").eps
|
| 312 |
+
ioas[valid_mask, :] = (
|
| 313 |
+
intersection[valid_mask, :] / area1[valid_mask][:, np.newaxis]
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
return ioas
|
| 317 |
+
else:
|
| 318 |
+
area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (
|
| 319 |
+
bboxes2[..., 3] - bboxes2[..., 1]
|
| 320 |
+
)
|
| 321 |
+
union = area1[:, np.newaxis] + area2[np.newaxis, :] - intersection
|
| 322 |
+
intersection[area1 <= 0 + np.finfo("float").eps, :] = 0
|
| 323 |
+
intersection[:, area2 <= 0 + np.finfo("float").eps] = 0
|
| 324 |
+
intersection[union <= 0 + np.finfo("float").eps] = 0
|
| 325 |
+
union[union <= 0 + np.finfo("float").eps] = 1
|
| 326 |
+
ious = intersection / union
|
| 327 |
+
return ious
|
| 328 |
+
|
| 329 |
+
@staticmethod
|
| 330 |
+
def _calculate_euclidean_similarity(dets1, dets2, zero_distance=2.0):
|
| 331 |
+
"""Calculates the euclidean distance between two sets of detections, and then converts this into a similarity
|
| 332 |
+
measure with values between 0 and 1 using the following formula: sim = max(0, 1 - dist/zero_distance).
|
| 333 |
+
The default zero_distance of 2.0, corresponds to the default used in MOT15_3D, such that a 0.5 similarity
|
| 334 |
+
threshold corresponds to a 1m distance threshold for TPs.
|
| 335 |
+
"""
|
| 336 |
+
dist = np.linalg.norm(dets1[:, np.newaxis] - dets2[np.newaxis, :], axis=2)
|
| 337 |
+
sim = np.maximum(0, 1 - dist / zero_distance)
|
| 338 |
+
return sim
|
| 339 |
+
|
| 340 |
+
@staticmethod
|
| 341 |
+
def _check_unique_ids(data, after_preproc=False):
|
| 342 |
+
"""Check the requirement that the tracker_ids and gt_ids are unique per timestep"""
|
| 343 |
+
gt_ids = data["gt_ids"]
|
| 344 |
+
tracker_ids = data["tracker_ids"]
|
| 345 |
+
for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(gt_ids, tracker_ids)):
|
| 346 |
+
if len(tracker_ids_t) > 0:
|
| 347 |
+
unique_ids, counts = np.unique(tracker_ids_t, return_counts=True)
|
| 348 |
+
if np.max(counts) != 1:
|
| 349 |
+
duplicate_ids = unique_ids[counts > 1]
|
| 350 |
+
exc_str_init = (
|
| 351 |
+
"Tracker predicts the same ID more than once in a single timestep "
|
| 352 |
+
"(seq: %s, frame: %i, ids:" % (data["seq"], t + 1)
|
| 353 |
+
)
|
| 354 |
+
exc_str = (
|
| 355 |
+
" ".join([exc_str_init] + [str(d) for d in duplicate_ids]) + ")"
|
| 356 |
+
)
|
| 357 |
+
if after_preproc:
|
| 358 |
+
exc_str_init += (
|
| 359 |
+
"\n Note that this error occurred after preprocessing (but not before), "
|
| 360 |
+
"so ids may not be as in file, and something seems wrong with preproc."
|
| 361 |
+
)
|
| 362 |
+
raise TrackEvalException(exc_str)
|
| 363 |
+
if len(gt_ids_t) > 0:
|
| 364 |
+
unique_ids, counts = np.unique(gt_ids_t, return_counts=True)
|
| 365 |
+
if np.max(counts) != 1:
|
| 366 |
+
duplicate_ids = unique_ids[counts > 1]
|
| 367 |
+
exc_str_init = (
|
| 368 |
+
"Ground-truth has the same ID more than once in a single timestep "
|
| 369 |
+
"(seq: %s, frame: %i, ids:" % (data["seq"], t + 1)
|
| 370 |
+
)
|
| 371 |
+
exc_str = (
|
| 372 |
+
" ".join([exc_str_init] + [str(d) for d in duplicate_ids]) + ")"
|
| 373 |
+
)
|
| 374 |
+
if after_preproc:
|
| 375 |
+
exc_str_init += (
|
| 376 |
+
"\n Note that this error occurred after preprocessing (but not before), "
|
| 377 |
+
"so ids may not be as in file, and something seems wrong with preproc."
|
| 378 |
+
)
|
| 379 |
+
raise TrackEvalException(exc_str)
|
sam3/eval/hota_eval_toolkit/trackeval/datasets/tao_ow.py
ADDED
|
@@ -0,0 +1,891 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
import itertools
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from scipy.optimize import linear_sum_assignment
|
| 10 |
+
|
| 11 |
+
from .. import _timing, utils
|
| 12 |
+
from ..utils import TrackEvalException
|
| 13 |
+
from ._base_dataset import _BaseDataset
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TAO_OW(_BaseDataset):
|
| 17 |
+
"""Dataset class for TAO tracking"""
|
| 18 |
+
|
| 19 |
+
@staticmethod
|
| 20 |
+
def get_default_dataset_config():
|
| 21 |
+
"""Default class config values"""
|
| 22 |
+
code_path = utils.get_code_path()
|
| 23 |
+
default_config = {
|
| 24 |
+
"GT_FOLDER": os.path.join(
|
| 25 |
+
code_path, "data/gt/tao/tao_training"
|
| 26 |
+
), # Location of GT data
|
| 27 |
+
"TRACKERS_FOLDER": os.path.join(
|
| 28 |
+
code_path, "data/trackers/tao/tao_training"
|
| 29 |
+
), # Trackers location
|
| 30 |
+
"OUTPUT_FOLDER": None, # Where to save eval results (if None, same as TRACKERS_FOLDER)
|
| 31 |
+
"TRACKERS_TO_EVAL": None, # Filenames of trackers to eval (if None, all in folder)
|
| 32 |
+
"CLASSES_TO_EVAL": None, # Classes to eval (if None, all classes)
|
| 33 |
+
"SPLIT_TO_EVAL": "training", # Valid: 'training', 'val'
|
| 34 |
+
"PRINT_CONFIG": True, # Whether to print current config
|
| 35 |
+
"TRACKER_SUB_FOLDER": "data", # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER
|
| 36 |
+
"OUTPUT_SUB_FOLDER": "", # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER
|
| 37 |
+
"TRACKER_DISPLAY_NAMES": None, # Names of trackers to display, if None: TRACKERS_TO_EVAL
|
| 38 |
+
"MAX_DETECTIONS": 300, # Number of maximal allowed detections per image (0 for unlimited)
|
| 39 |
+
"SUBSET": "all",
|
| 40 |
+
}
|
| 41 |
+
return default_config
|
| 42 |
+
|
| 43 |
+
def __init__(self, config=None):
|
| 44 |
+
"""Initialise dataset, checking that all required files are present"""
|
| 45 |
+
super().__init__()
|
| 46 |
+
# Fill non-given config values with defaults
|
| 47 |
+
self.config = utils.init_config(
|
| 48 |
+
config, self.get_default_dataset_config(), self.get_name()
|
| 49 |
+
)
|
| 50 |
+
self.gt_fol = self.config["GT_FOLDER"]
|
| 51 |
+
self.tracker_fol = self.config["TRACKERS_FOLDER"]
|
| 52 |
+
self.should_classes_combine = True
|
| 53 |
+
self.use_super_categories = False
|
| 54 |
+
|
| 55 |
+
self.tracker_sub_fol = self.config["TRACKER_SUB_FOLDER"]
|
| 56 |
+
self.output_fol = self.config["OUTPUT_FOLDER"]
|
| 57 |
+
if self.output_fol is None:
|
| 58 |
+
self.output_fol = self.tracker_fol
|
| 59 |
+
self.output_sub_fol = self.config["OUTPUT_SUB_FOLDER"]
|
| 60 |
+
|
| 61 |
+
gt_dir_files = [
|
| 62 |
+
file for file in os.listdir(self.gt_fol) if file.endswith(".json")
|
| 63 |
+
]
|
| 64 |
+
if len(gt_dir_files) != 1:
|
| 65 |
+
raise TrackEvalException(
|
| 66 |
+
self.gt_fol + " does not contain exactly one json file."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
with open(os.path.join(self.gt_fol, gt_dir_files[0])) as f:
|
| 70 |
+
self.gt_data = json.load(f)
|
| 71 |
+
|
| 72 |
+
self.subset = self.config["SUBSET"]
|
| 73 |
+
if self.subset != "all":
|
| 74 |
+
# Split GT data into `known`, `unknown` or `distractor`
|
| 75 |
+
self._split_known_unknown_distractor()
|
| 76 |
+
self.gt_data = self._filter_gt_data(self.gt_data)
|
| 77 |
+
|
| 78 |
+
# merge categories marked with a merged tag in TAO dataset
|
| 79 |
+
self._merge_categories(self.gt_data["annotations"] + self.gt_data["tracks"])
|
| 80 |
+
|
| 81 |
+
# Get sequences to eval and sequence information
|
| 82 |
+
self.seq_list = [
|
| 83 |
+
vid["name"].replace("/", "-") for vid in self.gt_data["videos"]
|
| 84 |
+
]
|
| 85 |
+
self.seq_name_to_seq_id = {
|
| 86 |
+
vid["name"].replace("/", "-"): vid["id"] for vid in self.gt_data["videos"]
|
| 87 |
+
}
|
| 88 |
+
# compute mappings from videos to annotation data
|
| 89 |
+
self.videos_to_gt_tracks, self.videos_to_gt_images = self._compute_vid_mappings(
|
| 90 |
+
self.gt_data["annotations"]
|
| 91 |
+
)
|
| 92 |
+
# compute sequence lengths
|
| 93 |
+
self.seq_lengths = {vid["id"]: 0 for vid in self.gt_data["videos"]}
|
| 94 |
+
for img in self.gt_data["images"]:
|
| 95 |
+
self.seq_lengths[img["video_id"]] += 1
|
| 96 |
+
self.seq_to_images_to_timestep = self._compute_image_to_timestep_mappings()
|
| 97 |
+
self.seq_to_classes = {
|
| 98 |
+
vid["id"]: {
|
| 99 |
+
"pos_cat_ids": list(
|
| 100 |
+
{
|
| 101 |
+
track["category_id"]
|
| 102 |
+
for track in self.videos_to_gt_tracks[vid["id"]]
|
| 103 |
+
}
|
| 104 |
+
),
|
| 105 |
+
"neg_cat_ids": vid["neg_category_ids"],
|
| 106 |
+
"not_exhaustively_labeled_cat_ids": vid["not_exhaustive_category_ids"],
|
| 107 |
+
}
|
| 108 |
+
for vid in self.gt_data["videos"]
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
# Get classes to eval
|
| 112 |
+
considered_vid_ids = [self.seq_name_to_seq_id[vid] for vid in self.seq_list]
|
| 113 |
+
seen_cats = set(
|
| 114 |
+
[
|
| 115 |
+
cat_id
|
| 116 |
+
for vid_id in considered_vid_ids
|
| 117 |
+
for cat_id in self.seq_to_classes[vid_id]["pos_cat_ids"]
|
| 118 |
+
]
|
| 119 |
+
)
|
| 120 |
+
# only classes with ground truth are evaluated in TAO
|
| 121 |
+
self.valid_classes = [
|
| 122 |
+
cls["name"] for cls in self.gt_data["categories"] if cls["id"] in seen_cats
|
| 123 |
+
]
|
| 124 |
+
# cls_name_to_cls_id_map = {cls['name']: cls['id'] for cls in self.gt_data['categories']}
|
| 125 |
+
|
| 126 |
+
if self.config["CLASSES_TO_EVAL"]:
|
| 127 |
+
# self.class_list = [cls.lower() if cls.lower() in self.valid_classes else None
|
| 128 |
+
# for cls in self.config['CLASSES_TO_EVAL']]
|
| 129 |
+
self.class_list = ["object"] # class-agnostic
|
| 130 |
+
if not all(self.class_list):
|
| 131 |
+
raise TrackEvalException(
|
| 132 |
+
"Attempted to evaluate an invalid class. Only classes "
|
| 133 |
+
+ ", ".join(self.valid_classes)
|
| 134 |
+
+ " are valid (classes present in ground truth data)."
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
# self.class_list = [cls for cls in self.valid_classes]
|
| 138 |
+
self.class_list = ["object"] # class-agnostic
|
| 139 |
+
# self.class_name_to_class_id = {k: v for k, v in cls_name_to_cls_id_map.items() if k in self.class_list}
|
| 140 |
+
self.class_name_to_class_id = {"object": 1} # class-agnostic
|
| 141 |
+
|
| 142 |
+
# Get trackers to eval
|
| 143 |
+
if self.config["TRACKERS_TO_EVAL"] is None:
|
| 144 |
+
self.tracker_list = os.listdir(self.tracker_fol)
|
| 145 |
+
else:
|
| 146 |
+
self.tracker_list = self.config["TRACKERS_TO_EVAL"]
|
| 147 |
+
|
| 148 |
+
if self.config["TRACKER_DISPLAY_NAMES"] is None:
|
| 149 |
+
self.tracker_to_disp = dict(zip(self.tracker_list, self.tracker_list))
|
| 150 |
+
elif (self.config["TRACKERS_TO_EVAL"] is not None) and (
|
| 151 |
+
len(self.config["TRACKER_DISPLAY_NAMES"]) == len(self.tracker_list)
|
| 152 |
+
):
|
| 153 |
+
self.tracker_to_disp = dict(
|
| 154 |
+
zip(self.tracker_list, self.config["TRACKER_DISPLAY_NAMES"])
|
| 155 |
+
)
|
| 156 |
+
else:
|
| 157 |
+
raise TrackEvalException(
|
| 158 |
+
"List of tracker files and tracker display names do not match."
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
self.tracker_data = {tracker: dict() for tracker in self.tracker_list}
|
| 162 |
+
|
| 163 |
+
for tracker in self.tracker_list:
|
| 164 |
+
tr_dir_files = [
|
| 165 |
+
file
|
| 166 |
+
for file in os.listdir(
|
| 167 |
+
os.path.join(self.tracker_fol, tracker, self.tracker_sub_fol)
|
| 168 |
+
)
|
| 169 |
+
if file.endswith(".json")
|
| 170 |
+
]
|
| 171 |
+
if len(tr_dir_files) != 1:
|
| 172 |
+
raise TrackEvalException(
|
| 173 |
+
os.path.join(self.tracker_fol, tracker, self.tracker_sub_fol)
|
| 174 |
+
+ " does not contain exactly one json file."
|
| 175 |
+
)
|
| 176 |
+
with open(
|
| 177 |
+
os.path.join(
|
| 178 |
+
self.tracker_fol, tracker, self.tracker_sub_fol, tr_dir_files[0]
|
| 179 |
+
)
|
| 180 |
+
) as f:
|
| 181 |
+
curr_data = json.load(f)
|
| 182 |
+
|
| 183 |
+
# limit detections if MAX_DETECTIONS > 0
|
| 184 |
+
if self.config["MAX_DETECTIONS"]:
|
| 185 |
+
curr_data = self._limit_dets_per_image(curr_data)
|
| 186 |
+
|
| 187 |
+
# fill missing video ids
|
| 188 |
+
self._fill_video_ids_inplace(curr_data)
|
| 189 |
+
|
| 190 |
+
# make track ids unique over whole evaluation set
|
| 191 |
+
self._make_track_ids_unique(curr_data)
|
| 192 |
+
|
| 193 |
+
# merge categories marked with a merged tag in TAO dataset
|
| 194 |
+
self._merge_categories(curr_data)
|
| 195 |
+
|
| 196 |
+
# get tracker sequence information
|
| 197 |
+
curr_videos_to_tracker_tracks, curr_videos_to_tracker_images = (
|
| 198 |
+
self._compute_vid_mappings(curr_data)
|
| 199 |
+
)
|
| 200 |
+
self.tracker_data[tracker]["vids_to_tracks"] = curr_videos_to_tracker_tracks
|
| 201 |
+
self.tracker_data[tracker]["vids_to_images"] = curr_videos_to_tracker_images
|
| 202 |
+
|
| 203 |
+
def get_display_name(self, tracker):
|
| 204 |
+
return self.tracker_to_disp[tracker]
|
| 205 |
+
|
| 206 |
+
def _load_raw_file(self, tracker, seq, is_gt):
|
| 207 |
+
"""Load a file (gt or tracker) in the TAO format
|
| 208 |
+
|
| 209 |
+
If is_gt, this returns a dict which contains the fields:
|
| 210 |
+
[gt_ids, gt_classes] : list (for each timestep) of 1D NDArrays (for each det).
|
| 211 |
+
[gt_dets]: list (for each timestep) of lists of detections.
|
| 212 |
+
[classes_to_gt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as
|
| 213 |
+
keys and corresponding segmentations as values) for each track
|
| 214 |
+
[classes_to_gt_track_ids, classes_to_gt_track_areas, classes_to_gt_track_lengths]: dictionary with class values
|
| 215 |
+
as keys and lists (for each track) as values
|
| 216 |
+
|
| 217 |
+
if not is_gt, this returns a dict which contains the fields:
|
| 218 |
+
[tracker_ids, tracker_classes, tracker_confidences] : list (for each timestep) of 1D NDArrays (for each det).
|
| 219 |
+
[tracker_dets]: list (for each timestep) of lists of detections.
|
| 220 |
+
[classes_to_dt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as
|
| 221 |
+
keys and corresponding segmentations as values) for each track
|
| 222 |
+
[classes_to_dt_track_ids, classes_to_dt_track_areas, classes_to_dt_track_lengths]: dictionary with class values
|
| 223 |
+
as keys and lists as values
|
| 224 |
+
[classes_to_dt_track_scores]: dictionary with class values as keys and 1D numpy arrays as values
|
| 225 |
+
"""
|
| 226 |
+
seq_id = self.seq_name_to_seq_id[seq]
|
| 227 |
+
# File location
|
| 228 |
+
if is_gt:
|
| 229 |
+
imgs = self.videos_to_gt_images[seq_id]
|
| 230 |
+
else:
|
| 231 |
+
imgs = self.tracker_data[tracker]["vids_to_images"][seq_id]
|
| 232 |
+
|
| 233 |
+
# Convert data to required format
|
| 234 |
+
num_timesteps = self.seq_lengths[seq_id]
|
| 235 |
+
img_to_timestep = self.seq_to_images_to_timestep[seq_id]
|
| 236 |
+
data_keys = ["ids", "classes", "dets"]
|
| 237 |
+
if not is_gt:
|
| 238 |
+
data_keys += ["tracker_confidences"]
|
| 239 |
+
raw_data = {key: [None] * num_timesteps for key in data_keys}
|
| 240 |
+
for img in imgs:
|
| 241 |
+
# some tracker data contains images without any ground truth information, these are ignored
|
| 242 |
+
try:
|
| 243 |
+
t = img_to_timestep[img["id"]]
|
| 244 |
+
except KeyError:
|
| 245 |
+
continue
|
| 246 |
+
annotations = img["annotations"]
|
| 247 |
+
raw_data["dets"][t] = np.atleast_2d(
|
| 248 |
+
[ann["bbox"] for ann in annotations]
|
| 249 |
+
).astype(float)
|
| 250 |
+
raw_data["ids"][t] = np.atleast_1d(
|
| 251 |
+
[ann["track_id"] for ann in annotations]
|
| 252 |
+
).astype(int)
|
| 253 |
+
raw_data["classes"][t] = np.atleast_1d([1 for _ in annotations]).astype(
|
| 254 |
+
int
|
| 255 |
+
) # class-agnostic
|
| 256 |
+
if not is_gt:
|
| 257 |
+
raw_data["tracker_confidences"][t] = np.atleast_1d(
|
| 258 |
+
[ann["score"] for ann in annotations]
|
| 259 |
+
).astype(float)
|
| 260 |
+
|
| 261 |
+
for t, d in enumerate(raw_data["dets"]):
|
| 262 |
+
if d is None:
|
| 263 |
+
raw_data["dets"][t] = np.empty((0, 4)).astype(float)
|
| 264 |
+
raw_data["ids"][t] = np.empty(0).astype(int)
|
| 265 |
+
raw_data["classes"][t] = np.empty(0).astype(int)
|
| 266 |
+
if not is_gt:
|
| 267 |
+
raw_data["tracker_confidences"][t] = np.empty(0)
|
| 268 |
+
|
| 269 |
+
if is_gt:
|
| 270 |
+
key_map = {"ids": "gt_ids", "classes": "gt_classes", "dets": "gt_dets"}
|
| 271 |
+
else:
|
| 272 |
+
key_map = {
|
| 273 |
+
"ids": "tracker_ids",
|
| 274 |
+
"classes": "tracker_classes",
|
| 275 |
+
"dets": "tracker_dets",
|
| 276 |
+
}
|
| 277 |
+
for k, v in key_map.items():
|
| 278 |
+
raw_data[v] = raw_data.pop(k)
|
| 279 |
+
|
| 280 |
+
# all_classes = [self.class_name_to_class_id[cls] for cls in self.class_list]
|
| 281 |
+
all_classes = [1] # class-agnostic
|
| 282 |
+
|
| 283 |
+
if is_gt:
|
| 284 |
+
classes_to_consider = all_classes
|
| 285 |
+
all_tracks = self.videos_to_gt_tracks[seq_id]
|
| 286 |
+
else:
|
| 287 |
+
# classes_to_consider = self.seq_to_classes[seq_id]['pos_cat_ids'] \
|
| 288 |
+
# + self.seq_to_classes[seq_id]['neg_cat_ids']
|
| 289 |
+
classes_to_consider = all_classes # class-agnostic
|
| 290 |
+
all_tracks = self.tracker_data[tracker]["vids_to_tracks"][seq_id]
|
| 291 |
+
|
| 292 |
+
# classes_to_tracks = {cls: [track for track in all_tracks if track['category_id'] == cls]
|
| 293 |
+
# if cls in classes_to_consider else [] for cls in all_classes}
|
| 294 |
+
classes_to_tracks = {
|
| 295 |
+
cls: [track for track in all_tracks] if cls in classes_to_consider else []
|
| 296 |
+
for cls in all_classes
|
| 297 |
+
} # class-agnostic
|
| 298 |
+
|
| 299 |
+
# mapping from classes to track information
|
| 300 |
+
raw_data["classes_to_tracks"] = {
|
| 301 |
+
cls: [
|
| 302 |
+
{
|
| 303 |
+
det["image_id"]: np.atleast_1d(det["bbox"])
|
| 304 |
+
for det in track["annotations"]
|
| 305 |
+
}
|
| 306 |
+
for track in tracks
|
| 307 |
+
]
|
| 308 |
+
for cls, tracks in classes_to_tracks.items()
|
| 309 |
+
}
|
| 310 |
+
raw_data["classes_to_track_ids"] = {
|
| 311 |
+
cls: [track["id"] for track in tracks]
|
| 312 |
+
for cls, tracks in classes_to_tracks.items()
|
| 313 |
+
}
|
| 314 |
+
raw_data["classes_to_track_areas"] = {
|
| 315 |
+
cls: [track["area"] for track in tracks]
|
| 316 |
+
for cls, tracks in classes_to_tracks.items()
|
| 317 |
+
}
|
| 318 |
+
raw_data["classes_to_track_lengths"] = {
|
| 319 |
+
cls: [len(track["annotations"]) for track in tracks]
|
| 320 |
+
for cls, tracks in classes_to_tracks.items()
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
if not is_gt:
|
| 324 |
+
raw_data["classes_to_dt_track_scores"] = {
|
| 325 |
+
cls: np.array(
|
| 326 |
+
[
|
| 327 |
+
np.mean([float(x["score"]) for x in track["annotations"]])
|
| 328 |
+
for track in tracks
|
| 329 |
+
]
|
| 330 |
+
)
|
| 331 |
+
for cls, tracks in classes_to_tracks.items()
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
if is_gt:
|
| 335 |
+
key_map = {
|
| 336 |
+
"classes_to_tracks": "classes_to_gt_tracks",
|
| 337 |
+
"classes_to_track_ids": "classes_to_gt_track_ids",
|
| 338 |
+
"classes_to_track_lengths": "classes_to_gt_track_lengths",
|
| 339 |
+
"classes_to_track_areas": "classes_to_gt_track_areas",
|
| 340 |
+
}
|
| 341 |
+
else:
|
| 342 |
+
key_map = {
|
| 343 |
+
"classes_to_tracks": "classes_to_dt_tracks",
|
| 344 |
+
"classes_to_track_ids": "classes_to_dt_track_ids",
|
| 345 |
+
"classes_to_track_lengths": "classes_to_dt_track_lengths",
|
| 346 |
+
"classes_to_track_areas": "classes_to_dt_track_areas",
|
| 347 |
+
}
|
| 348 |
+
for k, v in key_map.items():
|
| 349 |
+
raw_data[v] = raw_data.pop(k)
|
| 350 |
+
|
| 351 |
+
raw_data["num_timesteps"] = num_timesteps
|
| 352 |
+
raw_data["neg_cat_ids"] = self.seq_to_classes[seq_id]["neg_cat_ids"]
|
| 353 |
+
raw_data["not_exhaustively_labeled_cls"] = self.seq_to_classes[seq_id][
|
| 354 |
+
"not_exhaustively_labeled_cat_ids"
|
| 355 |
+
]
|
| 356 |
+
raw_data["seq"] = seq
|
| 357 |
+
return raw_data
|
| 358 |
+
|
| 359 |
+
@_timing.time
|
| 360 |
+
def get_preprocessed_seq_data(self, raw_data, cls):
|
| 361 |
+
"""Preprocess data for a single sequence for a single class ready for evaluation.
|
| 362 |
+
Inputs:
|
| 363 |
+
- raw_data is a dict containing the data for the sequence already read in by get_raw_seq_data().
|
| 364 |
+
- cls is the class to be evaluated.
|
| 365 |
+
Outputs:
|
| 366 |
+
- data is a dict containing all of the information that metrics need to perform evaluation.
|
| 367 |
+
It contains the following fields:
|
| 368 |
+
[num_timesteps, num_gt_ids, num_tracker_ids, num_gt_dets, num_tracker_dets] : integers.
|
| 369 |
+
[gt_ids, tracker_ids, tracker_confidences]: list (for each timestep) of 1D NDArrays (for each det).
|
| 370 |
+
[gt_dets, tracker_dets]: list (for each timestep) of lists of detections.
|
| 371 |
+
[similarity_scores]: list (for each timestep) of 2D NDArrays.
|
| 372 |
+
Notes:
|
| 373 |
+
General preprocessing (preproc) occurs in 4 steps. Some datasets may not use all of these steps.
|
| 374 |
+
1) Extract only detections relevant for the class to be evaluated (including distractor detections).
|
| 375 |
+
2) Match gt dets and tracker dets. Remove tracker dets that are matched to a gt det that is of a
|
| 376 |
+
distractor class, or otherwise marked as to be removed.
|
| 377 |
+
3) Remove unmatched tracker dets if they fall within a crowd ignore region or don't meet a certain
|
| 378 |
+
other criteria (e.g. are too small).
|
| 379 |
+
4) Remove gt dets that were only useful for preprocessing and not for actual evaluation.
|
| 380 |
+
After the above preprocessing steps, this function also calculates the number of gt and tracker detections
|
| 381 |
+
and unique track ids. It also relabels gt and tracker ids to be contiguous and checks that ids are
|
| 382 |
+
unique within each timestep.
|
| 383 |
+
TAO:
|
| 384 |
+
In TAO, the 4 preproc steps are as follow:
|
| 385 |
+
1) All classes present in the ground truth data are evaluated separately.
|
| 386 |
+
2) No matched tracker detections are removed.
|
| 387 |
+
3) Unmatched tracker detections are removed if there is not ground truth data and the class does not
|
| 388 |
+
belong to the categories marked as negative for this sequence. Additionally, unmatched tracker
|
| 389 |
+
detections for classes which are marked as not exhaustively labeled are removed.
|
| 390 |
+
4) No gt detections are removed.
|
| 391 |
+
Further, for TrackMAP computation track representations for the given class are accessed from a dictionary
|
| 392 |
+
and the tracks from the tracker data are sorted according to the tracker confidence.
|
| 393 |
+
"""
|
| 394 |
+
cls_id = self.class_name_to_class_id[cls]
|
| 395 |
+
is_not_exhaustively_labeled = cls_id in raw_data["not_exhaustively_labeled_cls"]
|
| 396 |
+
is_neg_category = cls_id in raw_data["neg_cat_ids"]
|
| 397 |
+
|
| 398 |
+
data_keys = [
|
| 399 |
+
"gt_ids",
|
| 400 |
+
"tracker_ids",
|
| 401 |
+
"gt_dets",
|
| 402 |
+
"tracker_dets",
|
| 403 |
+
"tracker_confidences",
|
| 404 |
+
"similarity_scores",
|
| 405 |
+
]
|
| 406 |
+
data = {key: [None] * raw_data["num_timesteps"] for key in data_keys}
|
| 407 |
+
unique_gt_ids = []
|
| 408 |
+
unique_tracker_ids = []
|
| 409 |
+
num_gt_dets = 0
|
| 410 |
+
num_tracker_dets = 0
|
| 411 |
+
for t in range(raw_data["num_timesteps"]):
|
| 412 |
+
# Only extract relevant dets for this class for preproc and eval (cls)
|
| 413 |
+
gt_class_mask = np.atleast_1d(raw_data["gt_classes"][t] == cls_id)
|
| 414 |
+
gt_class_mask = gt_class_mask.astype(bool)
|
| 415 |
+
gt_ids = raw_data["gt_ids"][t][gt_class_mask]
|
| 416 |
+
gt_dets = raw_data["gt_dets"][t][gt_class_mask]
|
| 417 |
+
|
| 418 |
+
tracker_class_mask = np.atleast_1d(raw_data["tracker_classes"][t] == cls_id)
|
| 419 |
+
tracker_class_mask = tracker_class_mask.astype(bool)
|
| 420 |
+
tracker_ids = raw_data["tracker_ids"][t][tracker_class_mask]
|
| 421 |
+
tracker_dets = raw_data["tracker_dets"][t][tracker_class_mask]
|
| 422 |
+
tracker_confidences = raw_data["tracker_confidences"][t][tracker_class_mask]
|
| 423 |
+
similarity_scores = raw_data["similarity_scores"][t][gt_class_mask, :][
|
| 424 |
+
:, tracker_class_mask
|
| 425 |
+
]
|
| 426 |
+
|
| 427 |
+
# Match tracker and gt dets (with hungarian algorithm).
|
| 428 |
+
unmatched_indices = np.arange(tracker_ids.shape[0])
|
| 429 |
+
if gt_ids.shape[0] > 0 and tracker_ids.shape[0] > 0:
|
| 430 |
+
matching_scores = similarity_scores.copy()
|
| 431 |
+
matching_scores[matching_scores < 0.5 - np.finfo("float").eps] = 0
|
| 432 |
+
match_rows, match_cols = linear_sum_assignment(-matching_scores)
|
| 433 |
+
actually_matched_mask = (
|
| 434 |
+
matching_scores[match_rows, match_cols] > 0 + np.finfo("float").eps
|
| 435 |
+
)
|
| 436 |
+
match_cols = match_cols[actually_matched_mask]
|
| 437 |
+
unmatched_indices = np.delete(unmatched_indices, match_cols, axis=0)
|
| 438 |
+
|
| 439 |
+
if gt_ids.shape[0] == 0 and not is_neg_category:
|
| 440 |
+
to_remove_tracker = unmatched_indices
|
| 441 |
+
elif is_not_exhaustively_labeled:
|
| 442 |
+
to_remove_tracker = unmatched_indices
|
| 443 |
+
else:
|
| 444 |
+
to_remove_tracker = np.array([], dtype=int)
|
| 445 |
+
|
| 446 |
+
# remove all unwanted unmatched tracker detections
|
| 447 |
+
data["tracker_ids"][t] = np.delete(tracker_ids, to_remove_tracker, axis=0)
|
| 448 |
+
data["tracker_dets"][t] = np.delete(tracker_dets, to_remove_tracker, axis=0)
|
| 449 |
+
data["tracker_confidences"][t] = np.delete(
|
| 450 |
+
tracker_confidences, to_remove_tracker, axis=0
|
| 451 |
+
)
|
| 452 |
+
similarity_scores = np.delete(similarity_scores, to_remove_tracker, axis=1)
|
| 453 |
+
|
| 454 |
+
data["gt_ids"][t] = gt_ids
|
| 455 |
+
data["gt_dets"][t] = gt_dets
|
| 456 |
+
data["similarity_scores"][t] = similarity_scores
|
| 457 |
+
|
| 458 |
+
unique_gt_ids += list(np.unique(data["gt_ids"][t]))
|
| 459 |
+
unique_tracker_ids += list(np.unique(data["tracker_ids"][t]))
|
| 460 |
+
num_tracker_dets += len(data["tracker_ids"][t])
|
| 461 |
+
num_gt_dets += len(data["gt_ids"][t])
|
| 462 |
+
|
| 463 |
+
# Re-label IDs such that there are no empty IDs
|
| 464 |
+
if len(unique_gt_ids) > 0:
|
| 465 |
+
unique_gt_ids = np.unique(unique_gt_ids)
|
| 466 |
+
gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1))
|
| 467 |
+
gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids))
|
| 468 |
+
for t in range(raw_data["num_timesteps"]):
|
| 469 |
+
if len(data["gt_ids"][t]) > 0:
|
| 470 |
+
data["gt_ids"][t] = gt_id_map[data["gt_ids"][t]].astype(int)
|
| 471 |
+
if len(unique_tracker_ids) > 0:
|
| 472 |
+
unique_tracker_ids = np.unique(unique_tracker_ids)
|
| 473 |
+
tracker_id_map = np.nan * np.ones((np.max(unique_tracker_ids) + 1))
|
| 474 |
+
tracker_id_map[unique_tracker_ids] = np.arange(len(unique_tracker_ids))
|
| 475 |
+
for t in range(raw_data["num_timesteps"]):
|
| 476 |
+
if len(data["tracker_ids"][t]) > 0:
|
| 477 |
+
data["tracker_ids"][t] = tracker_id_map[
|
| 478 |
+
data["tracker_ids"][t]
|
| 479 |
+
].astype(int)
|
| 480 |
+
|
| 481 |
+
# Record overview statistics.
|
| 482 |
+
data["num_tracker_dets"] = num_tracker_dets
|
| 483 |
+
data["num_gt_dets"] = num_gt_dets
|
| 484 |
+
data["num_tracker_ids"] = len(unique_tracker_ids)
|
| 485 |
+
data["num_gt_ids"] = len(unique_gt_ids)
|
| 486 |
+
data["num_timesteps"] = raw_data["num_timesteps"]
|
| 487 |
+
data["seq"] = raw_data["seq"]
|
| 488 |
+
|
| 489 |
+
# get track representations
|
| 490 |
+
data["gt_tracks"] = raw_data["classes_to_gt_tracks"][cls_id]
|
| 491 |
+
data["gt_track_ids"] = raw_data["classes_to_gt_track_ids"][cls_id]
|
| 492 |
+
data["gt_track_lengths"] = raw_data["classes_to_gt_track_lengths"][cls_id]
|
| 493 |
+
data["gt_track_areas"] = raw_data["classes_to_gt_track_areas"][cls_id]
|
| 494 |
+
data["dt_tracks"] = raw_data["classes_to_dt_tracks"][cls_id]
|
| 495 |
+
data["dt_track_ids"] = raw_data["classes_to_dt_track_ids"][cls_id]
|
| 496 |
+
data["dt_track_lengths"] = raw_data["classes_to_dt_track_lengths"][cls_id]
|
| 497 |
+
data["dt_track_areas"] = raw_data["classes_to_dt_track_areas"][cls_id]
|
| 498 |
+
data["dt_track_scores"] = raw_data["classes_to_dt_track_scores"][cls_id]
|
| 499 |
+
data["not_exhaustively_labeled"] = is_not_exhaustively_labeled
|
| 500 |
+
data["iou_type"] = "bbox"
|
| 501 |
+
|
| 502 |
+
# sort tracker data tracks by tracker confidence scores
|
| 503 |
+
if data["dt_tracks"]:
|
| 504 |
+
idx = np.argsort(
|
| 505 |
+
[-score for score in data["dt_track_scores"]], kind="mergesort"
|
| 506 |
+
)
|
| 507 |
+
data["dt_track_scores"] = [data["dt_track_scores"][i] for i in idx]
|
| 508 |
+
data["dt_tracks"] = [data["dt_tracks"][i] for i in idx]
|
| 509 |
+
data["dt_track_ids"] = [data["dt_track_ids"][i] for i in idx]
|
| 510 |
+
data["dt_track_lengths"] = [data["dt_track_lengths"][i] for i in idx]
|
| 511 |
+
data["dt_track_areas"] = [data["dt_track_areas"][i] for i in idx]
|
| 512 |
+
# Ensure that ids are unique per timestep.
|
| 513 |
+
self._check_unique_ids(data)
|
| 514 |
+
|
| 515 |
+
return data
|
| 516 |
+
|
| 517 |
+
def _calculate_similarities(self, gt_dets_t, tracker_dets_t):
|
| 518 |
+
similarity_scores = self._calculate_box_ious(gt_dets_t, tracker_dets_t)
|
| 519 |
+
return similarity_scores
|
| 520 |
+
|
| 521 |
+
def _merge_categories(self, annotations):
|
| 522 |
+
"""
|
| 523 |
+
Merges categories with a merged tag. Adapted from https://github.com/TAO-Dataset
|
| 524 |
+
:param annotations: the annotations in which the classes should be merged
|
| 525 |
+
:return: None
|
| 526 |
+
"""
|
| 527 |
+
merge_map = {}
|
| 528 |
+
for category in self.gt_data["categories"]:
|
| 529 |
+
if "merged" in category:
|
| 530 |
+
for to_merge in category["merged"]:
|
| 531 |
+
merge_map[to_merge["id"]] = category["id"]
|
| 532 |
+
|
| 533 |
+
for ann in annotations:
|
| 534 |
+
ann["category_id"] = merge_map.get(ann["category_id"], ann["category_id"])
|
| 535 |
+
|
| 536 |
+
def _compute_vid_mappings(self, annotations):
|
| 537 |
+
"""
|
| 538 |
+
Computes mappings from Videos to corresponding tracks and images.
|
| 539 |
+
:param annotations: the annotations for which the mapping should be generated
|
| 540 |
+
:return: the video-to-track-mapping, the video-to-image-mapping
|
| 541 |
+
"""
|
| 542 |
+
vids_to_tracks = {}
|
| 543 |
+
vids_to_imgs = {}
|
| 544 |
+
vid_ids = [vid["id"] for vid in self.gt_data["videos"]]
|
| 545 |
+
|
| 546 |
+
# compute an mapping from image IDs to images
|
| 547 |
+
images = {}
|
| 548 |
+
for image in self.gt_data["images"]:
|
| 549 |
+
images[image["id"]] = image
|
| 550 |
+
|
| 551 |
+
for ann in annotations:
|
| 552 |
+
ann["area"] = ann["bbox"][2] * ann["bbox"][3]
|
| 553 |
+
|
| 554 |
+
vid = ann["video_id"]
|
| 555 |
+
if ann["video_id"] not in vids_to_tracks.keys():
|
| 556 |
+
vids_to_tracks[ann["video_id"]] = list()
|
| 557 |
+
if ann["video_id"] not in vids_to_imgs.keys():
|
| 558 |
+
vids_to_imgs[ann["video_id"]] = list()
|
| 559 |
+
|
| 560 |
+
# Fill in vids_to_tracks
|
| 561 |
+
tid = ann["track_id"]
|
| 562 |
+
exist_tids = [track["id"] for track in vids_to_tracks[vid]]
|
| 563 |
+
try:
|
| 564 |
+
index1 = exist_tids.index(tid)
|
| 565 |
+
except ValueError:
|
| 566 |
+
index1 = -1
|
| 567 |
+
if tid not in exist_tids:
|
| 568 |
+
curr_track = {
|
| 569 |
+
"id": tid,
|
| 570 |
+
"category_id": ann["category_id"],
|
| 571 |
+
"video_id": vid,
|
| 572 |
+
"annotations": [ann],
|
| 573 |
+
}
|
| 574 |
+
vids_to_tracks[vid].append(curr_track)
|
| 575 |
+
else:
|
| 576 |
+
vids_to_tracks[vid][index1]["annotations"].append(ann)
|
| 577 |
+
|
| 578 |
+
# Fill in vids_to_imgs
|
| 579 |
+
img_id = ann["image_id"]
|
| 580 |
+
exist_img_ids = [img["id"] for img in vids_to_imgs[vid]]
|
| 581 |
+
try:
|
| 582 |
+
index2 = exist_img_ids.index(img_id)
|
| 583 |
+
except ValueError:
|
| 584 |
+
index2 = -1
|
| 585 |
+
if index2 == -1:
|
| 586 |
+
curr_img = {"id": img_id, "annotations": [ann]}
|
| 587 |
+
vids_to_imgs[vid].append(curr_img)
|
| 588 |
+
else:
|
| 589 |
+
vids_to_imgs[vid][index2]["annotations"].append(ann)
|
| 590 |
+
|
| 591 |
+
# sort annotations by frame index and compute track area
|
| 592 |
+
for vid, tracks in vids_to_tracks.items():
|
| 593 |
+
for track in tracks:
|
| 594 |
+
track["annotations"] = sorted(
|
| 595 |
+
track["annotations"],
|
| 596 |
+
key=lambda x: images[x["image_id"]]["frame_index"],
|
| 597 |
+
)
|
| 598 |
+
# Computer average area
|
| 599 |
+
track["area"] = sum(x["area"] for x in track["annotations"]) / len(
|
| 600 |
+
track["annotations"]
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
# Ensure all videos are present
|
| 604 |
+
for vid_id in vid_ids:
|
| 605 |
+
if vid_id not in vids_to_tracks.keys():
|
| 606 |
+
vids_to_tracks[vid_id] = []
|
| 607 |
+
if vid_id not in vids_to_imgs.keys():
|
| 608 |
+
vids_to_imgs[vid_id] = []
|
| 609 |
+
|
| 610 |
+
return vids_to_tracks, vids_to_imgs
|
| 611 |
+
|
| 612 |
+
def _compute_image_to_timestep_mappings(self):
|
| 613 |
+
"""
|
| 614 |
+
Computes a mapping from images to the corresponding timestep in the sequence.
|
| 615 |
+
:return: the image-to-timestep-mapping
|
| 616 |
+
"""
|
| 617 |
+
images = {}
|
| 618 |
+
for image in self.gt_data["images"]:
|
| 619 |
+
images[image["id"]] = image
|
| 620 |
+
|
| 621 |
+
seq_to_imgs_to_timestep = {vid["id"]: dict() for vid in self.gt_data["videos"]}
|
| 622 |
+
for vid in seq_to_imgs_to_timestep:
|
| 623 |
+
curr_imgs = [img["id"] for img in self.videos_to_gt_images[vid]]
|
| 624 |
+
curr_imgs = sorted(curr_imgs, key=lambda x: images[x]["frame_index"])
|
| 625 |
+
seq_to_imgs_to_timestep[vid] = {
|
| 626 |
+
curr_imgs[i]: i for i in range(len(curr_imgs))
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
return seq_to_imgs_to_timestep
|
| 630 |
+
|
| 631 |
+
def _limit_dets_per_image(self, annotations):
|
| 632 |
+
"""
|
| 633 |
+
Limits the number of detections for each image to config['MAX_DETECTIONS']. Adapted from
|
| 634 |
+
https://github.com/TAO-Dataset/
|
| 635 |
+
:param annotations: the annotations in which the detections should be limited
|
| 636 |
+
:return: the annotations with limited detections
|
| 637 |
+
"""
|
| 638 |
+
max_dets = self.config["MAX_DETECTIONS"]
|
| 639 |
+
img_ann = defaultdict(list)
|
| 640 |
+
for ann in annotations:
|
| 641 |
+
img_ann[ann["image_id"]].append(ann)
|
| 642 |
+
|
| 643 |
+
for img_id, _anns in img_ann.items():
|
| 644 |
+
if len(_anns) <= max_dets:
|
| 645 |
+
continue
|
| 646 |
+
_anns = sorted(_anns, key=lambda x: x["score"], reverse=True)
|
| 647 |
+
img_ann[img_id] = _anns[:max_dets]
|
| 648 |
+
|
| 649 |
+
return [ann for anns in img_ann.values() for ann in anns]
|
| 650 |
+
|
| 651 |
+
def _fill_video_ids_inplace(self, annotations):
|
| 652 |
+
"""
|
| 653 |
+
Fills in missing video IDs inplace. Adapted from https://github.com/TAO-Dataset/
|
| 654 |
+
:param annotations: the annotations for which the videos IDs should be filled inplace
|
| 655 |
+
:return: None
|
| 656 |
+
"""
|
| 657 |
+
missing_video_id = [x for x in annotations if "video_id" not in x]
|
| 658 |
+
if missing_video_id:
|
| 659 |
+
image_id_to_video_id = {
|
| 660 |
+
x["id"]: x["video_id"] for x in self.gt_data["images"]
|
| 661 |
+
}
|
| 662 |
+
for x in missing_video_id:
|
| 663 |
+
x["video_id"] = image_id_to_video_id[x["image_id"]]
|
| 664 |
+
|
| 665 |
+
@staticmethod
|
| 666 |
+
def _make_track_ids_unique(annotations):
|
| 667 |
+
"""
|
| 668 |
+
Makes the track IDs unqiue over the whole annotation set. Adapted from https://github.com/TAO-Dataset/
|
| 669 |
+
:param annotations: the annotation set
|
| 670 |
+
:return: the number of updated IDs
|
| 671 |
+
"""
|
| 672 |
+
track_id_videos = {}
|
| 673 |
+
track_ids_to_update = set()
|
| 674 |
+
max_track_id = 0
|
| 675 |
+
for ann in annotations:
|
| 676 |
+
t = ann["track_id"]
|
| 677 |
+
if t not in track_id_videos:
|
| 678 |
+
track_id_videos[t] = ann["video_id"]
|
| 679 |
+
|
| 680 |
+
if ann["video_id"] != track_id_videos[t]:
|
| 681 |
+
# Track id is assigned to multiple videos
|
| 682 |
+
track_ids_to_update.add(t)
|
| 683 |
+
max_track_id = max(max_track_id, t)
|
| 684 |
+
|
| 685 |
+
if track_ids_to_update:
|
| 686 |
+
print("true")
|
| 687 |
+
next_id = itertools.count(max_track_id + 1)
|
| 688 |
+
new_track_ids = defaultdict(lambda: next(next_id))
|
| 689 |
+
for ann in annotations:
|
| 690 |
+
t = ann["track_id"]
|
| 691 |
+
v = ann["video_id"]
|
| 692 |
+
if t in track_ids_to_update:
|
| 693 |
+
ann["track_id"] = new_track_ids[t, v]
|
| 694 |
+
return len(track_ids_to_update)
|
| 695 |
+
|
| 696 |
+
def _split_known_unknown_distractor(self):
|
| 697 |
+
all_ids = set(
|
| 698 |
+
[i for i in range(1, 2000)]
|
| 699 |
+
) # 2000 is larger than the max category id in TAO-OW.
|
| 700 |
+
# `knowns` includes 78 TAO_category_ids that corresponds to 78 COCO classes.
|
| 701 |
+
# (The other 2 COCO classes do not have corresponding classes in TAO).
|
| 702 |
+
self.knowns = {
|
| 703 |
+
4,
|
| 704 |
+
13,
|
| 705 |
+
1038,
|
| 706 |
+
544,
|
| 707 |
+
1057,
|
| 708 |
+
34,
|
| 709 |
+
35,
|
| 710 |
+
36,
|
| 711 |
+
41,
|
| 712 |
+
45,
|
| 713 |
+
58,
|
| 714 |
+
60,
|
| 715 |
+
579,
|
| 716 |
+
1091,
|
| 717 |
+
1097,
|
| 718 |
+
1099,
|
| 719 |
+
78,
|
| 720 |
+
79,
|
| 721 |
+
81,
|
| 722 |
+
91,
|
| 723 |
+
1115,
|
| 724 |
+
1117,
|
| 725 |
+
95,
|
| 726 |
+
1122,
|
| 727 |
+
99,
|
| 728 |
+
1132,
|
| 729 |
+
621,
|
| 730 |
+
1135,
|
| 731 |
+
625,
|
| 732 |
+
118,
|
| 733 |
+
1144,
|
| 734 |
+
126,
|
| 735 |
+
642,
|
| 736 |
+
1155,
|
| 737 |
+
133,
|
| 738 |
+
1162,
|
| 739 |
+
139,
|
| 740 |
+
154,
|
| 741 |
+
174,
|
| 742 |
+
185,
|
| 743 |
+
699,
|
| 744 |
+
1215,
|
| 745 |
+
714,
|
| 746 |
+
717,
|
| 747 |
+
1229,
|
| 748 |
+
211,
|
| 749 |
+
729,
|
| 750 |
+
221,
|
| 751 |
+
229,
|
| 752 |
+
747,
|
| 753 |
+
235,
|
| 754 |
+
237,
|
| 755 |
+
779,
|
| 756 |
+
276,
|
| 757 |
+
805,
|
| 758 |
+
299,
|
| 759 |
+
829,
|
| 760 |
+
852,
|
| 761 |
+
347,
|
| 762 |
+
371,
|
| 763 |
+
382,
|
| 764 |
+
896,
|
| 765 |
+
392,
|
| 766 |
+
926,
|
| 767 |
+
937,
|
| 768 |
+
428,
|
| 769 |
+
429,
|
| 770 |
+
961,
|
| 771 |
+
452,
|
| 772 |
+
979,
|
| 773 |
+
980,
|
| 774 |
+
982,
|
| 775 |
+
475,
|
| 776 |
+
480,
|
| 777 |
+
993,
|
| 778 |
+
1001,
|
| 779 |
+
502,
|
| 780 |
+
1018,
|
| 781 |
+
}
|
| 782 |
+
# `distractors` is defined as in the paper "Opening up Open-World Tracking"
|
| 783 |
+
self.distractors = {
|
| 784 |
+
20,
|
| 785 |
+
63,
|
| 786 |
+
108,
|
| 787 |
+
180,
|
| 788 |
+
188,
|
| 789 |
+
204,
|
| 790 |
+
212,
|
| 791 |
+
247,
|
| 792 |
+
303,
|
| 793 |
+
403,
|
| 794 |
+
407,
|
| 795 |
+
415,
|
| 796 |
+
490,
|
| 797 |
+
504,
|
| 798 |
+
507,
|
| 799 |
+
513,
|
| 800 |
+
529,
|
| 801 |
+
567,
|
| 802 |
+
569,
|
| 803 |
+
588,
|
| 804 |
+
672,
|
| 805 |
+
691,
|
| 806 |
+
702,
|
| 807 |
+
708,
|
| 808 |
+
711,
|
| 809 |
+
720,
|
| 810 |
+
736,
|
| 811 |
+
737,
|
| 812 |
+
798,
|
| 813 |
+
813,
|
| 814 |
+
815,
|
| 815 |
+
827,
|
| 816 |
+
831,
|
| 817 |
+
851,
|
| 818 |
+
877,
|
| 819 |
+
883,
|
| 820 |
+
912,
|
| 821 |
+
971,
|
| 822 |
+
976,
|
| 823 |
+
1130,
|
| 824 |
+
1133,
|
| 825 |
+
1134,
|
| 826 |
+
1169,
|
| 827 |
+
1184,
|
| 828 |
+
1220,
|
| 829 |
+
}
|
| 830 |
+
self.unknowns = all_ids.difference(self.knowns.union(self.distractors))
|
| 831 |
+
|
| 832 |
+
def _filter_gt_data(self, raw_gt_data):
|
| 833 |
+
"""
|
| 834 |
+
Filter out irrelevant data in the raw_gt_data
|
| 835 |
+
Args:
|
| 836 |
+
raw_gt_data: directly loaded from json.
|
| 837 |
+
|
| 838 |
+
Returns:
|
| 839 |
+
filtered gt_data
|
| 840 |
+
"""
|
| 841 |
+
valid_cat_ids = list()
|
| 842 |
+
if self.subset == "known":
|
| 843 |
+
valid_cat_ids = self.knowns
|
| 844 |
+
elif self.subset == "distractor":
|
| 845 |
+
valid_cat_ids = self.distractors
|
| 846 |
+
elif self.subset == "unknown":
|
| 847 |
+
valid_cat_ids = self.unknowns
|
| 848 |
+
# elif self.subset == "test_only_unknowns":
|
| 849 |
+
# valid_cat_ids = test_only_unknowns
|
| 850 |
+
else:
|
| 851 |
+
raise Exception("The parameter `SUBSET` is incorrect")
|
| 852 |
+
|
| 853 |
+
filtered = dict()
|
| 854 |
+
filtered["videos"] = raw_gt_data["videos"]
|
| 855 |
+
# filtered["videos"] = list()
|
| 856 |
+
unwanted_vid = set()
|
| 857 |
+
# for video in raw_gt_data["videos"]:
|
| 858 |
+
# datasrc = video["name"].split('/')[1]
|
| 859 |
+
# if datasrc in data_srcs:
|
| 860 |
+
# filtered["videos"].append(video)
|
| 861 |
+
# else:
|
| 862 |
+
# unwanted_vid.add(video["id"])
|
| 863 |
+
|
| 864 |
+
filtered["annotations"] = list()
|
| 865 |
+
for ann in raw_gt_data["annotations"]:
|
| 866 |
+
if (ann["video_id"] not in unwanted_vid) and (
|
| 867 |
+
ann["category_id"] in valid_cat_ids
|
| 868 |
+
):
|
| 869 |
+
filtered["annotations"].append(ann)
|
| 870 |
+
|
| 871 |
+
filtered["tracks"] = list()
|
| 872 |
+
for track in raw_gt_data["tracks"]:
|
| 873 |
+
if (track["video_id"] not in unwanted_vid) and (
|
| 874 |
+
track["category_id"] in valid_cat_ids
|
| 875 |
+
):
|
| 876 |
+
filtered["tracks"].append(track)
|
| 877 |
+
|
| 878 |
+
filtered["images"] = list()
|
| 879 |
+
for image in raw_gt_data["images"]:
|
| 880 |
+
if image["video_id"] not in unwanted_vid:
|
| 881 |
+
filtered["images"].append(image)
|
| 882 |
+
|
| 883 |
+
filtered["categories"] = list()
|
| 884 |
+
for cat in raw_gt_data["categories"]:
|
| 885 |
+
if cat["id"] in valid_cat_ids:
|
| 886 |
+
filtered["categories"].append(cat)
|
| 887 |
+
|
| 888 |
+
filtered["info"] = raw_gt_data["info"]
|
| 889 |
+
filtered["licenses"] = raw_gt_data["licenses"]
|
| 890 |
+
|
| 891 |
+
return filtered
|
sam3/eval/hota_eval_toolkit/trackeval/datasets/youtube_vis.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
# note: this file has been modified from its original version in TrackEval in
|
| 4 |
+
# https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/datasets/youtube_vis.py
|
| 5 |
+
# to support the following:
|
| 6 |
+
# 1) bbox evaluation (via `IOU_TYPE`)
|
| 7 |
+
# 2) passing GT and prediction data as Python objects (via `GT_JSON_OBJECT` and `TRACKER_JSON_OBJECT`)
|
| 8 |
+
# 3) specifying a custom dataset name (via `DATASET_NAME`)
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
from .. import _timing, utils
|
| 16 |
+
from ..utils import TrackEvalException
|
| 17 |
+
from ._base_dataset import _BaseDataset
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class YouTubeVIS(_BaseDataset):
|
| 21 |
+
"""Dataset class for YouTubeVIS tracking"""
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def get_default_dataset_config():
|
| 25 |
+
"""Default class config values"""
|
| 26 |
+
code_path = utils.get_code_path()
|
| 27 |
+
default_config = {
|
| 28 |
+
"GT_FOLDER": os.path.join(
|
| 29 |
+
code_path, "data/gt/youtube_vis/"
|
| 30 |
+
), # Location of GT data
|
| 31 |
+
"TRACKERS_FOLDER": os.path.join(code_path, "data/trackers/youtube_vis/"),
|
| 32 |
+
# Trackers location
|
| 33 |
+
"OUTPUT_FOLDER": None, # Where to save eval results (if None, same as TRACKERS_FOLDER)
|
| 34 |
+
"TRACKERS_TO_EVAL": None, # Filenames of trackers to eval (if None, all in folder)
|
| 35 |
+
"CLASSES_TO_EVAL": None, # Classes to eval (if None, all classes)
|
| 36 |
+
"SPLIT_TO_EVAL": "train_sub_split", # Valid: 'train', 'val', 'train_sub_split'
|
| 37 |
+
"PRINT_CONFIG": True, # Whether to print current config
|
| 38 |
+
"OUTPUT_SUB_FOLDER": "", # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER
|
| 39 |
+
"TRACKER_SUB_FOLDER": "data", # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER
|
| 40 |
+
"TRACKER_DISPLAY_NAMES": None, # Names of trackers to display, if None: TRACKERS_TO_EVAL
|
| 41 |
+
# Added for video phrase AP evaluation -- allow directly specifying the GT JSON data and Tracker (result)
|
| 42 |
+
# JSON data as Python objects, without reading from files.
|
| 43 |
+
"GT_JSON_OBJECT": None,
|
| 44 |
+
"TRACKER_JSON_OBJECT": None,
|
| 45 |
+
"IOU_TYPE": "segm",
|
| 46 |
+
"DATASET_NAME": "video",
|
| 47 |
+
}
|
| 48 |
+
return default_config
|
| 49 |
+
|
| 50 |
+
def __init__(self, config=None):
|
| 51 |
+
"""Initialise dataset, checking that all required files are present"""
|
| 52 |
+
super().__init__()
|
| 53 |
+
# Fill non-given config values with defaults
|
| 54 |
+
self.config = utils.init_config(config, self.get_default_dataset_config())
|
| 55 |
+
self.gt_fol = (
|
| 56 |
+
self.config["GT_FOLDER"] + "youtube_vis_" + self.config["SPLIT_TO_EVAL"]
|
| 57 |
+
)
|
| 58 |
+
self.tracker_fol = (
|
| 59 |
+
self.config["TRACKERS_FOLDER"]
|
| 60 |
+
+ "youtube_vis_"
|
| 61 |
+
+ self.config["SPLIT_TO_EVAL"]
|
| 62 |
+
)
|
| 63 |
+
self.use_super_categories = False
|
| 64 |
+
self.should_classes_combine = True
|
| 65 |
+
assert self.config["IOU_TYPE"] in ["segm", "bbox"]
|
| 66 |
+
self.iou_type = self.config["IOU_TYPE"]
|
| 67 |
+
print("=" * 100)
|
| 68 |
+
print(f"Evaluate annotation type *{self.iou_type}*")
|
| 69 |
+
self.dataset_name = self.config["DATASET_NAME"]
|
| 70 |
+
|
| 71 |
+
self.output_fol = self.config["OUTPUT_FOLDER"]
|
| 72 |
+
if self.output_fol is None:
|
| 73 |
+
self.output_fol = self.tracker_fol
|
| 74 |
+
self.output_sub_fol = self.config["OUTPUT_SUB_FOLDER"]
|
| 75 |
+
self.tracker_sub_fol = self.config["TRACKER_SUB_FOLDER"]
|
| 76 |
+
|
| 77 |
+
if self.config["GT_JSON_OBJECT"] is not None:
|
| 78 |
+
# allow directly specifying the GT JSON data without reading from files
|
| 79 |
+
gt_json = self.config["GT_JSON_OBJECT"]
|
| 80 |
+
assert isinstance(gt_json, dict)
|
| 81 |
+
assert "videos" in gt_json
|
| 82 |
+
assert "categories" in gt_json
|
| 83 |
+
assert "annotations" in gt_json
|
| 84 |
+
self.gt_data = gt_json
|
| 85 |
+
else:
|
| 86 |
+
if not os.path.exists(self.gt_fol):
|
| 87 |
+
print("GT folder not found: " + self.gt_fol)
|
| 88 |
+
raise TrackEvalException(
|
| 89 |
+
"GT folder not found: " + os.path.basename(self.gt_fol)
|
| 90 |
+
)
|
| 91 |
+
gt_dir_files = [
|
| 92 |
+
file for file in os.listdir(self.gt_fol) if file.endswith(".json")
|
| 93 |
+
]
|
| 94 |
+
if len(gt_dir_files) != 1:
|
| 95 |
+
raise TrackEvalException(
|
| 96 |
+
self.gt_fol + " does not contain exactly one json file."
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
with open(os.path.join(self.gt_fol, gt_dir_files[0])) as f:
|
| 100 |
+
self.gt_data = json.load(f)
|
| 101 |
+
|
| 102 |
+
# Get classes to eval
|
| 103 |
+
self.valid_classes = [cls["name"] for cls in self.gt_data["categories"]]
|
| 104 |
+
cls_name_to_cls_id_map = {
|
| 105 |
+
cls["name"]: cls["id"] for cls in self.gt_data["categories"]
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
if self.config["CLASSES_TO_EVAL"]:
|
| 109 |
+
self.class_list = [
|
| 110 |
+
cls.lower() if cls.lower() in self.valid_classes else None
|
| 111 |
+
for cls in self.config["CLASSES_TO_EVAL"]
|
| 112 |
+
]
|
| 113 |
+
if not all(self.class_list):
|
| 114 |
+
raise TrackEvalException(
|
| 115 |
+
"Attempted to evaluate an invalid class. Only classes "
|
| 116 |
+
+ ", ".join(self.valid_classes)
|
| 117 |
+
+ " are valid."
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
self.class_list = [cls["name"] for cls in self.gt_data["categories"]]
|
| 121 |
+
self.class_name_to_class_id = {
|
| 122 |
+
k: v for k, v in cls_name_to_cls_id_map.items() if k in self.class_list
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
# Get sequences to eval and check gt files exist
|
| 126 |
+
self.seq_list = [
|
| 127 |
+
vid["file_names"][0].split("/")[0] for vid in self.gt_data["videos"]
|
| 128 |
+
]
|
| 129 |
+
self.seq_name_to_seq_id = {
|
| 130 |
+
vid["file_names"][0].split("/")[0]: vid["id"]
|
| 131 |
+
for vid in self.gt_data["videos"]
|
| 132 |
+
}
|
| 133 |
+
self.seq_lengths = {
|
| 134 |
+
vid["id"]: len(vid["file_names"]) for vid in self.gt_data["videos"]
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
# encode masks and compute track areas
|
| 138 |
+
self._prepare_gt_annotations()
|
| 139 |
+
|
| 140 |
+
# Get trackers to eval
|
| 141 |
+
if self.config["TRACKER_JSON_OBJECT"] is not None:
|
| 142 |
+
# allow directly specifying the tracker JSON data without reading from files
|
| 143 |
+
tracker_json = self.config["TRACKER_JSON_OBJECT"]
|
| 144 |
+
assert isinstance(tracker_json, list)
|
| 145 |
+
self.tracker_list = ["tracker"]
|
| 146 |
+
elif self.config["TRACKERS_TO_EVAL"] is None:
|
| 147 |
+
self.tracker_list = os.listdir(self.tracker_fol)
|
| 148 |
+
else:
|
| 149 |
+
self.tracker_list = self.config["TRACKERS_TO_EVAL"]
|
| 150 |
+
|
| 151 |
+
if self.config["TRACKER_DISPLAY_NAMES"] is None:
|
| 152 |
+
self.tracker_to_disp = dict(zip(self.tracker_list, self.tracker_list))
|
| 153 |
+
elif (self.config["TRACKERS_TO_EVAL"] is not None) and (
|
| 154 |
+
len(self.config["TRACKER_DISPLAY_NAMES"]) == len(self.tracker_list)
|
| 155 |
+
):
|
| 156 |
+
self.tracker_to_disp = dict(
|
| 157 |
+
zip(self.tracker_list, self.config["TRACKER_DISPLAY_NAMES"])
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
raise TrackEvalException(
|
| 161 |
+
"List of tracker files and tracker display names do not match."
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# counter for globally unique track IDs
|
| 165 |
+
self.global_tid_counter = 0
|
| 166 |
+
|
| 167 |
+
self.tracker_data = dict()
|
| 168 |
+
if self.config["TRACKER_JSON_OBJECT"] is not None:
|
| 169 |
+
# allow directly specifying the tracker JSON data without reading from files
|
| 170 |
+
tracker = self.tracker_list[0]
|
| 171 |
+
self.tracker_data[tracker] = tracker_json
|
| 172 |
+
else:
|
| 173 |
+
for tracker in self.tracker_list:
|
| 174 |
+
tracker_dir_path = os.path.join(
|
| 175 |
+
self.tracker_fol, tracker, self.tracker_sub_fol
|
| 176 |
+
)
|
| 177 |
+
tr_dir_files = [
|
| 178 |
+
file
|
| 179 |
+
for file in os.listdir(tracker_dir_path)
|
| 180 |
+
if file.endswith(".json")
|
| 181 |
+
]
|
| 182 |
+
if len(tr_dir_files) != 1:
|
| 183 |
+
raise TrackEvalException(
|
| 184 |
+
tracker_dir_path + " does not contain exactly one json file."
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
with open(os.path.join(tracker_dir_path, tr_dir_files[0])) as f:
|
| 188 |
+
curr_data = json.load(f)
|
| 189 |
+
|
| 190 |
+
self.tracker_data[tracker] = curr_data
|
| 191 |
+
|
| 192 |
+
def get_display_name(self, tracker):
|
| 193 |
+
return self.tracker_to_disp[tracker]
|
| 194 |
+
|
| 195 |
+
def _load_raw_file(self, tracker, seq, is_gt):
|
| 196 |
+
"""Load a file (gt or tracker) in the YouTubeVIS format
|
| 197 |
+
If is_gt, this returns a dict which contains the fields:
|
| 198 |
+
[gt_ids, gt_classes] : list (for each timestep) of 1D NDArrays (for each det).
|
| 199 |
+
[gt_dets]: list (for each timestep) of lists of detections.
|
| 200 |
+
[classes_to_gt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as
|
| 201 |
+
keys and corresponding segmentations as values) for each track
|
| 202 |
+
[classes_to_gt_track_ids, classes_to_gt_track_areas, classes_to_gt_track_iscrowd]: dictionary with class values
|
| 203 |
+
as keys and lists (for each track) as values
|
| 204 |
+
|
| 205 |
+
if not is_gt, this returns a dict which contains the fields:
|
| 206 |
+
[tracker_ids, tracker_classes, tracker_confidences] : list (for each timestep) of 1D NDArrays (for each det).
|
| 207 |
+
[tracker_dets]: list (for each timestep) of lists of detections.
|
| 208 |
+
[classes_to_dt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as
|
| 209 |
+
keys and corresponding segmentations as values) for each track
|
| 210 |
+
[classes_to_dt_track_ids, classes_to_dt_track_areas]: dictionary with class values as keys and lists as values
|
| 211 |
+
[classes_to_dt_track_scores]: dictionary with class values as keys and 1D numpy arrays as values
|
| 212 |
+
"""
|
| 213 |
+
# select sequence tracks
|
| 214 |
+
seq_id = self.seq_name_to_seq_id[seq]
|
| 215 |
+
if is_gt:
|
| 216 |
+
tracks = [
|
| 217 |
+
ann for ann in self.gt_data["annotations"] if ann["video_id"] == seq_id
|
| 218 |
+
]
|
| 219 |
+
else:
|
| 220 |
+
tracks = self._get_tracker_seq_tracks(tracker, seq_id)
|
| 221 |
+
|
| 222 |
+
# Convert data to required format
|
| 223 |
+
num_timesteps = self.seq_lengths[seq_id]
|
| 224 |
+
data_keys = ["ids", "classes", "dets"]
|
| 225 |
+
if not is_gt:
|
| 226 |
+
data_keys += ["tracker_confidences"]
|
| 227 |
+
raw_data = {key: [None] * num_timesteps for key in data_keys}
|
| 228 |
+
result_key = "segmentations" if self.iou_type == "segm" else "bboxes"
|
| 229 |
+
for t in range(num_timesteps):
|
| 230 |
+
raw_data["dets"][t] = [
|
| 231 |
+
track[result_key][t] for track in tracks if track[result_key][t]
|
| 232 |
+
]
|
| 233 |
+
raw_data["ids"][t] = np.atleast_1d(
|
| 234 |
+
[track["id"] for track in tracks if track[result_key][t]]
|
| 235 |
+
).astype(int)
|
| 236 |
+
raw_data["classes"][t] = np.atleast_1d(
|
| 237 |
+
[track["category_id"] for track in tracks if track[result_key][t]]
|
| 238 |
+
).astype(int)
|
| 239 |
+
if not is_gt:
|
| 240 |
+
raw_data["tracker_confidences"][t] = np.atleast_1d(
|
| 241 |
+
[track["score"] for track in tracks if track[result_key][t]]
|
| 242 |
+
).astype(float)
|
| 243 |
+
|
| 244 |
+
if is_gt:
|
| 245 |
+
key_map = {"ids": "gt_ids", "classes": "gt_classes", "dets": "gt_dets"}
|
| 246 |
+
else:
|
| 247 |
+
key_map = {
|
| 248 |
+
"ids": "tracker_ids",
|
| 249 |
+
"classes": "tracker_classes",
|
| 250 |
+
"dets": "tracker_dets",
|
| 251 |
+
}
|
| 252 |
+
for k, v in key_map.items():
|
| 253 |
+
raw_data[v] = raw_data.pop(k)
|
| 254 |
+
|
| 255 |
+
all_cls_ids = {self.class_name_to_class_id[cls] for cls in self.class_list}
|
| 256 |
+
classes_to_tracks = {
|
| 257 |
+
cls: [track for track in tracks if track["category_id"] == cls]
|
| 258 |
+
for cls in all_cls_ids
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
# mapping from classes to track representations and track information
|
| 262 |
+
raw_data["classes_to_tracks"] = {
|
| 263 |
+
cls: [
|
| 264 |
+
{i: track[result_key][i] for i in range(len(track[result_key]))}
|
| 265 |
+
for track in tracks
|
| 266 |
+
]
|
| 267 |
+
for cls, tracks in classes_to_tracks.items()
|
| 268 |
+
}
|
| 269 |
+
raw_data["classes_to_track_ids"] = {
|
| 270 |
+
cls: [track["id"] for track in tracks]
|
| 271 |
+
for cls, tracks in classes_to_tracks.items()
|
| 272 |
+
}
|
| 273 |
+
raw_data["classes_to_track_areas"] = {
|
| 274 |
+
cls: [track["area"] for track in tracks]
|
| 275 |
+
for cls, tracks in classes_to_tracks.items()
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
if is_gt:
|
| 279 |
+
raw_data["classes_to_gt_track_iscrowd"] = {
|
| 280 |
+
cls: [track["iscrowd"] for track in tracks]
|
| 281 |
+
for cls, tracks in classes_to_tracks.items()
|
| 282 |
+
}
|
| 283 |
+
else:
|
| 284 |
+
raw_data["classes_to_dt_track_scores"] = {
|
| 285 |
+
cls: np.array([track["score"] for track in tracks])
|
| 286 |
+
for cls, tracks in classes_to_tracks.items()
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
if is_gt:
|
| 290 |
+
key_map = {
|
| 291 |
+
"classes_to_tracks": "classes_to_gt_tracks",
|
| 292 |
+
"classes_to_track_ids": "classes_to_gt_track_ids",
|
| 293 |
+
"classes_to_track_areas": "classes_to_gt_track_areas",
|
| 294 |
+
}
|
| 295 |
+
else:
|
| 296 |
+
key_map = {
|
| 297 |
+
"classes_to_tracks": "classes_to_dt_tracks",
|
| 298 |
+
"classes_to_track_ids": "classes_to_dt_track_ids",
|
| 299 |
+
"classes_to_track_areas": "classes_to_dt_track_areas",
|
| 300 |
+
}
|
| 301 |
+
for k, v in key_map.items():
|
| 302 |
+
raw_data[v] = raw_data.pop(k)
|
| 303 |
+
|
| 304 |
+
raw_data["num_timesteps"] = num_timesteps
|
| 305 |
+
raw_data["seq"] = seq
|
| 306 |
+
return raw_data
|
| 307 |
+
|
| 308 |
+
@_timing.time
|
| 309 |
+
def get_preprocessed_seq_data(self, raw_data, cls):
|
| 310 |
+
"""Preprocess data for a single sequence for a single class ready for evaluation.
|
| 311 |
+
Inputs:
|
| 312 |
+
- raw_data is a dict containing the data for the sequence already read in by get_raw_seq_data().
|
| 313 |
+
- cls is the class to be evaluated.
|
| 314 |
+
Outputs:
|
| 315 |
+
- data is a dict containing all of the information that metrics need to perform evaluation.
|
| 316 |
+
It contains the following fields:
|
| 317 |
+
[num_timesteps, num_gt_ids, num_tracker_ids, num_gt_dets, num_tracker_dets] : integers.
|
| 318 |
+
[gt_ids, tracker_ids, tracker_confidences]: list (for each timestep) of 1D NDArrays (for each det).
|
| 319 |
+
[gt_dets, tracker_dets]: list (for each timestep) of lists of detections.
|
| 320 |
+
[similarity_scores]: list (for each timestep) of 2D NDArrays.
|
| 321 |
+
Notes:
|
| 322 |
+
General preprocessing (preproc) occurs in 4 steps. Some datasets may not use all of these steps.
|
| 323 |
+
1) Extract only detections relevant for the class to be evaluated (including distractor detections).
|
| 324 |
+
2) Match gt dets and tracker dets. Remove tracker dets that are matched to a gt det that is of a
|
| 325 |
+
distractor class, or otherwise marked as to be removed.
|
| 326 |
+
3) Remove unmatched tracker dets if they fall within a crowd ignore region or don't meet a certain
|
| 327 |
+
other criteria (e.g. are too small).
|
| 328 |
+
4) Remove gt dets that were only useful for preprocessing and not for actual evaluation.
|
| 329 |
+
After the above preprocessing steps, this function also calculates the number of gt and tracker detections
|
| 330 |
+
and unique track ids. It also relabels gt and tracker ids to be contiguous and checks that ids are
|
| 331 |
+
unique within each timestep.
|
| 332 |
+
YouTubeVIS:
|
| 333 |
+
In YouTubeVIS, the 4 preproc steps are as follow:
|
| 334 |
+
1) There are 40 classes which are evaluated separately.
|
| 335 |
+
2) No matched tracker dets are removed.
|
| 336 |
+
3) No unmatched tracker dets are removed.
|
| 337 |
+
4) No gt dets are removed.
|
| 338 |
+
Further, for TrackMAP computation track representations for the given class are accessed from a dictionary
|
| 339 |
+
and the tracks from the tracker data are sorted according to the tracker confidence.
|
| 340 |
+
"""
|
| 341 |
+
cls_id = self.class_name_to_class_id[cls]
|
| 342 |
+
|
| 343 |
+
data_keys = [
|
| 344 |
+
"gt_ids",
|
| 345 |
+
"tracker_ids",
|
| 346 |
+
"gt_dets",
|
| 347 |
+
"tracker_dets",
|
| 348 |
+
"similarity_scores",
|
| 349 |
+
]
|
| 350 |
+
data = {key: [None] * raw_data["num_timesteps"] for key in data_keys}
|
| 351 |
+
unique_gt_ids = []
|
| 352 |
+
unique_tracker_ids = []
|
| 353 |
+
num_gt_dets = 0
|
| 354 |
+
num_tracker_dets = 0
|
| 355 |
+
|
| 356 |
+
for t in range(raw_data["num_timesteps"]):
|
| 357 |
+
# Only extract relevant dets for this class for eval (cls)
|
| 358 |
+
gt_class_mask = np.atleast_1d(raw_data["gt_classes"][t] == cls_id)
|
| 359 |
+
gt_class_mask = gt_class_mask.astype(bool)
|
| 360 |
+
gt_ids = raw_data["gt_ids"][t][gt_class_mask]
|
| 361 |
+
gt_dets = [
|
| 362 |
+
raw_data["gt_dets"][t][ind]
|
| 363 |
+
for ind in range(len(gt_class_mask))
|
| 364 |
+
if gt_class_mask[ind]
|
| 365 |
+
]
|
| 366 |
+
|
| 367 |
+
tracker_class_mask = np.atleast_1d(raw_data["tracker_classes"][t] == cls_id)
|
| 368 |
+
tracker_class_mask = tracker_class_mask.astype(bool)
|
| 369 |
+
tracker_ids = raw_data["tracker_ids"][t][tracker_class_mask]
|
| 370 |
+
tracker_dets = [
|
| 371 |
+
raw_data["tracker_dets"][t][ind]
|
| 372 |
+
for ind in range(len(tracker_class_mask))
|
| 373 |
+
if tracker_class_mask[ind]
|
| 374 |
+
]
|
| 375 |
+
similarity_scores = raw_data["similarity_scores"][t][gt_class_mask, :][
|
| 376 |
+
:, tracker_class_mask
|
| 377 |
+
]
|
| 378 |
+
|
| 379 |
+
data["tracker_ids"][t] = tracker_ids
|
| 380 |
+
data["tracker_dets"][t] = tracker_dets
|
| 381 |
+
data["gt_ids"][t] = gt_ids
|
| 382 |
+
data["gt_dets"][t] = gt_dets
|
| 383 |
+
data["similarity_scores"][t] = similarity_scores
|
| 384 |
+
|
| 385 |
+
unique_gt_ids += list(np.unique(data["gt_ids"][t]))
|
| 386 |
+
unique_tracker_ids += list(np.unique(data["tracker_ids"][t]))
|
| 387 |
+
num_tracker_dets += len(data["tracker_ids"][t])
|
| 388 |
+
num_gt_dets += len(data["gt_ids"][t])
|
| 389 |
+
|
| 390 |
+
# Re-label IDs such that there are no empty IDs
|
| 391 |
+
if len(unique_gt_ids) > 0:
|
| 392 |
+
unique_gt_ids = np.unique(unique_gt_ids)
|
| 393 |
+
gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1))
|
| 394 |
+
gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids))
|
| 395 |
+
for t in range(raw_data["num_timesteps"]):
|
| 396 |
+
if len(data["gt_ids"][t]) > 0:
|
| 397 |
+
data["gt_ids"][t] = gt_id_map[data["gt_ids"][t]].astype(int)
|
| 398 |
+
if len(unique_tracker_ids) > 0:
|
| 399 |
+
unique_tracker_ids = np.unique(unique_tracker_ids)
|
| 400 |
+
tracker_id_map = np.nan * np.ones((np.max(unique_tracker_ids) + 1))
|
| 401 |
+
tracker_id_map[unique_tracker_ids] = np.arange(len(unique_tracker_ids))
|
| 402 |
+
for t in range(raw_data["num_timesteps"]):
|
| 403 |
+
if len(data["tracker_ids"][t]) > 0:
|
| 404 |
+
data["tracker_ids"][t] = tracker_id_map[
|
| 405 |
+
data["tracker_ids"][t]
|
| 406 |
+
].astype(int)
|
| 407 |
+
|
| 408 |
+
# Ensure that ids are unique per timestep.
|
| 409 |
+
self._check_unique_ids(data)
|
| 410 |
+
|
| 411 |
+
# Record overview statistics.
|
| 412 |
+
data["num_tracker_dets"] = num_tracker_dets
|
| 413 |
+
data["num_gt_dets"] = num_gt_dets
|
| 414 |
+
data["num_tracker_ids"] = len(unique_tracker_ids)
|
| 415 |
+
data["num_gt_ids"] = len(unique_gt_ids)
|
| 416 |
+
data["num_timesteps"] = raw_data["num_timesteps"]
|
| 417 |
+
data["seq"] = raw_data["seq"]
|
| 418 |
+
|
| 419 |
+
# get track representations
|
| 420 |
+
data["gt_tracks"] = raw_data["classes_to_gt_tracks"][cls_id]
|
| 421 |
+
data["gt_track_ids"] = raw_data["classes_to_gt_track_ids"][cls_id]
|
| 422 |
+
data["gt_track_areas"] = raw_data["classes_to_gt_track_areas"][cls_id]
|
| 423 |
+
data["gt_track_iscrowd"] = raw_data["classes_to_gt_track_iscrowd"][cls_id]
|
| 424 |
+
data["dt_tracks"] = raw_data["classes_to_dt_tracks"][cls_id]
|
| 425 |
+
data["dt_track_ids"] = raw_data["classes_to_dt_track_ids"][cls_id]
|
| 426 |
+
data["dt_track_areas"] = raw_data["classes_to_dt_track_areas"][cls_id]
|
| 427 |
+
data["dt_track_scores"] = raw_data["classes_to_dt_track_scores"][cls_id]
|
| 428 |
+
data["iou_type"] = "mask"
|
| 429 |
+
|
| 430 |
+
# sort tracker data tracks by tracker confidence scores
|
| 431 |
+
if data["dt_tracks"]:
|
| 432 |
+
idx = np.argsort(
|
| 433 |
+
[-score for score in data["dt_track_scores"]], kind="mergesort"
|
| 434 |
+
)
|
| 435 |
+
data["dt_track_scores"] = [data["dt_track_scores"][i] for i in idx]
|
| 436 |
+
data["dt_tracks"] = [data["dt_tracks"][i] for i in idx]
|
| 437 |
+
data["dt_track_ids"] = [data["dt_track_ids"][i] for i in idx]
|
| 438 |
+
data["dt_track_areas"] = [data["dt_track_areas"][i] for i in idx]
|
| 439 |
+
|
| 440 |
+
return data
|
| 441 |
+
|
| 442 |
+
def _calculate_similarities(self, gt_dets_t, tracker_dets_t):
|
| 443 |
+
if self.iou_type == "segm":
|
| 444 |
+
similarity_scores = self._calculate_mask_ious(
|
| 445 |
+
gt_dets_t, tracker_dets_t, is_encoded=True, do_ioa=False
|
| 446 |
+
)
|
| 447 |
+
else:
|
| 448 |
+
gt_dets_t = np.array(gt_dets_t, dtype=np.float32).reshape(-1, 4)
|
| 449 |
+
tracker_dets_t = np.array(tracker_dets_t, dtype=np.float32).reshape(-1, 4)
|
| 450 |
+
similarity_scores = self._calculate_box_ious(
|
| 451 |
+
gt_dets_t, tracker_dets_t, box_format="xywh", do_ioa=False
|
| 452 |
+
)
|
| 453 |
+
return similarity_scores
|
| 454 |
+
|
| 455 |
+
def _prepare_gt_annotations(self):
|
| 456 |
+
"""
|
| 457 |
+
Prepares GT data by rle encoding segmentations and computing the average track area.
|
| 458 |
+
:return: None
|
| 459 |
+
"""
|
| 460 |
+
if self.iou_type == "segm":
|
| 461 |
+
# only loaded when needed to reduce minimum requirements
|
| 462 |
+
from pycocotools import mask as mask_utils
|
| 463 |
+
|
| 464 |
+
for track in self.gt_data["annotations"]:
|
| 465 |
+
h = track["height"]
|
| 466 |
+
w = track["width"]
|
| 467 |
+
for i, seg in enumerate(track["segmentations"]):
|
| 468 |
+
if seg is not None and isinstance(seg["counts"], list):
|
| 469 |
+
track["segmentations"][i] = mask_utils.frPyObjects(seg, h, w)
|
| 470 |
+
areas = [a for a in track["areas"] if a]
|
| 471 |
+
if len(areas) == 0:
|
| 472 |
+
track["area"] = 0
|
| 473 |
+
else:
|
| 474 |
+
track["area"] = np.array(areas).mean()
|
| 475 |
+
else:
|
| 476 |
+
for track in self.gt_data["annotations"]:
|
| 477 |
+
# For bbox eval, compute areas from bboxes if not already available
|
| 478 |
+
areas = [a for a in track.get("areas", []) if a]
|
| 479 |
+
if not areas:
|
| 480 |
+
areas = []
|
| 481 |
+
for bbox in track.get("bboxes", []):
|
| 482 |
+
if bbox is not None:
|
| 483 |
+
areas.append(bbox[2] * bbox[3])
|
| 484 |
+
track["area"] = np.array(areas).mean() if areas else 0
|
| 485 |
+
|
| 486 |
+
def _get_tracker_seq_tracks(self, tracker, seq_id):
|
| 487 |
+
"""
|
| 488 |
+
Prepares tracker data for a given sequence. Extracts all annotations for given sequence ID, computes
|
| 489 |
+
average track area and assigns a track ID.
|
| 490 |
+
:param tracker: the given tracker
|
| 491 |
+
:param seq_id: the sequence ID
|
| 492 |
+
:return: the extracted tracks
|
| 493 |
+
"""
|
| 494 |
+
# only loaded when needed to reduce minimum requirements
|
| 495 |
+
from pycocotools import mask as mask_utils
|
| 496 |
+
|
| 497 |
+
tracks = [
|
| 498 |
+
ann for ann in self.tracker_data[tracker] if ann["video_id"] == seq_id
|
| 499 |
+
]
|
| 500 |
+
for track in tracks:
|
| 501 |
+
if "areas" not in track:
|
| 502 |
+
if self.iou_type == "segm":
|
| 503 |
+
for seg in track["segmentations"]:
|
| 504 |
+
if seg:
|
| 505 |
+
track["areas"].append(mask_utils.area(seg))
|
| 506 |
+
else:
|
| 507 |
+
track["areas"].append(None)
|
| 508 |
+
else:
|
| 509 |
+
for bbox in track["bboxes"]:
|
| 510 |
+
if bbox:
|
| 511 |
+
track["areas"].append(bbox[2] * bbox[3])
|
| 512 |
+
else:
|
| 513 |
+
track["areas"].append(None)
|
| 514 |
+
areas = [a for a in track["areas"] if a]
|
| 515 |
+
if len(areas) == 0:
|
| 516 |
+
track["area"] = 0
|
| 517 |
+
else:
|
| 518 |
+
track["area"] = np.array(areas).mean()
|
| 519 |
+
track["id"] = self.global_tid_counter
|
| 520 |
+
self.global_tid_counter += 1
|
| 521 |
+
return tracks
|
| 522 |
+
|
| 523 |
+
def get_name(self):
|
| 524 |
+
return self.dataset_name
|
sam3/eval/hota_eval_toolkit/trackeval/eval.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import traceback
|
| 6 |
+
from functools import partial
|
| 7 |
+
from multiprocessing.pool import Pool
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from . import _timing, utils
|
| 12 |
+
from .metrics import Count
|
| 13 |
+
from .utils import TrackEvalException
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import tqdm
|
| 17 |
+
|
| 18 |
+
TQDM_IMPORTED = True
|
| 19 |
+
except ImportError as _:
|
| 20 |
+
TQDM_IMPORTED = False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Evaluator:
|
| 24 |
+
"""Evaluator class for evaluating different metrics for different datasets"""
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def get_default_eval_config():
|
| 28 |
+
"""Returns the default config values for evaluation"""
|
| 29 |
+
code_path = utils.get_code_path()
|
| 30 |
+
default_config = {
|
| 31 |
+
"USE_PARALLEL": False,
|
| 32 |
+
"NUM_PARALLEL_CORES": 8,
|
| 33 |
+
"BREAK_ON_ERROR": True, # Raises exception and exits with error
|
| 34 |
+
"RETURN_ON_ERROR": False, # if not BREAK_ON_ERROR, then returns from function on error
|
| 35 |
+
"LOG_ON_ERROR": os.path.join(
|
| 36 |
+
code_path, "error_log.txt"
|
| 37 |
+
), # if not None, save any errors into a log file.
|
| 38 |
+
"PRINT_RESULTS": True,
|
| 39 |
+
"PRINT_ONLY_COMBINED": False,
|
| 40 |
+
"PRINT_CONFIG": True,
|
| 41 |
+
"TIME_PROGRESS": True,
|
| 42 |
+
"DISPLAY_LESS_PROGRESS": True,
|
| 43 |
+
"OUTPUT_SUMMARY": True,
|
| 44 |
+
"OUTPUT_EMPTY_CLASSES": True, # If False, summary files are not output for classes with no detections
|
| 45 |
+
"OUTPUT_DETAILED": True,
|
| 46 |
+
"PLOT_CURVES": True,
|
| 47 |
+
}
|
| 48 |
+
return default_config
|
| 49 |
+
|
| 50 |
+
def __init__(self, config=None):
|
| 51 |
+
"""Initialise the evaluator with a config file"""
|
| 52 |
+
self.config = utils.init_config(config, self.get_default_eval_config(), "Eval")
|
| 53 |
+
# Only run timing analysis if not run in parallel.
|
| 54 |
+
if self.config["TIME_PROGRESS"] and not self.config["USE_PARALLEL"]:
|
| 55 |
+
_timing.DO_TIMING = True
|
| 56 |
+
if self.config["DISPLAY_LESS_PROGRESS"]:
|
| 57 |
+
_timing.DISPLAY_LESS_PROGRESS = True
|
| 58 |
+
|
| 59 |
+
def _combine_results(
|
| 60 |
+
self,
|
| 61 |
+
res,
|
| 62 |
+
metrics_list,
|
| 63 |
+
metric_names,
|
| 64 |
+
dataset,
|
| 65 |
+
res_field="COMBINED_SEQ",
|
| 66 |
+
target_tag=None,
|
| 67 |
+
):
|
| 68 |
+
assert res_field.startswith("COMBINED_SEQ")
|
| 69 |
+
# collecting combined cls keys (cls averaged, det averaged, super classes)
|
| 70 |
+
tracker_list, seq_list, class_list = dataset.get_eval_info()
|
| 71 |
+
combined_cls_keys = []
|
| 72 |
+
res[res_field] = {}
|
| 73 |
+
|
| 74 |
+
# narrow the target for evaluation
|
| 75 |
+
if target_tag is not None:
|
| 76 |
+
target_video_ids = [
|
| 77 |
+
annot["video_id"]
|
| 78 |
+
for annot in dataset.gt_data["annotations"]
|
| 79 |
+
if target_tag in annot["tags"]
|
| 80 |
+
]
|
| 81 |
+
vid2name = {
|
| 82 |
+
video["id"]: video["file_names"][0].split("/")[0]
|
| 83 |
+
for video in dataset.gt_data["videos"]
|
| 84 |
+
}
|
| 85 |
+
target_video_ids = set(target_video_ids)
|
| 86 |
+
target_video = [vid2name[video_id] for video_id in target_video_ids]
|
| 87 |
+
|
| 88 |
+
if len(target_video) == 0:
|
| 89 |
+
raise TrackEvalException(
|
| 90 |
+
"No sequences found with the tag %s" % target_tag
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
target_annotations = [
|
| 94 |
+
annot
|
| 95 |
+
for annot in dataset.gt_data["annotations"]
|
| 96 |
+
if annot["video_id"] in target_video_ids
|
| 97 |
+
]
|
| 98 |
+
assert all(target_tag in annot["tags"] for annot in target_annotations), (
|
| 99 |
+
f"Not all annotations in the target sequences have the target tag {target_tag}. "
|
| 100 |
+
"We currently only support a target tag at the sequence level, not at the annotation level."
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
target_video = seq_list
|
| 104 |
+
|
| 105 |
+
# combine sequences for each class
|
| 106 |
+
for c_cls in class_list:
|
| 107 |
+
res[res_field][c_cls] = {}
|
| 108 |
+
for metric, metric_name in zip(metrics_list, metric_names):
|
| 109 |
+
curr_res = {
|
| 110 |
+
seq_key: seq_value[c_cls][metric_name]
|
| 111 |
+
for seq_key, seq_value in res.items()
|
| 112 |
+
if not seq_key.startswith("COMBINED_SEQ")
|
| 113 |
+
and seq_key in target_video
|
| 114 |
+
}
|
| 115 |
+
res[res_field][c_cls][metric_name] = metric.combine_sequences(curr_res)
|
| 116 |
+
# combine classes
|
| 117 |
+
if dataset.should_classes_combine:
|
| 118 |
+
combined_cls_keys += [
|
| 119 |
+
"cls_comb_cls_av",
|
| 120 |
+
"cls_comb_det_av",
|
| 121 |
+
"all",
|
| 122 |
+
]
|
| 123 |
+
res[res_field]["cls_comb_cls_av"] = {}
|
| 124 |
+
res[res_field]["cls_comb_det_av"] = {}
|
| 125 |
+
for metric, metric_name in zip(metrics_list, metric_names):
|
| 126 |
+
cls_res = {
|
| 127 |
+
cls_key: cls_value[metric_name]
|
| 128 |
+
for cls_key, cls_value in res[res_field].items()
|
| 129 |
+
if cls_key not in combined_cls_keys
|
| 130 |
+
}
|
| 131 |
+
res[res_field]["cls_comb_cls_av"][metric_name] = (
|
| 132 |
+
metric.combine_classes_class_averaged(cls_res)
|
| 133 |
+
)
|
| 134 |
+
res[res_field]["cls_comb_det_av"][metric_name] = (
|
| 135 |
+
metric.combine_classes_det_averaged(cls_res)
|
| 136 |
+
)
|
| 137 |
+
# combine classes to super classes
|
| 138 |
+
if dataset.use_super_categories:
|
| 139 |
+
for cat, sub_cats in dataset.super_categories.items():
|
| 140 |
+
combined_cls_keys.append(cat)
|
| 141 |
+
res[res_field][cat] = {}
|
| 142 |
+
for metric, metric_name in zip(metrics_list, metric_names):
|
| 143 |
+
cat_res = {
|
| 144 |
+
cls_key: cls_value[metric_name]
|
| 145 |
+
for cls_key, cls_value in res[res_field].items()
|
| 146 |
+
if cls_key in sub_cats
|
| 147 |
+
}
|
| 148 |
+
res[res_field][cat][metric_name] = (
|
| 149 |
+
metric.combine_classes_det_averaged(cat_res)
|
| 150 |
+
)
|
| 151 |
+
return res, combined_cls_keys
|
| 152 |
+
|
| 153 |
+
def _summarize_results(
|
| 154 |
+
self,
|
| 155 |
+
res,
|
| 156 |
+
tracker,
|
| 157 |
+
metrics_list,
|
| 158 |
+
metric_names,
|
| 159 |
+
dataset,
|
| 160 |
+
res_field,
|
| 161 |
+
combined_cls_keys,
|
| 162 |
+
):
|
| 163 |
+
config = self.config
|
| 164 |
+
output_fol = dataset.get_output_fol(tracker)
|
| 165 |
+
tracker_display_name = dataset.get_display_name(tracker)
|
| 166 |
+
for c_cls in res[
|
| 167 |
+
res_field
|
| 168 |
+
].keys(): # class_list + combined classes if calculated
|
| 169 |
+
summaries = []
|
| 170 |
+
details = []
|
| 171 |
+
num_dets = res[res_field][c_cls]["Count"]["Dets"]
|
| 172 |
+
if config["OUTPUT_EMPTY_CLASSES"] or num_dets > 0:
|
| 173 |
+
for metric, metric_name in zip(metrics_list, metric_names):
|
| 174 |
+
# for combined classes there is no per sequence evaluation
|
| 175 |
+
if c_cls in combined_cls_keys:
|
| 176 |
+
table_res = {res_field: res[res_field][c_cls][metric_name]}
|
| 177 |
+
else:
|
| 178 |
+
table_res = {
|
| 179 |
+
seq_key: seq_value[c_cls][metric_name]
|
| 180 |
+
for seq_key, seq_value in res.items()
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
if config["PRINT_RESULTS"] and config["PRINT_ONLY_COMBINED"]:
|
| 184 |
+
dont_print = (
|
| 185 |
+
dataset.should_classes_combine
|
| 186 |
+
and c_cls not in combined_cls_keys
|
| 187 |
+
)
|
| 188 |
+
if not dont_print:
|
| 189 |
+
metric.print_table(
|
| 190 |
+
{res_field: table_res[res_field]},
|
| 191 |
+
tracker_display_name,
|
| 192 |
+
c_cls,
|
| 193 |
+
res_field,
|
| 194 |
+
res_field,
|
| 195 |
+
)
|
| 196 |
+
elif config["PRINT_RESULTS"]:
|
| 197 |
+
metric.print_table(
|
| 198 |
+
table_res, tracker_display_name, c_cls, res_field, res_field
|
| 199 |
+
)
|
| 200 |
+
if config["OUTPUT_SUMMARY"]:
|
| 201 |
+
summaries.append(metric.summary_results(table_res))
|
| 202 |
+
if config["OUTPUT_DETAILED"]:
|
| 203 |
+
details.append(metric.detailed_results(table_res))
|
| 204 |
+
if config["PLOT_CURVES"]:
|
| 205 |
+
metric.plot_single_tracker_results(
|
| 206 |
+
table_res,
|
| 207 |
+
tracker_display_name,
|
| 208 |
+
c_cls,
|
| 209 |
+
output_fol,
|
| 210 |
+
)
|
| 211 |
+
if config["OUTPUT_SUMMARY"]:
|
| 212 |
+
utils.write_summary_results(summaries, c_cls, output_fol)
|
| 213 |
+
if config["OUTPUT_DETAILED"]:
|
| 214 |
+
utils.write_detailed_results(details, c_cls, output_fol)
|
| 215 |
+
|
| 216 |
+
@_timing.time
|
| 217 |
+
def evaluate(self, dataset_list, metrics_list, show_progressbar=False):
|
| 218 |
+
"""Evaluate a set of metrics on a set of datasets"""
|
| 219 |
+
config = self.config
|
| 220 |
+
metrics_list = metrics_list + [Count()] # Count metrics are always run
|
| 221 |
+
metric_names = utils.validate_metrics_list(metrics_list)
|
| 222 |
+
dataset_names = [dataset.get_name() for dataset in dataset_list]
|
| 223 |
+
output_res = {}
|
| 224 |
+
output_msg = {}
|
| 225 |
+
|
| 226 |
+
for dataset, dataset_name in zip(dataset_list, dataset_names):
|
| 227 |
+
# Get dataset info about what to evaluate
|
| 228 |
+
output_res[dataset_name] = {}
|
| 229 |
+
output_msg[dataset_name] = {}
|
| 230 |
+
tracker_list, seq_list, class_list = dataset.get_eval_info()
|
| 231 |
+
print(
|
| 232 |
+
"\nEvaluating %i tracker(s) on %i sequence(s) for %i class(es) on %s dataset using the following "
|
| 233 |
+
"metrics: %s\n"
|
| 234 |
+
% (
|
| 235 |
+
len(tracker_list),
|
| 236 |
+
len(seq_list),
|
| 237 |
+
len(class_list),
|
| 238 |
+
dataset_name,
|
| 239 |
+
", ".join(metric_names),
|
| 240 |
+
)
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Evaluate each tracker
|
| 244 |
+
for tracker in tracker_list:
|
| 245 |
+
# if not config['BREAK_ON_ERROR'] then go to next tracker without breaking
|
| 246 |
+
try:
|
| 247 |
+
# Evaluate each sequence in parallel or in series.
|
| 248 |
+
# returns a nested dict (res), indexed like: res[seq][class][metric_name][sub_metric field]
|
| 249 |
+
# e.g. res[seq_0001][pedestrian][hota][DetA]
|
| 250 |
+
print("\nEvaluating %s\n" % tracker)
|
| 251 |
+
time_start = time.time()
|
| 252 |
+
if config["USE_PARALLEL"]:
|
| 253 |
+
if show_progressbar and TQDM_IMPORTED:
|
| 254 |
+
seq_list_sorted = sorted(seq_list)
|
| 255 |
+
|
| 256 |
+
with Pool(config["NUM_PARALLEL_CORES"]) as pool, tqdm.tqdm(
|
| 257 |
+
total=len(seq_list)
|
| 258 |
+
) as pbar:
|
| 259 |
+
_eval_sequence = partial(
|
| 260 |
+
eval_sequence,
|
| 261 |
+
dataset=dataset,
|
| 262 |
+
tracker=tracker,
|
| 263 |
+
class_list=class_list,
|
| 264 |
+
metrics_list=metrics_list,
|
| 265 |
+
metric_names=metric_names,
|
| 266 |
+
)
|
| 267 |
+
results = []
|
| 268 |
+
for r in pool.imap(
|
| 269 |
+
_eval_sequence, seq_list_sorted, chunksize=20
|
| 270 |
+
):
|
| 271 |
+
results.append(r)
|
| 272 |
+
pbar.update()
|
| 273 |
+
res = dict(zip(seq_list_sorted, results))
|
| 274 |
+
|
| 275 |
+
else:
|
| 276 |
+
with Pool(config["NUM_PARALLEL_CORES"]) as pool:
|
| 277 |
+
_eval_sequence = partial(
|
| 278 |
+
eval_sequence,
|
| 279 |
+
dataset=dataset,
|
| 280 |
+
tracker=tracker,
|
| 281 |
+
class_list=class_list,
|
| 282 |
+
metrics_list=metrics_list,
|
| 283 |
+
metric_names=metric_names,
|
| 284 |
+
)
|
| 285 |
+
results = pool.map(_eval_sequence, seq_list)
|
| 286 |
+
res = dict(zip(seq_list, results))
|
| 287 |
+
else:
|
| 288 |
+
res = {}
|
| 289 |
+
if show_progressbar and TQDM_IMPORTED:
|
| 290 |
+
seq_list_sorted = sorted(seq_list)
|
| 291 |
+
for curr_seq in tqdm.tqdm(seq_list_sorted):
|
| 292 |
+
res[curr_seq] = eval_sequence(
|
| 293 |
+
curr_seq,
|
| 294 |
+
dataset,
|
| 295 |
+
tracker,
|
| 296 |
+
class_list,
|
| 297 |
+
metrics_list,
|
| 298 |
+
metric_names,
|
| 299 |
+
)
|
| 300 |
+
else:
|
| 301 |
+
for curr_seq in sorted(seq_list):
|
| 302 |
+
res[curr_seq] = eval_sequence(
|
| 303 |
+
curr_seq,
|
| 304 |
+
dataset,
|
| 305 |
+
tracker,
|
| 306 |
+
class_list,
|
| 307 |
+
metrics_list,
|
| 308 |
+
metric_names,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Combine results over all sequences and then over all classes
|
| 312 |
+
res, combined_cls_keys = self._combine_results(
|
| 313 |
+
res, metrics_list, metric_names, dataset, "COMBINED_SEQ"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
if np.all(
|
| 317 |
+
["tags" in annot for annot in dataset.gt_data["annotations"]]
|
| 318 |
+
):
|
| 319 |
+
# Combine results over the challenging sequences and then over all classes
|
| 320 |
+
# currently only support "tracking_challenging_pair"
|
| 321 |
+
res, _ = self._combine_results(
|
| 322 |
+
res,
|
| 323 |
+
metrics_list,
|
| 324 |
+
metric_names,
|
| 325 |
+
dataset,
|
| 326 |
+
"COMBINED_SEQ_CHALLENGING",
|
| 327 |
+
"tracking_challenging_pair",
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Print and output results in various formats
|
| 331 |
+
if config["TIME_PROGRESS"]:
|
| 332 |
+
print(
|
| 333 |
+
"\nAll sequences for %s finished in %.2f seconds"
|
| 334 |
+
% (tracker, time.time() - time_start)
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
self._summarize_results(
|
| 338 |
+
res,
|
| 339 |
+
tracker,
|
| 340 |
+
metrics_list,
|
| 341 |
+
metric_names,
|
| 342 |
+
dataset,
|
| 343 |
+
"COMBINED_SEQ",
|
| 344 |
+
combined_cls_keys,
|
| 345 |
+
)
|
| 346 |
+
if "COMBINED_SEQ_CHALLENGING" in res:
|
| 347 |
+
self._summarize_results(
|
| 348 |
+
res,
|
| 349 |
+
tracker,
|
| 350 |
+
metrics_list,
|
| 351 |
+
metric_names,
|
| 352 |
+
dataset,
|
| 353 |
+
"COMBINED_SEQ_CHALLENGING",
|
| 354 |
+
combined_cls_keys,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# Output for returning from function
|
| 358 |
+
output_res[dataset_name][tracker] = res
|
| 359 |
+
output_msg[dataset_name][tracker] = "Success"
|
| 360 |
+
|
| 361 |
+
except Exception as err:
|
| 362 |
+
output_res[dataset_name][tracker] = None
|
| 363 |
+
if type(err) == TrackEvalException:
|
| 364 |
+
output_msg[dataset_name][tracker] = str(err)
|
| 365 |
+
else:
|
| 366 |
+
output_msg[dataset_name][tracker] = "Unknown error occurred."
|
| 367 |
+
print("Tracker %s was unable to be evaluated." % tracker)
|
| 368 |
+
print(err)
|
| 369 |
+
traceback.print_exc()
|
| 370 |
+
if config["LOG_ON_ERROR"] is not None:
|
| 371 |
+
with open(config["LOG_ON_ERROR"], "a") as f:
|
| 372 |
+
print(dataset_name, file=f)
|
| 373 |
+
print(tracker, file=f)
|
| 374 |
+
print(traceback.format_exc(), file=f)
|
| 375 |
+
print("\n\n\n", file=f)
|
| 376 |
+
if config["BREAK_ON_ERROR"]:
|
| 377 |
+
raise err
|
| 378 |
+
elif config["RETURN_ON_ERROR"]:
|
| 379 |
+
return output_res, output_msg
|
| 380 |
+
|
| 381 |
+
return output_res, output_msg
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
@_timing.time
|
| 385 |
+
def eval_sequence(seq, dataset, tracker, class_list, metrics_list, metric_names):
|
| 386 |
+
"""Function for evaluating a single sequence"""
|
| 387 |
+
|
| 388 |
+
raw_data = dataset.get_raw_seq_data(tracker, seq)
|
| 389 |
+
seq_res = {}
|
| 390 |
+
for cls in class_list:
|
| 391 |
+
seq_res[cls] = {}
|
| 392 |
+
data = dataset.get_preprocessed_seq_data(raw_data, cls)
|
| 393 |
+
for metric, met_name in zip(metrics_list, metric_names):
|
| 394 |
+
seq_res[cls][met_name] = metric.eval_sequence(data)
|
| 395 |
+
return seq_res
|
sam3/eval/hota_eval_toolkit/trackeval/metrics/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
from .count import Count
|
| 4 |
+
from .hota import HOTA
|
sam3/eval/hota_eval_toolkit/trackeval/metrics/_base_metric.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from .. import _timing
|
| 8 |
+
from ..utils import TrackEvalException
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class _BaseMetric(ABC):
|
| 12 |
+
@abstractmethod
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self.plottable = False
|
| 15 |
+
self.integer_fields = []
|
| 16 |
+
self.float_fields = []
|
| 17 |
+
self.array_labels = []
|
| 18 |
+
self.integer_array_fields = []
|
| 19 |
+
self.float_array_fields = []
|
| 20 |
+
self.fields = []
|
| 21 |
+
self.summary_fields = []
|
| 22 |
+
self.registered = False
|
| 23 |
+
|
| 24 |
+
#####################################################################
|
| 25 |
+
# Abstract functions for subclasses to implement
|
| 26 |
+
|
| 27 |
+
@_timing.time
|
| 28 |
+
@abstractmethod
|
| 29 |
+
def eval_sequence(self, data): ...
|
| 30 |
+
|
| 31 |
+
@abstractmethod
|
| 32 |
+
def combine_sequences(self, all_res): ...
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False): ...
|
| 36 |
+
|
| 37 |
+
@abstractmethod
|
| 38 |
+
def combine_classes_det_averaged(self, all_res): ...
|
| 39 |
+
|
| 40 |
+
def plot_single_tracker_results(self, all_res, tracker, output_folder, cls):
|
| 41 |
+
"""Plot results of metrics, only valid for metrics with self.plottable"""
|
| 42 |
+
if self.plottable:
|
| 43 |
+
raise NotImplementedError(
|
| 44 |
+
"plot_results is not implemented for metric %s" % self.get_name()
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
#####################################################################
|
| 50 |
+
# Helper functions which are useful for all metrics:
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def get_name(cls):
|
| 54 |
+
return cls.__name__
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def _combine_sum(all_res, field):
|
| 58 |
+
"""Combine sequence results via sum"""
|
| 59 |
+
return sum([all_res[k][field] for k in all_res.keys()])
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def _combine_weighted_av(all_res, field, comb_res, weight_field):
|
| 63 |
+
"""Combine sequence results via weighted average"""
|
| 64 |
+
return sum(
|
| 65 |
+
[all_res[k][field] * all_res[k][weight_field] for k in all_res.keys()]
|
| 66 |
+
) / np.maximum(1.0, comb_res[weight_field])
|
| 67 |
+
|
| 68 |
+
def print_table(
|
| 69 |
+
self, table_res, tracker, cls, res_field="COMBINED_SEQ", output_lable="COMBINED"
|
| 70 |
+
):
|
| 71 |
+
"""Prints table of results for all sequences"""
|
| 72 |
+
print("")
|
| 73 |
+
metric_name = self.get_name()
|
| 74 |
+
self._row_print(
|
| 75 |
+
[metric_name + ": " + tracker + "-" + cls] + self.summary_fields
|
| 76 |
+
)
|
| 77 |
+
for seq, results in sorted(table_res.items()):
|
| 78 |
+
if seq.startswith("COMBINED_SEQ"):
|
| 79 |
+
continue
|
| 80 |
+
summary_res = self._summary_row(results)
|
| 81 |
+
self._row_print([seq] + summary_res)
|
| 82 |
+
summary_res = self._summary_row(table_res[res_field])
|
| 83 |
+
self._row_print([output_lable] + summary_res)
|
| 84 |
+
|
| 85 |
+
def _summary_row(self, results_):
|
| 86 |
+
vals = []
|
| 87 |
+
for h in self.summary_fields:
|
| 88 |
+
if h in self.float_array_fields:
|
| 89 |
+
vals.append("{0:1.5g}".format(100 * np.mean(results_[h])))
|
| 90 |
+
elif h in self.float_fields:
|
| 91 |
+
vals.append("{0:1.5g}".format(100 * float(results_[h])))
|
| 92 |
+
elif h in self.integer_fields:
|
| 93 |
+
vals.append("{0:d}".format(int(results_[h])))
|
| 94 |
+
else:
|
| 95 |
+
raise NotImplementedError(
|
| 96 |
+
"Summary function not implemented for this field type."
|
| 97 |
+
)
|
| 98 |
+
return vals
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def _row_print(*argv):
|
| 102 |
+
"""Prints results in an evenly spaced rows, with more space in first row"""
|
| 103 |
+
if len(argv) == 1:
|
| 104 |
+
argv = argv[0]
|
| 105 |
+
to_print = "%-35s" % argv[0]
|
| 106 |
+
for v in argv[1:]:
|
| 107 |
+
to_print += "%-10s" % str(v)
|
| 108 |
+
print(to_print)
|
| 109 |
+
|
| 110 |
+
def summary_results(self, table_res):
|
| 111 |
+
"""Returns a simple summary of final results for a tracker"""
|
| 112 |
+
return dict(
|
| 113 |
+
zip(self.summary_fields, self._summary_row(table_res["COMBINED_SEQ"]))
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def detailed_results(self, table_res):
|
| 117 |
+
"""Returns detailed final results for a tracker"""
|
| 118 |
+
# Get detailed field information
|
| 119 |
+
detailed_fields = self.float_fields + self.integer_fields
|
| 120 |
+
for h in self.float_array_fields + self.integer_array_fields:
|
| 121 |
+
for alpha in [int(100 * x) for x in self.array_labels]:
|
| 122 |
+
detailed_fields.append(h + "___" + str(alpha))
|
| 123 |
+
detailed_fields.append(h + "___AUC")
|
| 124 |
+
|
| 125 |
+
# Get detailed results
|
| 126 |
+
detailed_results = {}
|
| 127 |
+
for seq, res in table_res.items():
|
| 128 |
+
detailed_row = self._detailed_row(res)
|
| 129 |
+
if len(detailed_row) != len(detailed_fields):
|
| 130 |
+
raise TrackEvalException(
|
| 131 |
+
"Field names and data have different sizes (%i and %i)"
|
| 132 |
+
% (len(detailed_row), len(detailed_fields))
|
| 133 |
+
)
|
| 134 |
+
detailed_results[seq] = dict(zip(detailed_fields, detailed_row))
|
| 135 |
+
return detailed_results
|
| 136 |
+
|
| 137 |
+
def _detailed_row(self, res):
|
| 138 |
+
detailed_row = []
|
| 139 |
+
for h in self.float_fields + self.integer_fields:
|
| 140 |
+
detailed_row.append(res[h])
|
| 141 |
+
for h in self.float_array_fields + self.integer_array_fields:
|
| 142 |
+
for i, alpha in enumerate([int(100 * x) for x in self.array_labels]):
|
| 143 |
+
detailed_row.append(res[h][i])
|
| 144 |
+
detailed_row.append(np.mean(res[h]))
|
| 145 |
+
return detailed_row
|
sam3/eval/hota_eval_toolkit/trackeval/metrics/count.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
from .. import _timing
|
| 4 |
+
from ._base_metric import _BaseMetric
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Count(_BaseMetric):
|
| 8 |
+
"""Class which simply counts the number of tracker and gt detections and ids."""
|
| 9 |
+
|
| 10 |
+
def __init__(self, config=None):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.integer_fields = ["Dets", "GT_Dets", "IDs", "GT_IDs"]
|
| 13 |
+
self.fields = self.integer_fields
|
| 14 |
+
self.summary_fields = self.fields
|
| 15 |
+
|
| 16 |
+
@_timing.time
|
| 17 |
+
def eval_sequence(self, data):
|
| 18 |
+
"""Returns counts for one sequence"""
|
| 19 |
+
# Get results
|
| 20 |
+
res = {
|
| 21 |
+
"Dets": data["num_tracker_dets"],
|
| 22 |
+
"GT_Dets": data["num_gt_dets"],
|
| 23 |
+
"IDs": data["num_tracker_ids"],
|
| 24 |
+
"GT_IDs": data["num_gt_ids"],
|
| 25 |
+
"Frames": data["num_timesteps"],
|
| 26 |
+
}
|
| 27 |
+
return res
|
| 28 |
+
|
| 29 |
+
def combine_sequences(self, all_res):
|
| 30 |
+
"""Combines metrics across all sequences"""
|
| 31 |
+
res = {}
|
| 32 |
+
for field in self.integer_fields:
|
| 33 |
+
res[field] = self._combine_sum(all_res, field)
|
| 34 |
+
return res
|
| 35 |
+
|
| 36 |
+
def combine_classes_class_averaged(self, all_res, ignore_empty_classes=None):
|
| 37 |
+
"""Combines metrics across all classes by averaging over the class values"""
|
| 38 |
+
res = {}
|
| 39 |
+
for field in self.integer_fields:
|
| 40 |
+
res[field] = self._combine_sum(all_res, field)
|
| 41 |
+
return res
|
| 42 |
+
|
| 43 |
+
def combine_classes_det_averaged(self, all_res):
|
| 44 |
+
"""Combines metrics across all classes by averaging over the detection values"""
|
| 45 |
+
res = {}
|
| 46 |
+
for field in self.integer_fields:
|
| 47 |
+
res[field] = self._combine_sum(all_res, field)
|
| 48 |
+
return res
|
sam3/eval/hota_eval_toolkit/trackeval/metrics/hota.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy.optimize import linear_sum_assignment
|
| 7 |
+
|
| 8 |
+
from .. import _timing
|
| 9 |
+
from ._base_metric import _BaseMetric
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class HOTA(_BaseMetric):
|
| 13 |
+
"""Class which implements the HOTA metrics.
|
| 14 |
+
See: https://link.springer.com/article/10.1007/s11263-020-01375-2
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, config=None):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.plottable = True
|
| 20 |
+
self.array_labels = np.arange(0.05, 0.99, 0.05)
|
| 21 |
+
self.integer_array_fields = ["HOTA_TP", "HOTA_FN", "HOTA_FP"]
|
| 22 |
+
self.float_array_fields = [
|
| 23 |
+
"HOTA",
|
| 24 |
+
"DetA",
|
| 25 |
+
"AssA",
|
| 26 |
+
"DetRe",
|
| 27 |
+
"DetPr",
|
| 28 |
+
"AssRe",
|
| 29 |
+
"AssPr",
|
| 30 |
+
"LocA",
|
| 31 |
+
"OWTA",
|
| 32 |
+
]
|
| 33 |
+
self.float_fields = ["HOTA(0)", "LocA(0)", "HOTALocA(0)"]
|
| 34 |
+
self.fields = (
|
| 35 |
+
self.float_array_fields + self.integer_array_fields + self.float_fields
|
| 36 |
+
)
|
| 37 |
+
self.summary_fields = self.float_array_fields + self.float_fields
|
| 38 |
+
|
| 39 |
+
@_timing.time
|
| 40 |
+
def eval_sequence(self, data):
|
| 41 |
+
"""Calculates the HOTA metrics for one sequence"""
|
| 42 |
+
|
| 43 |
+
# Initialise results
|
| 44 |
+
res = {}
|
| 45 |
+
for field in self.float_array_fields + self.integer_array_fields:
|
| 46 |
+
res[field] = np.zeros((len(self.array_labels)), dtype=float)
|
| 47 |
+
for field in self.float_fields:
|
| 48 |
+
res[field] = 0
|
| 49 |
+
|
| 50 |
+
# Return result quickly if tracker or gt sequence is empty
|
| 51 |
+
if data["num_tracker_dets"] == 0:
|
| 52 |
+
res["HOTA_FN"] = data["num_gt_dets"] * np.ones(
|
| 53 |
+
(len(self.array_labels)), dtype=float
|
| 54 |
+
)
|
| 55 |
+
res["LocA"] = np.ones((len(self.array_labels)), dtype=float)
|
| 56 |
+
res["LocA(0)"] = 1.0
|
| 57 |
+
return res
|
| 58 |
+
if data["num_gt_dets"] == 0:
|
| 59 |
+
res["HOTA_FP"] = data["num_tracker_dets"] * np.ones(
|
| 60 |
+
(len(self.array_labels)), dtype=float
|
| 61 |
+
)
|
| 62 |
+
res["LocA"] = np.ones((len(self.array_labels)), dtype=float)
|
| 63 |
+
res["LocA(0)"] = 1.0
|
| 64 |
+
return res
|
| 65 |
+
|
| 66 |
+
# Variables counting global association
|
| 67 |
+
potential_matches_count = np.zeros(
|
| 68 |
+
(data["num_gt_ids"], data["num_tracker_ids"])
|
| 69 |
+
)
|
| 70 |
+
gt_id_count = np.zeros((data["num_gt_ids"], 1))
|
| 71 |
+
tracker_id_count = np.zeros((1, data["num_tracker_ids"]))
|
| 72 |
+
|
| 73 |
+
# First loop through each timestep and accumulate global track information.
|
| 74 |
+
for t, (gt_ids_t, tracker_ids_t) in enumerate(
|
| 75 |
+
zip(data["gt_ids"], data["tracker_ids"])
|
| 76 |
+
):
|
| 77 |
+
# Count the potential matches between ids in each timestep
|
| 78 |
+
# These are normalised, weighted by the match similarity.
|
| 79 |
+
similarity = data["similarity_scores"][t]
|
| 80 |
+
sim_iou_denom = (
|
| 81 |
+
similarity.sum(0)[np.newaxis, :]
|
| 82 |
+
+ similarity.sum(1)[:, np.newaxis]
|
| 83 |
+
- similarity
|
| 84 |
+
)
|
| 85 |
+
sim_iou = np.zeros_like(similarity)
|
| 86 |
+
sim_iou_mask = sim_iou_denom > 0 + np.finfo("float").eps
|
| 87 |
+
sim_iou[sim_iou_mask] = (
|
| 88 |
+
similarity[sim_iou_mask] / sim_iou_denom[sim_iou_mask]
|
| 89 |
+
)
|
| 90 |
+
potential_matches_count[
|
| 91 |
+
gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]
|
| 92 |
+
] += sim_iou
|
| 93 |
+
|
| 94 |
+
# Calculate the total number of dets for each gt_id and tracker_id.
|
| 95 |
+
gt_id_count[gt_ids_t] += 1
|
| 96 |
+
tracker_id_count[0, tracker_ids_t] += 1
|
| 97 |
+
|
| 98 |
+
# Calculate overall jaccard alignment score (before unique matching) between IDs
|
| 99 |
+
global_alignment_score = potential_matches_count / (
|
| 100 |
+
gt_id_count + tracker_id_count - potential_matches_count
|
| 101 |
+
)
|
| 102 |
+
matches_counts = [
|
| 103 |
+
np.zeros_like(potential_matches_count) for _ in self.array_labels
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
# Calculate scores for each timestep
|
| 107 |
+
for t, (gt_ids_t, tracker_ids_t) in enumerate(
|
| 108 |
+
zip(data["gt_ids"], data["tracker_ids"])
|
| 109 |
+
):
|
| 110 |
+
# Deal with the case that there are no gt_det/tracker_det in a timestep.
|
| 111 |
+
if len(gt_ids_t) == 0:
|
| 112 |
+
for a, alpha in enumerate(self.array_labels):
|
| 113 |
+
res["HOTA_FP"][a] += len(tracker_ids_t)
|
| 114 |
+
continue
|
| 115 |
+
if len(tracker_ids_t) == 0:
|
| 116 |
+
for a, alpha in enumerate(self.array_labels):
|
| 117 |
+
res["HOTA_FN"][a] += len(gt_ids_t)
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
# Get matching scores between pairs of dets for optimizing HOTA
|
| 121 |
+
similarity = data["similarity_scores"][t]
|
| 122 |
+
score_mat = (
|
| 123 |
+
global_alignment_score[
|
| 124 |
+
gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]
|
| 125 |
+
]
|
| 126 |
+
* similarity
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Hungarian algorithm to find best matches
|
| 130 |
+
match_rows, match_cols = linear_sum_assignment(-score_mat)
|
| 131 |
+
|
| 132 |
+
# Calculate and accumulate basic statistics
|
| 133 |
+
for a, alpha in enumerate(self.array_labels):
|
| 134 |
+
actually_matched_mask = (
|
| 135 |
+
similarity[match_rows, match_cols] >= alpha - np.finfo("float").eps
|
| 136 |
+
)
|
| 137 |
+
alpha_match_rows = match_rows[actually_matched_mask]
|
| 138 |
+
alpha_match_cols = match_cols[actually_matched_mask]
|
| 139 |
+
num_matches = len(alpha_match_rows)
|
| 140 |
+
res["HOTA_TP"][a] += num_matches
|
| 141 |
+
res["HOTA_FN"][a] += len(gt_ids_t) - num_matches
|
| 142 |
+
res["HOTA_FP"][a] += len(tracker_ids_t) - num_matches
|
| 143 |
+
if num_matches > 0:
|
| 144 |
+
res["LocA"][a] += sum(
|
| 145 |
+
similarity[alpha_match_rows, alpha_match_cols]
|
| 146 |
+
)
|
| 147 |
+
matches_counts[a][
|
| 148 |
+
gt_ids_t[alpha_match_rows], tracker_ids_t[alpha_match_cols]
|
| 149 |
+
] += 1
|
| 150 |
+
|
| 151 |
+
# Calculate association scores (AssA, AssRe, AssPr) for the alpha value.
|
| 152 |
+
# First calculate scores per gt_id/tracker_id combo and then average over the number of detections.
|
| 153 |
+
for a, alpha in enumerate(self.array_labels):
|
| 154 |
+
matches_count = matches_counts[a]
|
| 155 |
+
ass_a = matches_count / np.maximum(
|
| 156 |
+
1, gt_id_count + tracker_id_count - matches_count
|
| 157 |
+
)
|
| 158 |
+
res["AssA"][a] = np.sum(matches_count * ass_a) / np.maximum(
|
| 159 |
+
1, res["HOTA_TP"][a]
|
| 160 |
+
)
|
| 161 |
+
ass_re = matches_count / np.maximum(1, gt_id_count)
|
| 162 |
+
res["AssRe"][a] = np.sum(matches_count * ass_re) / np.maximum(
|
| 163 |
+
1, res["HOTA_TP"][a]
|
| 164 |
+
)
|
| 165 |
+
ass_pr = matches_count / np.maximum(1, tracker_id_count)
|
| 166 |
+
res["AssPr"][a] = np.sum(matches_count * ass_pr) / np.maximum(
|
| 167 |
+
1, res["HOTA_TP"][a]
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Calculate final scores
|
| 171 |
+
res["LocA"] = np.maximum(1e-10, res["LocA"]) / np.maximum(1e-10, res["HOTA_TP"])
|
| 172 |
+
res = self._compute_final_fields(res)
|
| 173 |
+
return res
|
| 174 |
+
|
| 175 |
+
def combine_sequences(self, all_res):
|
| 176 |
+
"""Combines metrics across all sequences"""
|
| 177 |
+
res = {}
|
| 178 |
+
for field in self.integer_array_fields:
|
| 179 |
+
res[field] = self._combine_sum(all_res, field)
|
| 180 |
+
for field in ["AssRe", "AssPr", "AssA"]:
|
| 181 |
+
res[field] = self._combine_weighted_av(
|
| 182 |
+
all_res, field, res, weight_field="HOTA_TP"
|
| 183 |
+
)
|
| 184 |
+
loca_weighted_sum = sum(
|
| 185 |
+
[all_res[k]["LocA"] * all_res[k]["HOTA_TP"] for k in all_res.keys()]
|
| 186 |
+
)
|
| 187 |
+
res["LocA"] = np.maximum(1e-10, loca_weighted_sum) / np.maximum(
|
| 188 |
+
1e-10, res["HOTA_TP"]
|
| 189 |
+
)
|
| 190 |
+
res = self._compute_final_fields(res)
|
| 191 |
+
return res
|
| 192 |
+
|
| 193 |
+
def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
|
| 194 |
+
"""Combines metrics across all classes by averaging over the class values.
|
| 195 |
+
If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection.
|
| 196 |
+
"""
|
| 197 |
+
res = {}
|
| 198 |
+
for field in self.integer_array_fields:
|
| 199 |
+
if ignore_empty_classes:
|
| 200 |
+
res[field] = self._combine_sum(
|
| 201 |
+
{
|
| 202 |
+
k: v
|
| 203 |
+
for k, v in all_res.items()
|
| 204 |
+
if (
|
| 205 |
+
v["HOTA_TP"] + v["HOTA_FN"] + v["HOTA_FP"]
|
| 206 |
+
> 0 + np.finfo("float").eps
|
| 207 |
+
).any()
|
| 208 |
+
},
|
| 209 |
+
field,
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
res[field] = self._combine_sum(
|
| 213 |
+
{k: v for k, v in all_res.items()}, field
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
for field in self.float_fields + self.float_array_fields:
|
| 217 |
+
if ignore_empty_classes:
|
| 218 |
+
res[field] = np.mean(
|
| 219 |
+
[
|
| 220 |
+
v[field]
|
| 221 |
+
for v in all_res.values()
|
| 222 |
+
if (
|
| 223 |
+
v["HOTA_TP"] + v["HOTA_FN"] + v["HOTA_FP"]
|
| 224 |
+
> 0 + np.finfo("float").eps
|
| 225 |
+
).any()
|
| 226 |
+
],
|
| 227 |
+
axis=0,
|
| 228 |
+
)
|
| 229 |
+
else:
|
| 230 |
+
res[field] = np.mean([v[field] for v in all_res.values()], axis=0)
|
| 231 |
+
return res
|
| 232 |
+
|
| 233 |
+
def combine_classes_det_averaged(self, all_res):
|
| 234 |
+
"""Combines metrics across all classes by averaging over the detection values"""
|
| 235 |
+
res = {}
|
| 236 |
+
for field in self.integer_array_fields:
|
| 237 |
+
res[field] = self._combine_sum(all_res, field)
|
| 238 |
+
for field in ["AssRe", "AssPr", "AssA"]:
|
| 239 |
+
res[field] = self._combine_weighted_av(
|
| 240 |
+
all_res, field, res, weight_field="HOTA_TP"
|
| 241 |
+
)
|
| 242 |
+
loca_weighted_sum = sum(
|
| 243 |
+
[all_res[k]["LocA"] * all_res[k]["HOTA_TP"] for k in all_res.keys()]
|
| 244 |
+
)
|
| 245 |
+
res["LocA"] = np.maximum(1e-10, loca_weighted_sum) / np.maximum(
|
| 246 |
+
1e-10, res["HOTA_TP"]
|
| 247 |
+
)
|
| 248 |
+
res = self._compute_final_fields(res)
|
| 249 |
+
return res
|
| 250 |
+
|
| 251 |
+
@staticmethod
|
| 252 |
+
def _compute_final_fields(res):
|
| 253 |
+
"""Calculate sub-metric ('field') values which only depend on other sub-metric values.
|
| 254 |
+
This function is used both for both per-sequence calculation, and in combining values across sequences.
|
| 255 |
+
"""
|
| 256 |
+
res["DetRe"] = res["HOTA_TP"] / np.maximum(1, res["HOTA_TP"] + res["HOTA_FN"])
|
| 257 |
+
res["DetPr"] = res["HOTA_TP"] / np.maximum(1, res["HOTA_TP"] + res["HOTA_FP"])
|
| 258 |
+
res["DetA"] = res["HOTA_TP"] / np.maximum(
|
| 259 |
+
1, res["HOTA_TP"] + res["HOTA_FN"] + res["HOTA_FP"]
|
| 260 |
+
)
|
| 261 |
+
res["HOTA"] = np.sqrt(res["DetA"] * res["AssA"])
|
| 262 |
+
res["OWTA"] = np.sqrt(res["DetRe"] * res["AssA"])
|
| 263 |
+
|
| 264 |
+
res["HOTA(0)"] = res["HOTA"][0]
|
| 265 |
+
res["LocA(0)"] = res["LocA"][0]
|
| 266 |
+
res["HOTALocA(0)"] = res["HOTA(0)"] * res["LocA(0)"]
|
| 267 |
+
return res
|
| 268 |
+
|
| 269 |
+
def plot_single_tracker_results(self, table_res, tracker, cls, output_folder):
|
| 270 |
+
"""Create plot of results"""
|
| 271 |
+
|
| 272 |
+
# Only loaded when run to reduce minimum requirements
|
| 273 |
+
from matplotlib import pyplot as plt
|
| 274 |
+
|
| 275 |
+
res = table_res["COMBINED_SEQ"]
|
| 276 |
+
styles_to_plot = ["r", "b", "g", "b--", "b:", "g--", "g:", "m"]
|
| 277 |
+
for name, style in zip(self.float_array_fields, styles_to_plot):
|
| 278 |
+
plt.plot(self.array_labels, res[name], style)
|
| 279 |
+
plt.xlabel("alpha")
|
| 280 |
+
plt.ylabel("score")
|
| 281 |
+
plt.title(tracker + " - " + cls)
|
| 282 |
+
plt.axis([0, 1, 0, 1])
|
| 283 |
+
legend = []
|
| 284 |
+
for name in self.float_array_fields:
|
| 285 |
+
legend += [name + " (" + str(np.round(np.mean(res[name]), 2)) + ")"]
|
| 286 |
+
plt.legend(legend, loc="lower left")
|
| 287 |
+
out_file = os.path.join(output_folder, cls + "_plot.pdf")
|
| 288 |
+
os.makedirs(os.path.dirname(out_file), exist_ok=True)
|
| 289 |
+
plt.savefig(out_file)
|
| 290 |
+
plt.savefig(out_file.replace(".pdf", ".png"))
|
| 291 |
+
plt.clf()
|
sam3/eval/hota_eval_toolkit/trackeval/utils.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import csv
|
| 5 |
+
import os
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def init_config(config, default_config, name=None):
|
| 10 |
+
"""Initialise non-given config values with defaults"""
|
| 11 |
+
if config is None:
|
| 12 |
+
config = default_config
|
| 13 |
+
else:
|
| 14 |
+
for k in default_config.keys():
|
| 15 |
+
if k not in config.keys():
|
| 16 |
+
config[k] = default_config[k]
|
| 17 |
+
if name and config["PRINT_CONFIG"]:
|
| 18 |
+
print("\n%s Config:" % name)
|
| 19 |
+
for c in config.keys():
|
| 20 |
+
print("%-20s : %-30s" % (c, config[c]))
|
| 21 |
+
return config
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def update_config(config):
|
| 25 |
+
"""
|
| 26 |
+
Parse the arguments of a script and updates the config values for a given value if specified in the arguments.
|
| 27 |
+
:param config: the config to update
|
| 28 |
+
:return: the updated config
|
| 29 |
+
"""
|
| 30 |
+
parser = argparse.ArgumentParser()
|
| 31 |
+
for setting in config.keys():
|
| 32 |
+
if type(config[setting]) == list or type(config[setting]) == type(None):
|
| 33 |
+
parser.add_argument("--" + setting, nargs="+")
|
| 34 |
+
else:
|
| 35 |
+
parser.add_argument("--" + setting)
|
| 36 |
+
args = parser.parse_args().__dict__
|
| 37 |
+
for setting in args.keys():
|
| 38 |
+
if args[setting] is not None:
|
| 39 |
+
if type(config[setting]) == type(True):
|
| 40 |
+
if args[setting] == "True":
|
| 41 |
+
x = True
|
| 42 |
+
elif args[setting] == "False":
|
| 43 |
+
x = False
|
| 44 |
+
else:
|
| 45 |
+
raise Exception(
|
| 46 |
+
"Command line parameter " + setting + "must be True or False"
|
| 47 |
+
)
|
| 48 |
+
elif type(config[setting]) == type(1):
|
| 49 |
+
x = int(args[setting])
|
| 50 |
+
elif type(args[setting]) == type(None):
|
| 51 |
+
x = None
|
| 52 |
+
else:
|
| 53 |
+
x = args[setting]
|
| 54 |
+
config[setting] = x
|
| 55 |
+
return config
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_code_path():
|
| 59 |
+
"""Get base path where code is"""
|
| 60 |
+
return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def validate_metrics_list(metrics_list):
|
| 64 |
+
"""Get names of metric class and ensures they are unique, further checks that the fields within each metric class
|
| 65 |
+
do not have overlapping names.
|
| 66 |
+
"""
|
| 67 |
+
metric_names = [metric.get_name() for metric in metrics_list]
|
| 68 |
+
# check metric names are unique
|
| 69 |
+
if len(metric_names) != len(set(metric_names)):
|
| 70 |
+
raise TrackEvalException(
|
| 71 |
+
"Code being run with multiple metrics of the same name"
|
| 72 |
+
)
|
| 73 |
+
fields = []
|
| 74 |
+
for m in metrics_list:
|
| 75 |
+
fields += m.fields
|
| 76 |
+
# check metric fields are unique
|
| 77 |
+
if len(fields) != len(set(fields)):
|
| 78 |
+
raise TrackEvalException(
|
| 79 |
+
"Code being run with multiple metrics with fields of the same name"
|
| 80 |
+
)
|
| 81 |
+
return metric_names
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def write_summary_results(summaries, cls, output_folder):
|
| 85 |
+
"""Write summary results to file"""
|
| 86 |
+
|
| 87 |
+
fields = sum([list(s.keys()) for s in summaries], [])
|
| 88 |
+
values = sum([list(s.values()) for s in summaries], [])
|
| 89 |
+
|
| 90 |
+
# In order to remain consistent upon new fields being adding, for each of the following fields if they are present
|
| 91 |
+
# they will be output in the summary first in the order below. Any further fields will be output in the order each
|
| 92 |
+
# metric family is called, and within each family either in the order they were added to the dict (python >= 3.6) or
|
| 93 |
+
# randomly (python < 3.6).
|
| 94 |
+
default_order = [
|
| 95 |
+
"HOTA",
|
| 96 |
+
"DetA",
|
| 97 |
+
"AssA",
|
| 98 |
+
"DetRe",
|
| 99 |
+
"DetPr",
|
| 100 |
+
"AssRe",
|
| 101 |
+
"AssPr",
|
| 102 |
+
"LocA",
|
| 103 |
+
"OWTA",
|
| 104 |
+
"HOTA(0)",
|
| 105 |
+
"LocA(0)",
|
| 106 |
+
"HOTALocA(0)",
|
| 107 |
+
"MOTA",
|
| 108 |
+
"MOTP",
|
| 109 |
+
"MODA",
|
| 110 |
+
"CLR_Re",
|
| 111 |
+
"CLR_Pr",
|
| 112 |
+
"MTR",
|
| 113 |
+
"PTR",
|
| 114 |
+
"MLR",
|
| 115 |
+
"CLR_TP",
|
| 116 |
+
"CLR_FN",
|
| 117 |
+
"CLR_FP",
|
| 118 |
+
"IDSW",
|
| 119 |
+
"MT",
|
| 120 |
+
"PT",
|
| 121 |
+
"ML",
|
| 122 |
+
"Frag",
|
| 123 |
+
"sMOTA",
|
| 124 |
+
"IDF1",
|
| 125 |
+
"IDR",
|
| 126 |
+
"IDP",
|
| 127 |
+
"IDTP",
|
| 128 |
+
"IDFN",
|
| 129 |
+
"IDFP",
|
| 130 |
+
"Dets",
|
| 131 |
+
"GT_Dets",
|
| 132 |
+
"IDs",
|
| 133 |
+
"GT_IDs",
|
| 134 |
+
]
|
| 135 |
+
default_ordered_dict = OrderedDict(
|
| 136 |
+
zip(default_order, [None for _ in default_order])
|
| 137 |
+
)
|
| 138 |
+
for f, v in zip(fields, values):
|
| 139 |
+
default_ordered_dict[f] = v
|
| 140 |
+
for df in default_order:
|
| 141 |
+
if default_ordered_dict[df] is None:
|
| 142 |
+
del default_ordered_dict[df]
|
| 143 |
+
fields = list(default_ordered_dict.keys())
|
| 144 |
+
values = list(default_ordered_dict.values())
|
| 145 |
+
|
| 146 |
+
out_file = os.path.join(output_folder, cls + "_summary.txt")
|
| 147 |
+
os.makedirs(os.path.dirname(out_file), exist_ok=True)
|
| 148 |
+
with open(out_file, "w", newline="") as f:
|
| 149 |
+
writer = csv.writer(f, delimiter=" ")
|
| 150 |
+
writer.writerow(fields)
|
| 151 |
+
writer.writerow(values)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def write_detailed_results(details, cls, output_folder):
|
| 155 |
+
"""Write detailed results to file"""
|
| 156 |
+
sequences = details[0].keys()
|
| 157 |
+
fields = ["seq"] + sum([list(s["COMBINED_SEQ"].keys()) for s in details], [])
|
| 158 |
+
out_file = os.path.join(output_folder, cls + "_detailed.csv")
|
| 159 |
+
os.makedirs(os.path.dirname(out_file), exist_ok=True)
|
| 160 |
+
with open(out_file, "w", newline="") as f:
|
| 161 |
+
writer = csv.writer(f)
|
| 162 |
+
writer.writerow(fields)
|
| 163 |
+
for seq in sorted(sequences):
|
| 164 |
+
if seq == "COMBINED_SEQ":
|
| 165 |
+
continue
|
| 166 |
+
writer.writerow([seq] + sum([list(s[seq].values()) for s in details], []))
|
| 167 |
+
writer.writerow(
|
| 168 |
+
["COMBINED"] + sum([list(s["COMBINED_SEQ"].values()) for s in details], [])
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def load_detail(file):
|
| 173 |
+
"""Loads detailed data for a tracker."""
|
| 174 |
+
data = {}
|
| 175 |
+
with open(file) as f:
|
| 176 |
+
for i, row_text in enumerate(f):
|
| 177 |
+
row = row_text.replace("\r", "").replace("\n", "").split(",")
|
| 178 |
+
if i == 0:
|
| 179 |
+
keys = row[1:]
|
| 180 |
+
continue
|
| 181 |
+
current_values = row[1:]
|
| 182 |
+
seq = row[0]
|
| 183 |
+
if seq == "COMBINED":
|
| 184 |
+
seq = "COMBINED_SEQ"
|
| 185 |
+
if (len(current_values) == len(keys)) and seq != "":
|
| 186 |
+
data[seq] = {}
|
| 187 |
+
for key, value in zip(keys, current_values):
|
| 188 |
+
data[seq][key] = float(value)
|
| 189 |
+
return data
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class TrackEvalException(Exception):
|
| 193 |
+
"""Custom exception for catching expected errors."""
|
| 194 |
+
|
| 195 |
+
...
|
sam3/eval/postprocessors.py
ADDED
|
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
"""Postprocessors class to transform MDETR output according to the downstream task"""
|
| 4 |
+
|
| 5 |
+
import dataclasses
|
| 6 |
+
import logging
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from typing import Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from sam3.model import box_ops
|
| 13 |
+
from sam3.model.data_misc import BatchedInferenceMetadata, interpolate
|
| 14 |
+
from sam3.train.masks_ops import rle_encode, robust_rle_encode
|
| 15 |
+
from torch import nn
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class PostProcessNullOp(nn.Module):
|
| 19 |
+
def __init__(self, **kwargs):
|
| 20 |
+
super(PostProcessNullOp).__init__()
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
def forward(self, input):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
def process_results(self, **kwargs):
|
| 27 |
+
return kwargs["find_stages"]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PostProcessImage(nn.Module):
|
| 31 |
+
"""This module converts the model's output into the format expected by the coco api"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
max_dets_per_img: int,
|
| 36 |
+
iou_type="bbox",
|
| 37 |
+
to_cpu: bool = True,
|
| 38 |
+
use_original_ids: bool = False,
|
| 39 |
+
use_original_sizes_box: bool = False,
|
| 40 |
+
use_original_sizes_mask: bool = False,
|
| 41 |
+
convert_mask_to_rle: bool = False,
|
| 42 |
+
always_interpolate_masks_on_gpu: bool = True,
|
| 43 |
+
use_presence: bool = True,
|
| 44 |
+
detection_threshold: float = -1.0,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.max_dets_per_img = max_dets_per_img
|
| 48 |
+
self.iou_type = iou_type
|
| 49 |
+
self.to_cpu = to_cpu
|
| 50 |
+
self.convert_mask_to_rle = convert_mask_to_rle
|
| 51 |
+
self.always_interpolate_masks_on_gpu = always_interpolate_masks_on_gpu
|
| 52 |
+
|
| 53 |
+
self.use_presence = use_presence
|
| 54 |
+
self.detection_threshold = detection_threshold
|
| 55 |
+
self.use_original_ids = use_original_ids
|
| 56 |
+
self.use_original_sizes_box = use_original_sizes_box
|
| 57 |
+
self.use_original_sizes_mask = use_original_sizes_mask
|
| 58 |
+
|
| 59 |
+
@torch.no_grad()
|
| 60 |
+
def forward(
|
| 61 |
+
self,
|
| 62 |
+
outputs,
|
| 63 |
+
target_sizes_boxes,
|
| 64 |
+
target_sizes_masks,
|
| 65 |
+
forced_labels=None,
|
| 66 |
+
consistent=False,
|
| 67 |
+
ret_tensordict: bool = False, # This is experimental
|
| 68 |
+
):
|
| 69 |
+
"""Perform the computation
|
| 70 |
+
Parameters:
|
| 71 |
+
outputs: raw outputs of the model
|
| 72 |
+
target_sizes_boxes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
| 73 |
+
For evaluation, this must be the original image size (before any data augmentation)
|
| 74 |
+
For visualization, this should be the image size after data augment, but before padding
|
| 75 |
+
target_sizes_masks: same but used to resize masks
|
| 76 |
+
forced_labels: tensor of dimension [batch_size] containing the label to force for each image of the batch
|
| 77 |
+
This is useful when evaluating the model using standard metrics (eg on COCO, LVIS). In that case,
|
| 78 |
+
we query the model with every possible class label, so we when we pass the predictions to the evaluator,
|
| 79 |
+
we want to make sure that the predicted "class" matches the one that was queried.
|
| 80 |
+
consistent: whether all target sizes are equal
|
| 81 |
+
ret_tensordict: Experimental argument. If true, return a tensordict.TensorDict instead of a list of dictionaries for easier manipulation.
|
| 82 |
+
"""
|
| 83 |
+
if ret_tensordict:
|
| 84 |
+
assert (
|
| 85 |
+
consistent is True
|
| 86 |
+
), "We don't support returning TensorDict if the outputs have different shapes" # NOTE: It's possible but we don't support it.
|
| 87 |
+
assert self.detection_threshold <= 0.0, "TODO: implement?"
|
| 88 |
+
try:
|
| 89 |
+
from tensordict import TensorDict
|
| 90 |
+
except ImportError:
|
| 91 |
+
logging.info(
|
| 92 |
+
"tensordict is not installed. Install by running `pip install tensordict --no-deps`. Falling back by setting `ret_tensordict=False`"
|
| 93 |
+
)
|
| 94 |
+
ret_tensordict = False
|
| 95 |
+
|
| 96 |
+
out_bbox = outputs["pred_boxes"] if "pred_boxes" in outputs else None
|
| 97 |
+
out_logits = outputs["pred_logits"]
|
| 98 |
+
pred_masks = outputs["pred_masks"] if self.iou_type == "segm" else None
|
| 99 |
+
out_probs = out_logits.sigmoid()
|
| 100 |
+
if self.use_presence:
|
| 101 |
+
presence_score = outputs["presence_logit_dec"].sigmoid().unsqueeze(1)
|
| 102 |
+
out_probs = out_probs * presence_score
|
| 103 |
+
|
| 104 |
+
assert target_sizes_boxes.shape[1] == 2
|
| 105 |
+
assert target_sizes_masks.shape[1] == 2
|
| 106 |
+
batch_size = target_sizes_boxes.shape[0]
|
| 107 |
+
|
| 108 |
+
boxes, scores, labels, keep = self._process_boxes_and_labels(
|
| 109 |
+
target_sizes_boxes, forced_labels, out_bbox, out_probs
|
| 110 |
+
)
|
| 111 |
+
assert boxes is None or len(boxes) == batch_size
|
| 112 |
+
out_masks = self._process_masks(
|
| 113 |
+
target_sizes_masks, pred_masks, consistent=consistent, keep=keep
|
| 114 |
+
)
|
| 115 |
+
del pred_masks
|
| 116 |
+
|
| 117 |
+
if boxes is None:
|
| 118 |
+
assert out_masks is not None
|
| 119 |
+
assert not ret_tensordict, "We don't support returning TensorDict if the output does not contain boxes"
|
| 120 |
+
B = len(out_masks)
|
| 121 |
+
boxes = [None] * B
|
| 122 |
+
scores = [None] * B
|
| 123 |
+
labels = [None] * B
|
| 124 |
+
|
| 125 |
+
results = {
|
| 126 |
+
"scores": scores,
|
| 127 |
+
"labels": labels,
|
| 128 |
+
"boxes": boxes,
|
| 129 |
+
}
|
| 130 |
+
if out_masks is not None:
|
| 131 |
+
if self.convert_mask_to_rle:
|
| 132 |
+
results.update(masks_rle=out_masks)
|
| 133 |
+
else:
|
| 134 |
+
results.update(masks=out_masks)
|
| 135 |
+
|
| 136 |
+
if ret_tensordict:
|
| 137 |
+
results = TensorDict(results).auto_batch_size_()
|
| 138 |
+
if self.to_cpu:
|
| 139 |
+
results = results.cpu()
|
| 140 |
+
else:
|
| 141 |
+
# Convert a dictonary of lists/tensors to list of dictionaries
|
| 142 |
+
results = [
|
| 143 |
+
dict(zip(results.keys(), res_tuple))
|
| 144 |
+
for res_tuple in zip(*results.values())
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
return results
|
| 148 |
+
|
| 149 |
+
def _process_masks(self, target_sizes, pred_masks, consistent=True, keep=None):
|
| 150 |
+
if pred_masks is None:
|
| 151 |
+
return None
|
| 152 |
+
if self.always_interpolate_masks_on_gpu:
|
| 153 |
+
gpu_device = target_sizes.device
|
| 154 |
+
assert gpu_device.type == "cuda"
|
| 155 |
+
pred_masks = pred_masks.to(device=gpu_device)
|
| 156 |
+
if consistent:
|
| 157 |
+
assert keep is None, "TODO: implement?"
|
| 158 |
+
# All masks should have the same shape, expected when processing a batch of size 1
|
| 159 |
+
target_size = target_sizes.unique(dim=0)
|
| 160 |
+
assert target_size.size(0) == 1, "Expecting all target sizes to be equal"
|
| 161 |
+
out_masks = (
|
| 162 |
+
interpolate(
|
| 163 |
+
pred_masks,
|
| 164 |
+
target_size.squeeze().tolist(),
|
| 165 |
+
mode="bilinear",
|
| 166 |
+
align_corners=False,
|
| 167 |
+
).sigmoid()
|
| 168 |
+
> 0.5
|
| 169 |
+
)
|
| 170 |
+
if self.convert_mask_to_rle:
|
| 171 |
+
raise RuntimeError("TODO: implement?")
|
| 172 |
+
if self.to_cpu:
|
| 173 |
+
out_masks = out_masks.cpu()
|
| 174 |
+
else:
|
| 175 |
+
out_masks = [[]] * len(pred_masks)
|
| 176 |
+
|
| 177 |
+
assert keep is None or len(keep) == len(pred_masks)
|
| 178 |
+
for i, mask in enumerate(pred_masks):
|
| 179 |
+
h, w = target_sizes[i]
|
| 180 |
+
if keep is not None:
|
| 181 |
+
mask = mask[keep[i]]
|
| 182 |
+
# Uses the gpu version fist, moves masks to cpu if it fails"""
|
| 183 |
+
try:
|
| 184 |
+
interpolated = (
|
| 185 |
+
interpolate(
|
| 186 |
+
mask.unsqueeze(1),
|
| 187 |
+
(h, w),
|
| 188 |
+
mode="bilinear",
|
| 189 |
+
align_corners=False,
|
| 190 |
+
).sigmoid()
|
| 191 |
+
> 0.5
|
| 192 |
+
)
|
| 193 |
+
except Exception as e:
|
| 194 |
+
logging.info("Issue found, reverting to CPU mode!")
|
| 195 |
+
mask_device = mask.device
|
| 196 |
+
mask = mask.cpu()
|
| 197 |
+
interpolated = (
|
| 198 |
+
interpolate(
|
| 199 |
+
mask.unsqueeze(1),
|
| 200 |
+
(h, w),
|
| 201 |
+
mode="bilinear",
|
| 202 |
+
align_corners=False,
|
| 203 |
+
).sigmoid()
|
| 204 |
+
> 0.5
|
| 205 |
+
)
|
| 206 |
+
interpolated = interpolated.to(mask_device)
|
| 207 |
+
|
| 208 |
+
if self.convert_mask_to_rle:
|
| 209 |
+
out_masks[i] = robust_rle_encode(interpolated.squeeze(1))
|
| 210 |
+
else:
|
| 211 |
+
out_masks[i] = interpolated
|
| 212 |
+
if self.to_cpu:
|
| 213 |
+
out_masks[i] = out_masks[i].cpu()
|
| 214 |
+
|
| 215 |
+
return out_masks
|
| 216 |
+
|
| 217 |
+
def _process_boxes_and_labels(
|
| 218 |
+
self, target_sizes, forced_labels, out_bbox, out_probs
|
| 219 |
+
):
|
| 220 |
+
if out_bbox is None:
|
| 221 |
+
return None, None, None, None
|
| 222 |
+
assert len(out_probs) == len(target_sizes)
|
| 223 |
+
if self.to_cpu:
|
| 224 |
+
out_probs = out_probs.cpu()
|
| 225 |
+
scores, labels = out_probs.max(-1)
|
| 226 |
+
if forced_labels is None:
|
| 227 |
+
labels = torch.ones_like(labels)
|
| 228 |
+
else:
|
| 229 |
+
labels = forced_labels[:, None].expand_as(labels)
|
| 230 |
+
|
| 231 |
+
# convert to [x0, y0, x1, y1] format
|
| 232 |
+
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
|
| 233 |
+
|
| 234 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 235 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
| 236 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 237 |
+
|
| 238 |
+
if self.to_cpu:
|
| 239 |
+
boxes = boxes.cpu()
|
| 240 |
+
|
| 241 |
+
keep = None
|
| 242 |
+
if self.detection_threshold > 0:
|
| 243 |
+
# Filter out the boxes with scores below the detection threshold
|
| 244 |
+
keep = scores > self.detection_threshold
|
| 245 |
+
assert len(keep) == len(boxes) == len(scores) == len(labels)
|
| 246 |
+
|
| 247 |
+
boxes = [b[k.to(b.device)] for b, k in zip(boxes, keep)]
|
| 248 |
+
scores = [s[k.to(s.device)] for s, k in zip(scores, keep)]
|
| 249 |
+
labels = [l[k.to(l.device)] for l, k in zip(labels, keep)]
|
| 250 |
+
|
| 251 |
+
return boxes, scores, labels, keep
|
| 252 |
+
|
| 253 |
+
def process_results(
|
| 254 |
+
self, find_stages, find_metadatas: List[BatchedInferenceMetadata], **kwargs
|
| 255 |
+
):
|
| 256 |
+
if find_stages.loss_stages is not None:
|
| 257 |
+
find_metadatas = [find_metadatas[i] for i in find_stages.loss_stages]
|
| 258 |
+
assert len(find_stages) == len(find_metadatas)
|
| 259 |
+
results = {}
|
| 260 |
+
for outputs, meta in zip(find_stages, find_metadatas):
|
| 261 |
+
img_size_for_boxes = (
|
| 262 |
+
meta.original_size
|
| 263 |
+
if self.use_original_sizes_box
|
| 264 |
+
else torch.ones_like(meta.original_size)
|
| 265 |
+
)
|
| 266 |
+
img_size_for_masks = (
|
| 267 |
+
meta.original_size
|
| 268 |
+
if self.use_original_sizes_mask
|
| 269 |
+
else torch.ones_like(meta.original_size)
|
| 270 |
+
)
|
| 271 |
+
detection_results = self(
|
| 272 |
+
outputs,
|
| 273 |
+
img_size_for_boxes,
|
| 274 |
+
img_size_for_masks,
|
| 275 |
+
forced_labels=(
|
| 276 |
+
meta.original_category_id if self.use_original_ids else None
|
| 277 |
+
),
|
| 278 |
+
)
|
| 279 |
+
ids = (
|
| 280 |
+
meta.original_image_id if self.use_original_ids else meta.coco_image_id
|
| 281 |
+
)
|
| 282 |
+
assert len(detection_results) == len(ids)
|
| 283 |
+
for img_id, result in zip(ids, detection_results):
|
| 284 |
+
if img_id.item() not in results:
|
| 285 |
+
results[img_id.item()] = result
|
| 286 |
+
else:
|
| 287 |
+
assert set(results[img_id.item()].keys()) == set(result.keys())
|
| 288 |
+
for k in result.keys():
|
| 289 |
+
if isinstance(result[k], torch.Tensor):
|
| 290 |
+
results[img_id.item()][k] = torch.cat(
|
| 291 |
+
[results[img_id.item()][k], result[k]], dim=0
|
| 292 |
+
)
|
| 293 |
+
elif isinstance(result[k], list):
|
| 294 |
+
results[img_id.item()][k] += result[k]
|
| 295 |
+
else:
|
| 296 |
+
raise NotImplementedError(
|
| 297 |
+
f"Unexpected type {type(result[k])} in result."
|
| 298 |
+
)
|
| 299 |
+
# Prune the results to the max number of detections per image.
|
| 300 |
+
for img_id, result in results.items():
|
| 301 |
+
if (
|
| 302 |
+
self.max_dets_per_img > 0
|
| 303 |
+
and len(result["scores"]) > self.max_dets_per_img
|
| 304 |
+
):
|
| 305 |
+
_, topk_indexes = torch.topk(
|
| 306 |
+
result["scores"], self.max_dets_per_img, dim=0
|
| 307 |
+
)
|
| 308 |
+
if self.to_cpu:
|
| 309 |
+
topk_indexes = topk_indexes.cpu()
|
| 310 |
+
for k in result.keys():
|
| 311 |
+
if isinstance(results[img_id][k], list):
|
| 312 |
+
results[img_id][k] = [
|
| 313 |
+
results[img_id][k][i] for i in topk_indexes.tolist()
|
| 314 |
+
]
|
| 315 |
+
else:
|
| 316 |
+
results[img_id][k] = results[img_id][k].to(topk_indexes.device)[
|
| 317 |
+
topk_indexes
|
| 318 |
+
]
|
| 319 |
+
|
| 320 |
+
return results
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class PostProcessAPIVideo(PostProcessImage):
|
| 324 |
+
"""This module converts the video model's output into the format expected by the YT-VIS api"""
|
| 325 |
+
|
| 326 |
+
def __init__(
|
| 327 |
+
self,
|
| 328 |
+
*args,
|
| 329 |
+
to_cpu: bool = True,
|
| 330 |
+
convert_mask_to_rle: bool = False,
|
| 331 |
+
always_interpolate_masks_on_gpu: bool = True,
|
| 332 |
+
prob_thresh: float = 0.5,
|
| 333 |
+
use_presence: bool = False,
|
| 334 |
+
**kwargs,
|
| 335 |
+
):
|
| 336 |
+
super().__init__(
|
| 337 |
+
*args,
|
| 338 |
+
# Here we always set `convert_mask_to_rle=False` in the base `PostProcessAPI` class
|
| 339 |
+
# (so that its `_process_masks` won't return a list of RLEs). If we want to return
|
| 340 |
+
# RLEs for video masklets, we handle it in this `PostProcessAPIVideo` class instead.
|
| 341 |
+
convert_mask_to_rle=False,
|
| 342 |
+
# Here we always set `to_cpu=False` in the base `PostProcessAPI` class (so that
|
| 343 |
+
# the interpolated masks won't be automatically moved back to CPU). We will handle
|
| 344 |
+
# it in this `PostProcessAPIVideo` class instead.
|
| 345 |
+
always_interpolate_masks_on_gpu=always_interpolate_masks_on_gpu,
|
| 346 |
+
use_presence=use_presence,
|
| 347 |
+
**kwargs,
|
| 348 |
+
)
|
| 349 |
+
# Expected keys in the output dict to postprocess
|
| 350 |
+
self.EXPECTED_KEYS = [
|
| 351 |
+
"pred_logits",
|
| 352 |
+
"pred_boxes",
|
| 353 |
+
"pred_masks",
|
| 354 |
+
]
|
| 355 |
+
# Whether to post-process video masklets (under packed representation) into RLE format
|
| 356 |
+
self.convert_mask_to_rle_for_video = convert_mask_to_rle
|
| 357 |
+
self.to_cpu_for_video = to_cpu
|
| 358 |
+
self.prob_thresh = prob_thresh
|
| 359 |
+
|
| 360 |
+
def process_results(
|
| 361 |
+
self, find_stages, find_metadatas: List[BatchedInferenceMetadata], **kwargs
|
| 362 |
+
):
|
| 363 |
+
"""
|
| 364 |
+
Tracking Postprocessor for SAM 3 video model.
|
| 365 |
+
This function takes in the output of the SAM 3 video model and processes it to extract all the tracklet predictions.
|
| 366 |
+
Args:
|
| 367 |
+
find_stages: A list of tensors representing the output of the SAM 3 video model.
|
| 368 |
+
find_metadatas: A list of BatchedInferenceMetadata objects containing metadata about each frame.
|
| 369 |
+
**kwargs: Additional keyword arguments.
|
| 370 |
+
Returns:
|
| 371 |
+
A dictionary of predcitions with video_id as key.
|
| 372 |
+
"""
|
| 373 |
+
|
| 374 |
+
# Import tensordict here to avoid global dependency.
|
| 375 |
+
try:
|
| 376 |
+
from tensordict import TensorDict
|
| 377 |
+
except ImportError as e:
|
| 378 |
+
logging.error(
|
| 379 |
+
"tensordict is not installed, please install by running `pip install tensordict --no-deps`"
|
| 380 |
+
)
|
| 381 |
+
raise e
|
| 382 |
+
# Notes and assumptions:
|
| 383 |
+
# 1- This postprocessor assumes results only for a single video.
|
| 384 |
+
# 2- There are N stage outputs corresponding to N video frames
|
| 385 |
+
# 3- Each stage outputs contains PxQ preds, where P is number of prompts and Q is number of object queries. The output should also contain the tracking object ids corresponding to each object query.
|
| 386 |
+
# 4- The tracking object id has a default value of -1, indicating that the object query is not tracking any object in the frame, and hence its predictions can be ingored for a given frame.
|
| 387 |
+
# 5- Some objects may be tracked in a subset of frames only. So, we first extract the predictions in a packed representation (for efficient postprocessing -- specially memory)
|
| 388 |
+
# and then we convert the packed representation into a padded one, where we zero pad boxes/masks for objects that are not tracked in some frames.
|
| 389 |
+
# 6- We refer to objects by an object id, which is a tuple (prompt_idx, obj_id)
|
| 390 |
+
|
| 391 |
+
assert len(find_stages) > 0, "There is nothing to postprocess?"
|
| 392 |
+
PROMPT_AXIS, OBJ_QUERY_AXIS = (0, 1)
|
| 393 |
+
NO_OBJ_ID = -1
|
| 394 |
+
# Maps object ID -> [indices in packed tensor]
|
| 395 |
+
tracked_objects_packed_idx = defaultdict(list)
|
| 396 |
+
# Maps object ID -> [indices in padded tensor (abs frame index)]
|
| 397 |
+
tracked_objects_frame_idx = defaultdict(list)
|
| 398 |
+
total_num_preds = 0
|
| 399 |
+
# This will hold the packed representation of predictions.
|
| 400 |
+
vid_preds_packed: List[TensorDict] = []
|
| 401 |
+
vid_masklets_rle_packed: List[Optional[Dict]] = []
|
| 402 |
+
video_id = -1 # We assume single video postprocessing, this ID should be unique in the datapoint.
|
| 403 |
+
|
| 404 |
+
for frame_idx, (frame_outs, meta) in enumerate(
|
| 405 |
+
zip(find_stages, find_metadatas)
|
| 406 |
+
):
|
| 407 |
+
# only store keys we need to extract the results
|
| 408 |
+
frame_outs_td = TensorDict(
|
| 409 |
+
{k: frame_outs[k] for k in self.EXPECTED_KEYS}
|
| 410 |
+
).auto_batch_size_() # Shape is [P,Q,...]
|
| 411 |
+
meta_td = TensorDict(
|
| 412 |
+
dataclasses.asdict(meta)
|
| 413 |
+
).auto_batch_size_() # Shape is [P,...]
|
| 414 |
+
unique_vid_id = meta.original_image_id.unique()
|
| 415 |
+
assert unique_vid_id.size(0) == 1
|
| 416 |
+
if video_id == -1:
|
| 417 |
+
video_id = unique_vid_id.item()
|
| 418 |
+
else:
|
| 419 |
+
assert (
|
| 420 |
+
video_id == unique_vid_id.item()
|
| 421 |
+
), "We can only postprocess one video per datapoint"
|
| 422 |
+
# keeping track of which objects appear in the current frame
|
| 423 |
+
obj_ids_per_frame = frame_outs["pred_object_ids"]
|
| 424 |
+
assert obj_ids_per_frame.size(-1) == frame_outs["pred_logits"].size(-2)
|
| 425 |
+
if self.prob_thresh is not None:
|
| 426 |
+
# only keep the predictions on this frame with probability above the threshold
|
| 427 |
+
# (remove those predictions during the keep-alive period of a tracking query,
|
| 428 |
+
# where its "pred_object_ids" is still the tracked object ID rather than -1)
|
| 429 |
+
pred_probs = frame_outs["pred_logits"].sigmoid().squeeze(-1)
|
| 430 |
+
obj_ids_per_frame = torch.where(
|
| 431 |
+
pred_probs >= self.prob_thresh, obj_ids_per_frame, NO_OBJ_ID
|
| 432 |
+
)
|
| 433 |
+
tracked_obj_ids_idx = torch.where(obj_ids_per_frame != NO_OBJ_ID)
|
| 434 |
+
# Object id is a tuple of (prompt_idx, obj_id). This is because the model can assign same obj_id for two different prompts.
|
| 435 |
+
tracked_obj_ids = [
|
| 436 |
+
(p_id.item(), obj_ids_per_frame[p_id, q_id].item())
|
| 437 |
+
for p_id, q_id in zip(
|
| 438 |
+
tracked_obj_ids_idx[PROMPT_AXIS],
|
| 439 |
+
tracked_obj_ids_idx[OBJ_QUERY_AXIS],
|
| 440 |
+
)
|
| 441 |
+
]
|
| 442 |
+
if len(tracked_obj_ids) == 0:
|
| 443 |
+
continue
|
| 444 |
+
# For each object, we keep track of the packed and padded (frame index) indices
|
| 445 |
+
for oid in tracked_obj_ids:
|
| 446 |
+
tracked_objects_packed_idx[oid].append(total_num_preds)
|
| 447 |
+
tracked_objects_frame_idx[oid].append(frame_idx)
|
| 448 |
+
total_num_preds += 1
|
| 449 |
+
|
| 450 |
+
# Since we have P*Q masks per frame, mask interpolation is the GPU memory bottleneck or time bottleneck in case of cpu processing.
|
| 451 |
+
# Instead, we first extract results only for tracked objects, reducing the number of masks to K = sum_i(tracked_objs_per_ith_prompt), hopefully <<< P*Q
|
| 452 |
+
tracked_objs_outs_td = frame_outs_td[
|
| 453 |
+
tracked_obj_ids_idx
|
| 454 |
+
] # [P,Q,...] --> [K,...]
|
| 455 |
+
meta_td = meta_td[tracked_obj_ids_idx[PROMPT_AXIS].cpu()]
|
| 456 |
+
if self.always_interpolate_masks_on_gpu:
|
| 457 |
+
gpu_device = meta_td["original_size"].device
|
| 458 |
+
assert gpu_device.type == "cuda"
|
| 459 |
+
tracked_objs_outs_td = tracked_objs_outs_td.to(device=gpu_device)
|
| 460 |
+
frame_results_td = self(
|
| 461 |
+
tracked_objs_outs_td.unsqueeze(1),
|
| 462 |
+
(
|
| 463 |
+
meta_td["original_size"]
|
| 464 |
+
if self.use_original_sizes
|
| 465 |
+
else torch.ones_like(meta_td["original_size"])
|
| 466 |
+
),
|
| 467 |
+
forced_labels=(
|
| 468 |
+
meta_td["original_category_id"] if self.use_original_ids else None
|
| 469 |
+
),
|
| 470 |
+
consistent=True,
|
| 471 |
+
ret_tensordict=True,
|
| 472 |
+
).squeeze(1)
|
| 473 |
+
del tracked_objs_outs_td
|
| 474 |
+
|
| 475 |
+
# Optionally, remove "masks" from output tensor dict and directly encode them
|
| 476 |
+
# to RLE format under packed representations
|
| 477 |
+
if self.convert_mask_to_rle_for_video:
|
| 478 |
+
interpolated_binary_masks = frame_results_td.pop("masks")
|
| 479 |
+
rle_list = rle_encode(interpolated_binary_masks, return_areas=True)
|
| 480 |
+
vid_masklets_rle_packed.extend(rle_list)
|
| 481 |
+
# Optionally, move output TensorDict to CPU (do this after RLE encoding step above)
|
| 482 |
+
if self.to_cpu_for_video:
|
| 483 |
+
frame_results_td = frame_results_td.cpu()
|
| 484 |
+
vid_preds_packed.append(frame_results_td)
|
| 485 |
+
|
| 486 |
+
if len(vid_preds_packed) == 0:
|
| 487 |
+
logging.debug(f"Video {video_id} has no predictions")
|
| 488 |
+
return {video_id: []}
|
| 489 |
+
|
| 490 |
+
vid_preds_packed = torch.cat(vid_preds_packed, dim=0)
|
| 491 |
+
############### Construct a padded representation of the predictions ###############
|
| 492 |
+
num_preds = len(tracked_objects_packed_idx)
|
| 493 |
+
num_frames = len(find_stages)
|
| 494 |
+
# We zero pad any missing prediction
|
| 495 |
+
# NOTE: here, we also have padded tensors for "scores" and "labels", but we overwrite them later.
|
| 496 |
+
padded_frames_results = TensorDict(
|
| 497 |
+
{
|
| 498 |
+
k: torch.zeros(
|
| 499 |
+
num_preds, num_frames, *v.shape[1:], device=v.device, dtype=v.dtype
|
| 500 |
+
)
|
| 501 |
+
for k, v in vid_preds_packed.items()
|
| 502 |
+
},
|
| 503 |
+
batch_size=[
|
| 504 |
+
num_preds,
|
| 505 |
+
num_frames,
|
| 506 |
+
],
|
| 507 |
+
)
|
| 508 |
+
padded_frames_results["scores"][...] = -1e8 # a very low score for empty object
|
| 509 |
+
# Track scores and labels of each pred tracklet, only for frames where the model was able to track that object
|
| 510 |
+
tracklet_scores = []
|
| 511 |
+
tracklet_labels = []
|
| 512 |
+
# Optionally, fill the list of RLEs for masklets
|
| 513 |
+
# note: only frames with actual predicted masks (in packed format) will be
|
| 514 |
+
# filled with RLEs; the rest will remains None in results["masks_rle"]
|
| 515 |
+
if self.convert_mask_to_rle_for_video:
|
| 516 |
+
vid_masklets_rle_padded = [[None] * num_frames for _ in range(num_preds)]
|
| 517 |
+
for o_idx, oid in enumerate(tracked_objects_packed_idx):
|
| 518 |
+
oid2packed_idx = tracked_objects_packed_idx[oid]
|
| 519 |
+
oid2padded_idx = tracked_objects_frame_idx[oid]
|
| 520 |
+
obj_packed_results = vid_preds_packed[oid2packed_idx]
|
| 521 |
+
padded_frames_results[o_idx][oid2padded_idx] = obj_packed_results
|
| 522 |
+
if self.convert_mask_to_rle_for_video:
|
| 523 |
+
for packed_idx, padded_idx in zip(oid2packed_idx, oid2padded_idx):
|
| 524 |
+
vid_masklets_rle_padded[o_idx][padded_idx] = (
|
| 525 |
+
vid_masklets_rle_packed[packed_idx]
|
| 526 |
+
)
|
| 527 |
+
# NOTE: We need a single confidence score per tracklet for the mAP metric.
|
| 528 |
+
# We use the average confidence score across time. (How does this impact AP?)
|
| 529 |
+
tracklet_scores.append(obj_packed_results["scores"].mean())
|
| 530 |
+
# We also need to have a unique category Id per tracklet.
|
| 531 |
+
# This is not a problem for phrase AP, however, for mAP we do majority voting across time.
|
| 532 |
+
tracklet_labels.append(obj_packed_results["labels"].mode()[0])
|
| 533 |
+
|
| 534 |
+
results = padded_frames_results.to_dict()
|
| 535 |
+
results["scores"] = torch.stack(tracklet_scores, dim=0)
|
| 536 |
+
results["labels"] = torch.stack(tracklet_labels, dim=0)
|
| 537 |
+
if self.convert_mask_to_rle_for_video:
|
| 538 |
+
results["masks_rle"] = vid_masklets_rle_padded
|
| 539 |
+
# we keep the frame-level scores since it's needed by some evaluation scripts
|
| 540 |
+
results["per_frame_scores"] = padded_frames_results["scores"]
|
| 541 |
+
|
| 542 |
+
return {video_id: results}
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
class PostProcessTracking(PostProcessImage):
|
| 546 |
+
"""This module converts the model's output into the format expected by the coco api"""
|
| 547 |
+
|
| 548 |
+
def __init__(
|
| 549 |
+
self,
|
| 550 |
+
max_dets_per_img: int,
|
| 551 |
+
iou_type="bbox",
|
| 552 |
+
force_single_mask: bool = False,
|
| 553 |
+
**kwargs,
|
| 554 |
+
) -> None:
|
| 555 |
+
super().__init__(max_dets_per_img=max_dets_per_img, iou_type=iou_type, **kwargs)
|
| 556 |
+
self.force_single_mask = force_single_mask
|
| 557 |
+
|
| 558 |
+
def process_results(
|
| 559 |
+
self, find_stages, find_metadatas: BatchedInferenceMetadata, **kwargs
|
| 560 |
+
):
|
| 561 |
+
assert len(find_stages) == len(find_metadatas)
|
| 562 |
+
results = {}
|
| 563 |
+
for outputs, meta in zip(find_stages, find_metadatas):
|
| 564 |
+
if self.force_single_mask:
|
| 565 |
+
scores, labels = outputs["pred_logits"].max(-1)
|
| 566 |
+
m = []
|
| 567 |
+
for i in range(len(outputs["pred_masks"])):
|
| 568 |
+
score, idx = scores[i].max(0)
|
| 569 |
+
m.append(outputs["pred_masks"][i][idx])
|
| 570 |
+
outputs["pred_masks"] = torch.stack(m, 0).unsqueeze(1)
|
| 571 |
+
detection_results = self(outputs, meta.original_size, consistent=False)
|
| 572 |
+
assert len(detection_results) == len(meta.coco_image_id)
|
| 573 |
+
results.update(
|
| 574 |
+
{
|
| 575 |
+
(media_id.item(), object_id.item(), frame_index.item()): result
|
| 576 |
+
for media_id, object_id, frame_index, result in zip(
|
| 577 |
+
meta.original_image_id,
|
| 578 |
+
meta.object_id,
|
| 579 |
+
meta.frame_index,
|
| 580 |
+
detection_results,
|
| 581 |
+
)
|
| 582 |
+
}
|
| 583 |
+
)
|
| 584 |
+
return results
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
class PostProcessCounting(nn.Module):
|
| 588 |
+
"""This module converts the model's output to be evaluated for counting tasks"""
|
| 589 |
+
|
| 590 |
+
def __init__(
|
| 591 |
+
self,
|
| 592 |
+
use_original_ids: bool = False,
|
| 593 |
+
threshold: float = 0.5,
|
| 594 |
+
use_presence: bool = False,
|
| 595 |
+
) -> None:
|
| 596 |
+
"""
|
| 597 |
+
Args:
|
| 598 |
+
use_original_ids: whether to use the original image ids or the coco ids
|
| 599 |
+
threshold: threshold for counting (values above this are counted)
|
| 600 |
+
"""
|
| 601 |
+
super().__init__()
|
| 602 |
+
self.use_original_ids = use_original_ids
|
| 603 |
+
self.threshold = threshold
|
| 604 |
+
self.use_presence = use_presence
|
| 605 |
+
|
| 606 |
+
def forward(self, outputs, target_sizes):
|
| 607 |
+
"""Perform the computation
|
| 608 |
+
Parameters:
|
| 609 |
+
outputs: raw outputs of the model
|
| 610 |
+
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
| 611 |
+
"""
|
| 612 |
+
# Extract scores from model outputs and apply sigmoid
|
| 613 |
+
scores = torch.sigmoid(outputs["pred_logits"]).squeeze(-1) # [B, N]
|
| 614 |
+
if self.use_presence:
|
| 615 |
+
presence_score = outputs["presence_logit_dec"].sigmoid()
|
| 616 |
+
if presence_score.ndim == 1:
|
| 617 |
+
presence_score = presence_score.unsqueeze(1) # [B, 1]
|
| 618 |
+
scores = scores * presence_score # [B, N]
|
| 619 |
+
|
| 620 |
+
# Calculate counts by summing values above threshold
|
| 621 |
+
counts = (scores > self.threshold).float().sum(dim=1)
|
| 622 |
+
|
| 623 |
+
assert len(counts) == len(target_sizes)
|
| 624 |
+
results = []
|
| 625 |
+
for count in counts:
|
| 626 |
+
results.append({"count": count.item()})
|
| 627 |
+
|
| 628 |
+
return results
|
| 629 |
+
|
| 630 |
+
@torch.no_grad()
|
| 631 |
+
def process_results(
|
| 632 |
+
self, find_stages, find_metadatas: List[BatchedInferenceMetadata], **kwargs
|
| 633 |
+
):
|
| 634 |
+
assert len(find_stages) == len(find_metadatas)
|
| 635 |
+
results = {}
|
| 636 |
+
for outputs, meta in zip(find_stages, find_metadatas):
|
| 637 |
+
detection_results = self(
|
| 638 |
+
outputs,
|
| 639 |
+
meta.original_size,
|
| 640 |
+
)
|
| 641 |
+
ids = (
|
| 642 |
+
meta.original_image_id if self.use_original_ids else meta.coco_image_id
|
| 643 |
+
)
|
| 644 |
+
assert len(detection_results) == len(ids)
|
| 645 |
+
for img_id, result in zip(ids, detection_results):
|
| 646 |
+
results[img_id.item()] = result
|
| 647 |
+
|
| 648 |
+
return results
|
sam3/eval/saco_veval_eval.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
from iopath.common.file_io import g_pathmgr
|
| 8 |
+
from sam3.eval.saco_veval_evaluators import (
|
| 9 |
+
VideoCGF1Evaluator,
|
| 10 |
+
VideoPhraseApEvaluator,
|
| 11 |
+
VideoPhraseHotaEvaluator,
|
| 12 |
+
VideoTetaEvaluator,
|
| 13 |
+
YTVISPredFileEvaluator,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class VEvalEvaluator:
|
| 18 |
+
def __init__(self, gt_annot_file: str, eval_res_file: str):
|
| 19 |
+
self.gt_annot_file = gt_annot_file
|
| 20 |
+
self.eval_res_file = eval_res_file
|
| 21 |
+
self.evaluators = [
|
| 22 |
+
# mAP
|
| 23 |
+
YTVISPredFileEvaluator(gt_annot_file),
|
| 24 |
+
# Phrase AP
|
| 25 |
+
VideoPhraseApEvaluator(gt_annot_file),
|
| 26 |
+
# TETA
|
| 27 |
+
VideoTetaEvaluator(gt_annot_file, use_mask=True, is_exhaustive=True),
|
| 28 |
+
# HOTA
|
| 29 |
+
VideoPhraseHotaEvaluator(gt_annot_file),
|
| 30 |
+
# cgF1
|
| 31 |
+
VideoCGF1Evaluator(gt_annot_file),
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
def run_eval(self, pred_file: str):
|
| 35 |
+
dataset_results = {}
|
| 36 |
+
video_np_results = defaultdict(dict)
|
| 37 |
+
for evaluator in self.evaluators:
|
| 38 |
+
d_res, v_np_res = evaluator.evaluate(pred_file)
|
| 39 |
+
dataset_results.update(d_res)
|
| 40 |
+
for (video_id, category_id), res in v_np_res.items():
|
| 41 |
+
video_np_results[(video_id, category_id)].update(res)
|
| 42 |
+
|
| 43 |
+
if len(dataset_results) == 0:
|
| 44 |
+
dataset_results = {"": 0.0}
|
| 45 |
+
|
| 46 |
+
formatted_video_np_results = [
|
| 47 |
+
{"video_id": video_id, "category_id": category_id, **res}
|
| 48 |
+
for (video_id, category_id), res in video_np_results.items()
|
| 49 |
+
]
|
| 50 |
+
eval_metrics = {
|
| 51 |
+
"dataset_results": dataset_results,
|
| 52 |
+
"video_np_results": formatted_video_np_results,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
with g_pathmgr.open(self.eval_res_file, "w") as f:
|
| 56 |
+
json.dump(eval_metrics, f)
|
| 57 |
+
|
| 58 |
+
return eval_metrics
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def run_main_all(dataset_name, args):
|
| 62 |
+
gt_annot_file = os.path.join(args.gt_annot_dir, dataset_name + ".json")
|
| 63 |
+
pred_file = os.path.join(args.pred_dir, dataset_name + "_preds.json")
|
| 64 |
+
eval_res_file = os.path.join(args.eval_res_dir, dataset_name + "_eval_res.json")
|
| 65 |
+
print(f"=== Running evaluation for Pred {pred_file} vs GT {gt_annot_file} ===")
|
| 66 |
+
veval_evaluator = VEvalEvaluator(
|
| 67 |
+
gt_annot_file=gt_annot_file, eval_res_file=eval_res_file
|
| 68 |
+
)
|
| 69 |
+
_ = veval_evaluator.run_eval(pred_file=pred_file)
|
| 70 |
+
|
| 71 |
+
print(f"=== Results saved to {eval_res_file} ===")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def main_all(args):
|
| 75 |
+
saco_veval_dataset_names = [
|
| 76 |
+
"saco_veval_sav_test",
|
| 77 |
+
"saco_veval_sav_val",
|
| 78 |
+
"saco_veval_yt1b_test",
|
| 79 |
+
"saco_veval_yt1b_val",
|
| 80 |
+
"saco_veval_smartglasses_test",
|
| 81 |
+
"saco_veval_smartglasses_val",
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
# multiprocessing may not really work as inner evaluator also using multiprocessing
|
| 85 |
+
# so we just for loop
|
| 86 |
+
for dataset_name in saco_veval_dataset_names:
|
| 87 |
+
print(f"=== Running evaluation for dataset {dataset_name} ===")
|
| 88 |
+
run_main_all(dataset_name=dataset_name, args=args)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def main_one(args):
|
| 92 |
+
gt_annot_file = args.gt_annot_file
|
| 93 |
+
pred_file = args.pred_file
|
| 94 |
+
eval_res_file = args.eval_res_file
|
| 95 |
+
|
| 96 |
+
print(f"=== Running evaluation for Pred {pred_file} vs GT {gt_annot_file} ===")
|
| 97 |
+
veval_evaluator = VEvalEvaluator(
|
| 98 |
+
gt_annot_file=gt_annot_file, eval_res_file=eval_res_file
|
| 99 |
+
)
|
| 100 |
+
_ = veval_evaluator.run_eval(pred_file=pred_file)
|
| 101 |
+
|
| 102 |
+
print(f"=== Results saved to {eval_res_file} ===")
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def main():
|
| 106 |
+
parser = argparse.ArgumentParser(description="Run video grounding evaluators")
|
| 107 |
+
|
| 108 |
+
# Create subparsers for different commands
|
| 109 |
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
| 110 |
+
|
| 111 |
+
# Run evaluation for all datasets
|
| 112 |
+
all_parser = subparsers.add_parser("all", help="Run evaluation for all datasets")
|
| 113 |
+
all_parser.add_argument(
|
| 114 |
+
"--gt_annot_dir",
|
| 115 |
+
type=str,
|
| 116 |
+
help="Directory that contains the ground truth annotation files",
|
| 117 |
+
)
|
| 118 |
+
all_parser.add_argument(
|
| 119 |
+
"--pred_dir",
|
| 120 |
+
type=str,
|
| 121 |
+
help="Directory that contains the prediction files",
|
| 122 |
+
)
|
| 123 |
+
all_parser.add_argument(
|
| 124 |
+
"--eval_res_dir",
|
| 125 |
+
type=str,
|
| 126 |
+
help="Directory that contains the eval results files",
|
| 127 |
+
)
|
| 128 |
+
all_parser.set_defaults(func=main_all)
|
| 129 |
+
|
| 130 |
+
# Run evaluation for one dataset
|
| 131 |
+
one_parser = subparsers.add_parser("one", help="Run evaluation for one dataset")
|
| 132 |
+
one_parser.add_argument(
|
| 133 |
+
"--gt_annot_file",
|
| 134 |
+
type=str,
|
| 135 |
+
help="Path to the ground truth annotation file",
|
| 136 |
+
)
|
| 137 |
+
one_parser.add_argument(
|
| 138 |
+
"--pred_file",
|
| 139 |
+
type=str,
|
| 140 |
+
help="Path to the prediction file",
|
| 141 |
+
)
|
| 142 |
+
one_parser.add_argument(
|
| 143 |
+
"--eval_res_file",
|
| 144 |
+
type=str,
|
| 145 |
+
help="Path to the eval results file",
|
| 146 |
+
)
|
| 147 |
+
one_parser.set_defaults(func=main_one)
|
| 148 |
+
|
| 149 |
+
# Parse and dispatch
|
| 150 |
+
args = parser.parse_args()
|
| 151 |
+
args.func(args)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
main()
|
sam3/eval/saco_veval_evaluators.py
ADDED
|
@@ -0,0 +1,838 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import tempfile
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from typing import Dict, Optional, Sequence, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pycocotools.mask
|
| 10 |
+
from sam3.eval.cgf1_eval import CGF1_METRICS
|
| 11 |
+
from sam3.eval.conversion_util import (
|
| 12 |
+
convert_ytbvis_to_cocovid_gt,
|
| 13 |
+
convert_ytbvis_to_cocovid_pred,
|
| 14 |
+
)
|
| 15 |
+
from sam3.eval.hota_eval_toolkit.run_ytvis_eval import run_ytvis_eval
|
| 16 |
+
from sam3.eval.teta_eval_toolkit import config, Evaluator, metrics
|
| 17 |
+
from sam3.eval.teta_eval_toolkit.datasets import COCO, TAO
|
| 18 |
+
from sam3.eval.ytvis_coco_wrapper import YTVIS
|
| 19 |
+
from sam3.eval.ytvis_eval import VideoDemoF1Eval, YTVISeval
|
| 20 |
+
from sam3.train.nms_helper import process_frame_level_nms, process_track_level_nms
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _get_metric_index(metric_name: str, iou_threshold: Optional[float] = None) -> int:
|
| 24 |
+
"""
|
| 25 |
+
Find the index of a metric in CGF1_METRICS by name and IoU threshold.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
metric_name: Name of the metric (e.g., "cgF1", "precision", "recall")
|
| 29 |
+
iou_threshold: IoU threshold (None for average over 0.5:0.95, or specific value like 0.5, 0.75)
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Index of the metric in CGF1_METRICS
|
| 33 |
+
|
| 34 |
+
Raises:
|
| 35 |
+
ValueError: If metric not found
|
| 36 |
+
"""
|
| 37 |
+
for idx, metric in enumerate(CGF1_METRICS):
|
| 38 |
+
if metric.name == metric_name and metric.iou_threshold == iou_threshold:
|
| 39 |
+
return idx
|
| 40 |
+
raise ValueError(
|
| 41 |
+
f"Metric '{metric_name}' with IoU threshold {iou_threshold} not found in CGF1_METRICS"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class BasePredFileEvaluator:
|
| 46 |
+
"""A base class for evaluating a prediction file."""
|
| 47 |
+
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class YTVISPredFileEvaluator(BasePredFileEvaluator):
|
| 52 |
+
"""Evaluate class mAP for YT-VIS prediction files."""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
gt_ann_file: str,
|
| 57 |
+
dataset_name: str = "video",
|
| 58 |
+
iou_types: Optional[Sequence[str]] = None,
|
| 59 |
+
):
|
| 60 |
+
self.gt_ann_file = gt_ann_file
|
| 61 |
+
self.dataset_name = dataset_name
|
| 62 |
+
self.iou_types = list(iou_types) if iou_types is not None else ["bbox", "segm"]
|
| 63 |
+
assert all(iou_type in ["bbox", "segm"] for iou_type in self.iou_types)
|
| 64 |
+
|
| 65 |
+
def evaluate(self, pred_file: str) -> Dict[str, float]:
|
| 66 |
+
# use our internal video evaluation toolkit for YT-VIS pred file
|
| 67 |
+
# (i.e. the same one we're using for video phrase AP)
|
| 68 |
+
results = {}
|
| 69 |
+
use_cats = True # YT-VIS mAP evaluation uses categories
|
| 70 |
+
ytvisGT = YTVIS(self.gt_ann_file, ignore_gt_cats=not use_cats)
|
| 71 |
+
# the original YT-VIS GT annotations have uncompressed RLEs ("counts" is an integer list)
|
| 72 |
+
# rather than compressed RLEs ("counts" is a string), so we first convert them here.
|
| 73 |
+
if "segm" in self.iou_types:
|
| 74 |
+
for ann in ytvisGT.dataset["annotations"]:
|
| 75 |
+
ann["segmentations"] = [
|
| 76 |
+
_compress_rle(rle) for rle in ann["segmentations"]
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
with open(pred_file) as f:
|
| 80 |
+
dt = json.load(f)
|
| 81 |
+
# Our prediction file saves "video_id" and absolute (unnormalized) boxes.
|
| 82 |
+
# Note that we should use the official (original) YT-VIS annotations (i.e. the one
|
| 83 |
+
# saved via "scripts/datasets/training/ytvis_split.py", instead of the one saved
|
| 84 |
+
# via "scripts/api_db_to_ytvis_json.py") in this evaluator, which contain absolute
|
| 85 |
+
# boxes coordinates in its GT annotations.
|
| 86 |
+
for d in dt:
|
| 87 |
+
d["image_id"] = d["video_id"]
|
| 88 |
+
ytvisDT = ytvisGT.loadRes(dt)
|
| 89 |
+
|
| 90 |
+
for iou_type in self.iou_types:
|
| 91 |
+
ytvisEval = YTVISeval(ytvisGT, ytvisDT, iou_type)
|
| 92 |
+
|
| 93 |
+
# set the area ranges for small, medium, and large objects (using
|
| 94 |
+
# absolute pixel areas) as in the official YT-VIS evaluation toolkit:
|
| 95 |
+
# https://github.com/achalddave/ytvosapi/blob/eca601117c9f86bad084cb91f1d918e9ab665a75/PythonAPI/ytvostools/ytvoseval.py#L538
|
| 96 |
+
ytvisEval.params.areaRng = [
|
| 97 |
+
[0**2, 1e5**2],
|
| 98 |
+
[0**2, 128**2],
|
| 99 |
+
[128**2, 256**2],
|
| 100 |
+
[256**2, 1e5**2],
|
| 101 |
+
]
|
| 102 |
+
ytvisEval.params.areaRngLbl = ["all", "small", "medium", "large"]
|
| 103 |
+
ytvisEval.params.useCats = use_cats
|
| 104 |
+
|
| 105 |
+
ytvisEval.evaluate()
|
| 106 |
+
ytvisEval.accumulate()
|
| 107 |
+
ytvisEval.summarize()
|
| 108 |
+
result_key = f"{self.dataset_name}_{'mask' if iou_type == 'segm' else 'bbox'}_mAP_50_95"
|
| 109 |
+
results[result_key] = ytvisEval.stats[0]
|
| 110 |
+
|
| 111 |
+
# video-NP level results not supported for `YTVISPredFileEvaluator` yet
|
| 112 |
+
video_np_level_results = {}
|
| 113 |
+
return results, video_np_level_results
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class VideoPhraseApEvaluator(BasePredFileEvaluator):
|
| 117 |
+
"""Evaluate Video Phrase AP with YT-VIS format prediction and GT files."""
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
gt_ann_file: str,
|
| 122 |
+
dataset_name: str = "video",
|
| 123 |
+
iou_types: Optional[Sequence[str]] = None,
|
| 124 |
+
):
|
| 125 |
+
self.gt_ann_file = gt_ann_file
|
| 126 |
+
self.dataset_name = dataset_name
|
| 127 |
+
self.iou_types = list(iou_types) if iou_types is not None else ["bbox", "segm"]
|
| 128 |
+
assert all(iou_type in ["bbox", "segm"] for iou_type in self.iou_types)
|
| 129 |
+
|
| 130 |
+
def evaluate(self, pred_file: str) -> Dict[str, float]:
|
| 131 |
+
with open(self.gt_ann_file) as f:
|
| 132 |
+
gt = json.load(f)
|
| 133 |
+
with open(pred_file) as f:
|
| 134 |
+
dt = json.load(f)
|
| 135 |
+
# For phrase AP and demo F1 evaluation, we need to remap each pair of (video_id, category_id) to
|
| 136 |
+
# a new unique video_id, so that we don't mix detections from different categories under `useCat=False`
|
| 137 |
+
gt, dt = remap_video_category_pairs_to_unique_video_ids(gt, dt)
|
| 138 |
+
if "segm" in self.iou_types:
|
| 139 |
+
for ann in gt["annotations"]:
|
| 140 |
+
ann["segmentations"] = [
|
| 141 |
+
_compress_rle(rle) for rle in ann["segmentations"]
|
| 142 |
+
]
|
| 143 |
+
for d in dt:
|
| 144 |
+
d["image_id"] = d["video_id"]
|
| 145 |
+
|
| 146 |
+
results = {}
|
| 147 |
+
use_cats = False # Phrase AP evaluation does not use categories
|
| 148 |
+
ytvisGT = YTVIS(annotation_file=None, ignore_gt_cats=not use_cats)
|
| 149 |
+
ytvisGT.dataset = gt
|
| 150 |
+
ytvisGT.createIndex()
|
| 151 |
+
ytvisDT = ytvisGT.loadRes(dt)
|
| 152 |
+
|
| 153 |
+
for iou_type in self.iou_types:
|
| 154 |
+
phraseApEval = YTVISeval(ytvisGT, ytvisDT, iou_type)
|
| 155 |
+
|
| 156 |
+
# set the area ranges for small, medium, and large objects (using
|
| 157 |
+
# absolute pixel areas) as in the official YT-VIS evaluation toolkit:
|
| 158 |
+
# https://github.com/achalddave/ytvosapi/blob/eca601117c9f86bad084cb91f1d918e9ab665a75/PythonAPI/ytvostools/ytvoseval.py#L538
|
| 159 |
+
phraseApEval.params.areaRng = [
|
| 160 |
+
[0**2, 1e5**2],
|
| 161 |
+
[0**2, 128**2],
|
| 162 |
+
[128**2, 256**2],
|
| 163 |
+
[256**2, 1e5**2],
|
| 164 |
+
]
|
| 165 |
+
phraseApEval.params.areaRngLbl = ["all", "small", "medium", "large"]
|
| 166 |
+
phraseApEval.params.useCats = use_cats
|
| 167 |
+
|
| 168 |
+
phraseApEval.evaluate()
|
| 169 |
+
phraseApEval.accumulate()
|
| 170 |
+
phraseApEval.summarize()
|
| 171 |
+
result_prefix = f"{self.dataset_name}"
|
| 172 |
+
result_prefix += f"_{'mask' if iou_type == 'segm' else 'bbox'}_phrase_ap"
|
| 173 |
+
# fetch Phrase AP results from the corresponding indices in `phraseApEval.stats`
|
| 174 |
+
# (see `_summarizeDets` in https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py)
|
| 175 |
+
results[result_prefix + "_50_95"] = phraseApEval.stats[0] # IoU=0.5:0.95
|
| 176 |
+
results[result_prefix + "_50"] = phraseApEval.stats[1] # IoU=0.5
|
| 177 |
+
results[result_prefix + "_75"] = phraseApEval.stats[2] # IoU=0.75
|
| 178 |
+
|
| 179 |
+
# video-NP level results not supported for `VideoPhraseApEvaluator` yet
|
| 180 |
+
video_np_level_results = {}
|
| 181 |
+
return results, video_np_level_results
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class VideoCGF1Evaluator(BasePredFileEvaluator):
|
| 185 |
+
"""Evaluate Video Demo F1 with YT-VIS format prediction and GT files."""
|
| 186 |
+
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
gt_ann_file: str,
|
| 190 |
+
dataset_name: str = "video",
|
| 191 |
+
prob_thresh: float = 0.5,
|
| 192 |
+
iou_types: Optional[Sequence[str]] = None,
|
| 193 |
+
):
|
| 194 |
+
self.gt_ann_file = gt_ann_file
|
| 195 |
+
self.dataset_name = dataset_name
|
| 196 |
+
self.prob_thresh = prob_thresh
|
| 197 |
+
self.iou_types = list(iou_types) if iou_types is not None else ["bbox", "segm"]
|
| 198 |
+
assert all(iou_type in ["bbox", "segm"] for iou_type in self.iou_types)
|
| 199 |
+
|
| 200 |
+
def evaluate(self, pred_file: str) -> Dict[str, float]:
|
| 201 |
+
with open(self.gt_ann_file) as f:
|
| 202 |
+
gt = json.load(f)
|
| 203 |
+
with open(pred_file) as f:
|
| 204 |
+
dt = json.load(f)
|
| 205 |
+
# compute IL_MCC and CG-F1 can only be computed if we have "video_np_pairs" keys in the GT JSON
|
| 206 |
+
compute_ilmcc_and_cgf1 = "video_np_pairs" in gt
|
| 207 |
+
if not compute_ilmcc_and_cgf1:
|
| 208 |
+
print(
|
| 209 |
+
f"Warning: IL_MCC and CG-F1 are not computed for {pred_file=} as it does not have 'video_np_pairs' keys in the GT JSON"
|
| 210 |
+
)
|
| 211 |
+
# For phrase AP and demo F1 evaluation, we need to remap each pair of (video_id, category_id) to
|
| 212 |
+
# a new unique video_id, so that we don't mix detections from different categories under `useCat=False`
|
| 213 |
+
gt, dt = remap_video_category_pairs_to_unique_video_ids(
|
| 214 |
+
gt, dt, add_negative_np_pairs=compute_ilmcc_and_cgf1
|
| 215 |
+
)
|
| 216 |
+
if "segm" in self.iou_types:
|
| 217 |
+
for ann in gt["annotations"]:
|
| 218 |
+
ann["segmentations"] = [
|
| 219 |
+
_compress_rle(rle) for rle in ann["segmentations"]
|
| 220 |
+
]
|
| 221 |
+
for d in dt:
|
| 222 |
+
d["image_id"] = d["video_id"]
|
| 223 |
+
|
| 224 |
+
results = {}
|
| 225 |
+
use_cats = False # Demo F1 evaluation does not use categories
|
| 226 |
+
ytvisGT = YTVIS(annotation_file=None, ignore_gt_cats=not use_cats)
|
| 227 |
+
ytvisGT.dataset = gt
|
| 228 |
+
ytvisGT.createIndex()
|
| 229 |
+
ytvisDT = ytvisGT.loadRes(dt)
|
| 230 |
+
|
| 231 |
+
video_np_level_results = {}
|
| 232 |
+
for iou_type in self.iou_types:
|
| 233 |
+
demoF1Eval = VideoDemoF1Eval(ytvisGT, ytvisDT, iou_type, self.prob_thresh)
|
| 234 |
+
|
| 235 |
+
demoF1Eval.params.useCats = use_cats
|
| 236 |
+
demoF1Eval.params.areaRng = [[0**2, 1e5**2]]
|
| 237 |
+
demoF1Eval.params.areaRngLbl = ["all"]
|
| 238 |
+
demoF1Eval.params.maxDets = [100000]
|
| 239 |
+
|
| 240 |
+
demoF1Eval.evaluate()
|
| 241 |
+
demoF1Eval.accumulate()
|
| 242 |
+
demoF1Eval.summarize()
|
| 243 |
+
result_prefix = f"{self.dataset_name}"
|
| 244 |
+
result_prefix += f"_{'mask' if iou_type == 'segm' else 'bbox'}_demo"
|
| 245 |
+
|
| 246 |
+
stats = demoF1Eval.stats
|
| 247 |
+
|
| 248 |
+
if compute_ilmcc_and_cgf1:
|
| 249 |
+
# Average IoU threshold (0.5:0.95)
|
| 250 |
+
cgf1_micro_avg_idx = _get_metric_index("cgF1", None)
|
| 251 |
+
positive_micro_f1_avg_idx = _get_metric_index("positive_micro_F1", None)
|
| 252 |
+
ilmcc_avg_idx = _get_metric_index("IL_MCC", None)
|
| 253 |
+
results[result_prefix + "_cgf1_micro_50_95"] = stats[cgf1_micro_avg_idx]
|
| 254 |
+
results[result_prefix + "_ilmcc_50_95"] = stats[ilmcc_avg_idx]
|
| 255 |
+
results[result_prefix + "_positive_micro_f1_50_95"] = stats[
|
| 256 |
+
positive_micro_f1_avg_idx
|
| 257 |
+
]
|
| 258 |
+
|
| 259 |
+
# IoU = 0.5
|
| 260 |
+
cgf1_micro_50_idx = _get_metric_index("cgF1", 0.5)
|
| 261 |
+
positive_micro_f1_50_idx = _get_metric_index("positive_micro_F1", 0.5)
|
| 262 |
+
results[result_prefix + "_cgf1_micro_50"] = stats[cgf1_micro_50_idx]
|
| 263 |
+
results[result_prefix + "_ilmcc_50"] = float(
|
| 264 |
+
np.array(stats[cgf1_micro_50_idx])
|
| 265 |
+
/ np.array(stats[positive_micro_f1_50_idx])
|
| 266 |
+
)
|
| 267 |
+
results[result_prefix + "_positive_micro_f1_50"] = stats[
|
| 268 |
+
positive_micro_f1_50_idx
|
| 269 |
+
]
|
| 270 |
+
|
| 271 |
+
# IoU = 0.75
|
| 272 |
+
cgf1_micro_75_idx = _get_metric_index("cgF1", 0.75)
|
| 273 |
+
positive_micro_f1_75_idx = _get_metric_index("positive_micro_F1", 0.75)
|
| 274 |
+
results[result_prefix + "_cgf1_micro_75"] = stats[cgf1_micro_75_idx]
|
| 275 |
+
results[result_prefix + "_ilmcc_75"] = float(
|
| 276 |
+
np.array(stats[cgf1_micro_75_idx])
|
| 277 |
+
/ np.array(stats[positive_micro_f1_75_idx])
|
| 278 |
+
)
|
| 279 |
+
results[result_prefix + "_positive_micro_f1_75"] = stats[
|
| 280 |
+
positive_micro_f1_75_idx
|
| 281 |
+
]
|
| 282 |
+
|
| 283 |
+
self.extract_video_np_level_results(demoF1Eval, video_np_level_results)
|
| 284 |
+
|
| 285 |
+
return results, video_np_level_results
|
| 286 |
+
|
| 287 |
+
def extract_video_np_level_results(self, demoF1Eval, video_np_level_results):
|
| 288 |
+
"""Aggregate statistics for video-level metrics."""
|
| 289 |
+
num_iou_thrs = len(demoF1Eval.params.iouThrs)
|
| 290 |
+
iou_50_index = int(np.where(demoF1Eval.params.iouThrs == 0.5)[0])
|
| 291 |
+
iou_75_index = int(np.where(demoF1Eval.params.iouThrs == 0.75)[0])
|
| 292 |
+
|
| 293 |
+
result_prefix = "mask" if demoF1Eval.params.iouType == "segm" else "bbox"
|
| 294 |
+
|
| 295 |
+
assert len(demoF1Eval.evalImgs) == len(demoF1Eval.cocoGt.dataset["images"])
|
| 296 |
+
for i, video in enumerate(demoF1Eval.cocoGt.dataset["images"]):
|
| 297 |
+
# the original video id and category id before remapping
|
| 298 |
+
video_id = video["orig_video_id"]
|
| 299 |
+
category_id = video["orig_category_id"]
|
| 300 |
+
eval_img_dict = demoF1Eval.evalImgs[i]
|
| 301 |
+
|
| 302 |
+
TPs = eval_img_dict.get("TPs", np.zeros(num_iou_thrs, dtype=np.int64))
|
| 303 |
+
FPs = eval_img_dict.get("FPs", np.zeros(num_iou_thrs, dtype=np.int64))
|
| 304 |
+
FNs = eval_img_dict.get("FNs", np.zeros(num_iou_thrs, dtype=np.int64))
|
| 305 |
+
assert len(TPs) == len(FPs) == len(FNs) == num_iou_thrs
|
| 306 |
+
# F1 = 2*TP / (2*TP + FP + FN), and we set F1 to 1.0 if denominator is 0
|
| 307 |
+
denominator = 2 * TPs + FPs + FNs
|
| 308 |
+
F1s = np.where(denominator > 0, 2 * TPs / np.maximum(denominator, 1), 1.0)
|
| 309 |
+
local_results = {
|
| 310 |
+
f"{result_prefix}_TP_50_95": float(TPs.mean()),
|
| 311 |
+
f"{result_prefix}_FP_50_95": float(FPs.mean()),
|
| 312 |
+
f"{result_prefix}_FN_50_95": float(FNs.mean()),
|
| 313 |
+
f"{result_prefix}_F1_50_95": float(F1s.mean()),
|
| 314 |
+
f"{result_prefix}_TP_50": float(TPs[iou_50_index]),
|
| 315 |
+
f"{result_prefix}_FP_50": float(FPs[iou_50_index]),
|
| 316 |
+
f"{result_prefix}_FN_50": float(FNs[iou_50_index]),
|
| 317 |
+
f"{result_prefix}_F1_50": float(F1s[iou_50_index]),
|
| 318 |
+
f"{result_prefix}_TP_75": float(TPs[iou_75_index]),
|
| 319 |
+
f"{result_prefix}_FP_75": float(FPs[iou_75_index]),
|
| 320 |
+
f"{result_prefix}_FN_75": float(FNs[iou_75_index]),
|
| 321 |
+
f"{result_prefix}_F1_75": float(F1s[iou_75_index]),
|
| 322 |
+
}
|
| 323 |
+
if (video_id, category_id) not in video_np_level_results:
|
| 324 |
+
video_np_level_results[(video_id, category_id)] = {}
|
| 325 |
+
video_np_level_results[(video_id, category_id)].update(local_results)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class VideoTetaEvaluator(BasePredFileEvaluator):
|
| 329 |
+
"""Evaluate TETA metric using YouTubeVIS format prediction and GT files."""
|
| 330 |
+
|
| 331 |
+
def __init__(
|
| 332 |
+
self,
|
| 333 |
+
gt_ann_file: str,
|
| 334 |
+
dataset_name: str = "video",
|
| 335 |
+
tracker_name: str = "Sam3",
|
| 336 |
+
nms_threshold: float = 0.5,
|
| 337 |
+
nms_strategy: str = "none", # "track", "frame", or "none"
|
| 338 |
+
prob_thresh: float = 0.5,
|
| 339 |
+
is_exhaustive: bool = False,
|
| 340 |
+
use_mask: bool = False,
|
| 341 |
+
num_parallel_cores: int = 8,
|
| 342 |
+
):
|
| 343 |
+
self.gt_ann_file = gt_ann_file
|
| 344 |
+
self.dataset_name = dataset_name
|
| 345 |
+
self.tracker_name = tracker_name
|
| 346 |
+
self.nms_threshold = nms_threshold
|
| 347 |
+
self.nms_strategy = nms_strategy.lower() # Convert to lowercase for consistency
|
| 348 |
+
self.prob_thresh = prob_thresh
|
| 349 |
+
self.metric_prefix = "TETA"
|
| 350 |
+
self.is_exhaustive = is_exhaustive
|
| 351 |
+
self.use_mask = use_mask
|
| 352 |
+
self.num_parallel_cores = num_parallel_cores
|
| 353 |
+
|
| 354 |
+
# Verify NMS strategy is valid
|
| 355 |
+
valid_strategies = ["track", "frame", "none"]
|
| 356 |
+
print("current nms_strategy:", self.nms_strategy)
|
| 357 |
+
if self.nms_strategy not in valid_strategies:
|
| 358 |
+
raise ValueError(
|
| 359 |
+
f"Invalid NMS strategy: {self.nms_strategy}. Must be one of {valid_strategies}"
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
print(f"Initialized VideoTetaEvaluator with NMS strategy: {self.nms_strategy}")
|
| 363 |
+
print(f"Probability threshold set to: {self.prob_thresh}")
|
| 364 |
+
print(f"Dataset exhaustivity set to: {self.is_exhaustive}")
|
| 365 |
+
print(f"Tracker name set to: {self.tracker_name}")
|
| 366 |
+
print(f"Dataset name set to: {self.dataset_name}")
|
| 367 |
+
print(f"Use mask set to: {self.use_mask}")
|
| 368 |
+
|
| 369 |
+
def process_predictions(self, pred_file: str, tmp_dir: str) -> str:
|
| 370 |
+
"""Process predictions with selected NMS strategy"""
|
| 371 |
+
with open(pred_file, "r") as f:
|
| 372 |
+
raw_preds = json.load(f)
|
| 373 |
+
print(f"Processing predictions with {self.nms_strategy} NMS strategy")
|
| 374 |
+
|
| 375 |
+
# Filter by score threshold
|
| 376 |
+
if self.prob_thresh > 0:
|
| 377 |
+
raw_preds = [d for d in raw_preds if d["score"] >= self.prob_thresh]
|
| 378 |
+
print(
|
| 379 |
+
f"Filtered to {len(raw_preds)} predictions with score >= {self.prob_thresh}"
|
| 380 |
+
)
|
| 381 |
+
# Group predictions by video_id
|
| 382 |
+
video_groups = defaultdict(list)
|
| 383 |
+
for pred in raw_preds:
|
| 384 |
+
video_groups[pred["video_id"]].append(pred)
|
| 385 |
+
# Process based on NMS strategy
|
| 386 |
+
if self.nms_strategy == "track":
|
| 387 |
+
process_track_level_nms(video_groups, nms_threshold=self.nms_threshold)
|
| 388 |
+
elif self.nms_strategy == "frame":
|
| 389 |
+
process_frame_level_nms(video_groups, nms_threshold=self.nms_threshold)
|
| 390 |
+
elif self.nms_strategy == "none":
|
| 391 |
+
print("Skipping NMS processing as strategy is set to 'none'")
|
| 392 |
+
# No processing needed for "none" strategy
|
| 393 |
+
# Save processed predictions
|
| 394 |
+
processed_preds = [
|
| 395 |
+
track for tracks in video_groups.values() for track in tracks
|
| 396 |
+
]
|
| 397 |
+
processed_path = os.path.join(tmp_dir, "processed_preds.json")
|
| 398 |
+
with open(processed_path, "w") as f:
|
| 399 |
+
json.dump(processed_preds, f)
|
| 400 |
+
|
| 401 |
+
print(f"Saved processed predictions to {processed_path}")
|
| 402 |
+
return processed_path
|
| 403 |
+
|
| 404 |
+
def evaluate(self, pred_file: str) -> Tuple[Dict[str, float], Dict]:
|
| 405 |
+
"""Main evaluation method"""
|
| 406 |
+
|
| 407 |
+
print(f"Evaluating TETA Metric with {self.nms_strategy.upper()} NMS strategy")
|
| 408 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 409 |
+
# Process predictions first
|
| 410 |
+
processed_pred_file = self.process_predictions(pred_file, tmp_dir)
|
| 411 |
+
|
| 412 |
+
# Convert GT to COCO-vid format
|
| 413 |
+
gt_dir = os.path.join(tmp_dir, "gt")
|
| 414 |
+
os.makedirs(gt_dir, exist_ok=True)
|
| 415 |
+
gt_coco_path = os.path.join(gt_dir, "annotations.json")
|
| 416 |
+
convert_ytbvis_to_cocovid_gt(self.gt_ann_file, gt_coco_path)
|
| 417 |
+
|
| 418 |
+
# Convert processed predictions to COCO-vid format
|
| 419 |
+
pred_dir = os.path.join(tmp_dir, "predictions")
|
| 420 |
+
tracker_dir = os.path.join(pred_dir, self.tracker_name)
|
| 421 |
+
os.makedirs(tracker_dir, exist_ok=True)
|
| 422 |
+
pred_coco_path = os.path.join(tracker_dir, "track_results_cocofmt.json")
|
| 423 |
+
convert_ytbvis_to_cocovid_pred(
|
| 424 |
+
youtubevis_pred_path=processed_pred_file,
|
| 425 |
+
converted_dataset_path=gt_coco_path,
|
| 426 |
+
output_path=pred_coco_path,
|
| 427 |
+
)
|
| 428 |
+
# Configure TETA evaluator
|
| 429 |
+
default_eval_config = config.get_default_eval_config()
|
| 430 |
+
default_eval_config["PRINT_ONLY_COMBINED"] = True
|
| 431 |
+
default_eval_config["DISPLAY_LESS_PROGRESS"] = True
|
| 432 |
+
default_eval_config["OUTPUT_TEMP_RAW_DATA"] = True
|
| 433 |
+
default_eval_config["NUM_PARALLEL_CORES"] = self.num_parallel_cores
|
| 434 |
+
default_dataset_config = config.get_default_dataset_config()
|
| 435 |
+
default_dataset_config["TRACKERS_TO_EVAL"] = [self.tracker_name]
|
| 436 |
+
default_dataset_config["GT_FOLDER"] = gt_dir
|
| 437 |
+
default_dataset_config["OUTPUT_FOLDER"] = pred_dir
|
| 438 |
+
default_dataset_config["TRACKER_SUB_FOLDER"] = tracker_dir
|
| 439 |
+
default_dataset_config["USE_MASK"] = self.use_mask
|
| 440 |
+
|
| 441 |
+
evaluator = Evaluator(default_eval_config)
|
| 442 |
+
if self.is_exhaustive:
|
| 443 |
+
dataset_list = [COCO(default_dataset_config)]
|
| 444 |
+
dataset_parsing_key = "COCO"
|
| 445 |
+
else:
|
| 446 |
+
dataset_list = [TAO(default_dataset_config)]
|
| 447 |
+
dataset_parsing_key = "TAO"
|
| 448 |
+
|
| 449 |
+
# Run evaluation
|
| 450 |
+
eval_results, _ = evaluator.evaluate(
|
| 451 |
+
dataset_list, [metrics.TETA(exhaustive=self.is_exhaustive)]
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
# Extract and format results
|
| 455 |
+
results = {
|
| 456 |
+
f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_teta": float(
|
| 457 |
+
eval_results[dataset_parsing_key]["TETA"][0]
|
| 458 |
+
),
|
| 459 |
+
f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_loc_a": float(
|
| 460 |
+
eval_results[dataset_parsing_key]["TETA"][1]
|
| 461 |
+
),
|
| 462 |
+
f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_assoc_a": float(
|
| 463 |
+
eval_results[dataset_parsing_key]["TETA"][2]
|
| 464 |
+
),
|
| 465 |
+
f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_cls_a": float(
|
| 466 |
+
eval_results[dataset_parsing_key]["TETA"][3]
|
| 467 |
+
),
|
| 468 |
+
f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_loc_re": float(
|
| 469 |
+
eval_results[dataset_parsing_key]["TETA"][4]
|
| 470 |
+
),
|
| 471 |
+
f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_loc_pr": float(
|
| 472 |
+
eval_results[dataset_parsing_key]["TETA"][5]
|
| 473 |
+
),
|
| 474 |
+
f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_assoc_re": float(
|
| 475 |
+
eval_results[dataset_parsing_key]["TETA"][6]
|
| 476 |
+
),
|
| 477 |
+
f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_assoc_pr": float(
|
| 478 |
+
eval_results[dataset_parsing_key]["TETA"][7]
|
| 479 |
+
),
|
| 480 |
+
f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_cls_re": float(
|
| 481 |
+
eval_results[dataset_parsing_key]["TETA"][8]
|
| 482 |
+
),
|
| 483 |
+
f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_cls_pr": float(
|
| 484 |
+
eval_results[dataset_parsing_key]["TETA"][9]
|
| 485 |
+
),
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
# video-NP level results not supported for `VideoTetaEvaluator` yet
|
| 489 |
+
video_np_level_results = {}
|
| 490 |
+
return results, video_np_level_results
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class VideoPhraseHotaEvaluator(BasePredFileEvaluator):
|
| 494 |
+
"""Evaluate Video Phrase HOTA with YT-VIS format prediction and GT files."""
|
| 495 |
+
|
| 496 |
+
def __init__(
|
| 497 |
+
self,
|
| 498 |
+
gt_ann_file: str,
|
| 499 |
+
dataset_name: str = "video",
|
| 500 |
+
prob_thresh: float = 0.5,
|
| 501 |
+
iou_types: Optional[Sequence[str]] = None,
|
| 502 |
+
compute_video_mot_hota: bool = False,
|
| 503 |
+
):
|
| 504 |
+
self.gt_ann_file = gt_ann_file
|
| 505 |
+
self.dataset_name = dataset_name
|
| 506 |
+
self.prob_thresh = prob_thresh
|
| 507 |
+
self.metric_prefix = "phrase"
|
| 508 |
+
# the list of metrics to collect from the HOTA evaluation results
|
| 509 |
+
self.metric_to_collect = [
|
| 510 |
+
"HOTA",
|
| 511 |
+
"DetA",
|
| 512 |
+
"AssA",
|
| 513 |
+
"DetRe",
|
| 514 |
+
"DetPr",
|
| 515 |
+
"AssRe",
|
| 516 |
+
"AssPr",
|
| 517 |
+
"LocA",
|
| 518 |
+
"OWTA",
|
| 519 |
+
]
|
| 520 |
+
self.iou_types = list(iou_types) if iou_types is not None else ["bbox", "segm"]
|
| 521 |
+
assert all(iou_type in ["bbox", "segm"] for iou_type in self.iou_types)
|
| 522 |
+
|
| 523 |
+
# If True, compute video MOT HOTA, aggregating predictions/GT from all categories.
|
| 524 |
+
self.compute_video_mot_hota = compute_video_mot_hota
|
| 525 |
+
|
| 526 |
+
def evaluate(self, pred_file: str) -> Dict[str, float]:
|
| 527 |
+
# use the YT-VIS evaluation toolkit in TrackEval
|
| 528 |
+
|
| 529 |
+
with open(self.gt_ann_file) as f:
|
| 530 |
+
gt = json.load(f)
|
| 531 |
+
with open(pred_file) as f:
|
| 532 |
+
dt = json.load(f)
|
| 533 |
+
# keep only predictions with score above the probability threshold
|
| 534 |
+
dt = [d for d in dt if d["score"] > self.prob_thresh]
|
| 535 |
+
for d in dt:
|
| 536 |
+
assert len(d["areas"]) == len(d["bboxes"])
|
| 537 |
+
assert len(d["areas"]) == len(d["segmentations"])
|
| 538 |
+
# remove empty boxes (otherwise they will count as false positives for during
|
| 539 |
+
# per-frame detection accuracy in HOTA evaluation)
|
| 540 |
+
for t in range(len(d["bboxes"])):
|
| 541 |
+
bbox = d["bboxes"][t]
|
| 542 |
+
if d["areas"][t] == 0 or bbox is None or all(x == 0 for x in bbox):
|
| 543 |
+
d["segmentations"][t] = None
|
| 544 |
+
d["bboxes"][t] = None
|
| 545 |
+
d["areas"][t] = None
|
| 546 |
+
# check that box occurence and mask occurence are consistent
|
| 547 |
+
for bbox, mask, area in zip(d["bboxes"], d["segmentations"], d["areas"]):
|
| 548 |
+
assert (area is None) == (bbox is None)
|
| 549 |
+
assert (area is None) == (mask is None)
|
| 550 |
+
# set all scores to 1.0 for HOTA evaluation (just like Demo F1, the exact score
|
| 551 |
+
# value is not used in HOTA metrics; it will be treated as a detection prediction
|
| 552 |
+
# as long as its score is above the threshold)
|
| 553 |
+
d["score"] = 1.0
|
| 554 |
+
|
| 555 |
+
# remap the GT and DT annotations for phrase HOTA evaluation
|
| 556 |
+
gt = _fill_in_ann_height_width(gt)
|
| 557 |
+
if not self.compute_video_mot_hota:
|
| 558 |
+
# remap the GT and DT annotations for phrase HOTA evaluation
|
| 559 |
+
gt, dt = self._remap_gt_dt(gt, dt)
|
| 560 |
+
else:
|
| 561 |
+
# Compute video-level MOT HOTA
|
| 562 |
+
# Apply track-level NMS
|
| 563 |
+
video_groups = defaultdict(list)
|
| 564 |
+
for pred in dt:
|
| 565 |
+
video_groups[pred["video_id"]].append(pred)
|
| 566 |
+
process_track_level_nms(video_groups, nms_threshold=0.5)
|
| 567 |
+
dt = [track for tracks in video_groups.values() for track in tracks]
|
| 568 |
+
|
| 569 |
+
# Remap GT track ids for class-agnostic HOTA
|
| 570 |
+
gt, dt = remap_gt_dt_class_agnostic(gt, dt)
|
| 571 |
+
|
| 572 |
+
# run the HOTA evaluation using TrackEval on the remapped (video_id, category_id) pairs
|
| 573 |
+
out_dict = {}
|
| 574 |
+
video_np_level_results = {}
|
| 575 |
+
for iou_type in self.iou_types:
|
| 576 |
+
output_res, _ = run_ytvis_eval(
|
| 577 |
+
args=[
|
| 578 |
+
"--METRICS",
|
| 579 |
+
"HOTA",
|
| 580 |
+
"--IOU_TYPE",
|
| 581 |
+
iou_type,
|
| 582 |
+
"--DATASET_NAME",
|
| 583 |
+
self.dataset_name,
|
| 584 |
+
"--USE_PARALLEL",
|
| 585 |
+
"True",
|
| 586 |
+
"--NUM_PARALLEL_CORES",
|
| 587 |
+
"8",
|
| 588 |
+
"--PLOT_CURVES",
|
| 589 |
+
"False",
|
| 590 |
+
"--LOG_ON_ERROR",
|
| 591 |
+
"None",
|
| 592 |
+
"--PRINT_ONLY_COMBINED",
|
| 593 |
+
"True",
|
| 594 |
+
"--OUTPUT_SUMMARY",
|
| 595 |
+
"False",
|
| 596 |
+
"--OUTPUT_DETAILED",
|
| 597 |
+
"False",
|
| 598 |
+
"--TIME_PROGRESS",
|
| 599 |
+
"False",
|
| 600 |
+
"--PRINT_CONFIG",
|
| 601 |
+
"False",
|
| 602 |
+
],
|
| 603 |
+
gt_json=gt,
|
| 604 |
+
dt_json=dt,
|
| 605 |
+
)
|
| 606 |
+
self.extract_video_np_level_results(
|
| 607 |
+
iou_type=iou_type,
|
| 608 |
+
remapped_gt=gt,
|
| 609 |
+
raw_results=output_res[self.dataset_name]["tracker"],
|
| 610 |
+
video_np_level_results=video_np_level_results,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
def _summarize_results(output_res, iou_type, field, suffix):
|
| 614 |
+
eval_res = output_res[self.dataset_name]["tracker"][field]
|
| 615 |
+
result_prefix = f"{self.dataset_name}_{'mask' if iou_type == 'segm' else 'bbox'}_{suffix}"
|
| 616 |
+
for metric_name in self.metric_to_collect:
|
| 617 |
+
eval_res_hota = eval_res["cls_comb_cls_av"]["HOTA"]
|
| 618 |
+
result_key = f"{result_prefix}_{self.metric_prefix}_{metric_name}"
|
| 619 |
+
result_value = float(np.mean(eval_res_hota[metric_name]))
|
| 620 |
+
out_dict[result_key] = result_value
|
| 621 |
+
|
| 622 |
+
_summarize_results(output_res, iou_type, "COMBINED_SEQ", "all")
|
| 623 |
+
if "COMBINED_SEQ_CHALLENGING" in output_res[self.dataset_name]["tracker"]:
|
| 624 |
+
_summarize_results(
|
| 625 |
+
output_res, iou_type, "COMBINED_SEQ_CHALLENGING", "challenging"
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
# video-NP level results not supported for `VideoPhraseHotaEvaluator` yet
|
| 629 |
+
return out_dict, video_np_level_results
|
| 630 |
+
|
| 631 |
+
def _remap_gt_dt(self, gt, dt):
|
| 632 |
+
# For phrase HOTA evaluation, we need to remap each pair of (video_id, category_id) to
|
| 633 |
+
# a new unique video_id, so that we don't mix detections from different categories
|
| 634 |
+
gt, dt = remap_video_category_pairs_to_unique_video_ids(gt, dt)
|
| 635 |
+
# We further map all the categories to category_id=1 in HOTA evaluation toolkit
|
| 636 |
+
# for phrase HOTA (similar to "useCat=False" for video phrase AP)
|
| 637 |
+
remapped_category_id = 1
|
| 638 |
+
gt["categories"] = [
|
| 639 |
+
{
|
| 640 |
+
"supercategory": "object",
|
| 641 |
+
"id": remapped_category_id,
|
| 642 |
+
"name": "_REMAPPED_FOR_PHRASE_METRICS_",
|
| 643 |
+
}
|
| 644 |
+
]
|
| 645 |
+
for ann in gt["annotations"]:
|
| 646 |
+
ann["category_id"] = remapped_category_id
|
| 647 |
+
for d in dt:
|
| 648 |
+
d["category_id"] = remapped_category_id
|
| 649 |
+
# To be compatible with the TrackEval YT-VIS evaluation toolkit, we need to give
|
| 650 |
+
# unique filenames to each remapped video, so we add remapped video_id as prefix.
|
| 651 |
+
for video in gt["videos"]:
|
| 652 |
+
new_video_id = video["id"]
|
| 653 |
+
video["file_names"] = [
|
| 654 |
+
f"remapped_vid_{new_video_id:012d}/{name}"
|
| 655 |
+
for name in video["file_names"]
|
| 656 |
+
]
|
| 657 |
+
return gt, dt
|
| 658 |
+
|
| 659 |
+
def extract_video_np_level_results(
|
| 660 |
+
self, iou_type, remapped_gt, raw_results, video_np_level_results
|
| 661 |
+
):
|
| 662 |
+
"""Aggregate statistics for video-level metrics."""
|
| 663 |
+
result_prefix = "mask" if iou_type == "segm" else "bbox"
|
| 664 |
+
for video in remapped_gt["videos"]:
|
| 665 |
+
# the original video id and category id before remapping
|
| 666 |
+
video_id = video["orig_video_id"]
|
| 667 |
+
category_id = video["orig_category_id"]
|
| 668 |
+
video_key = f"remapped_vid_{video['id']:012d}"
|
| 669 |
+
results = raw_results[video_key]["_REMAPPED_FOR_PHRASE_METRICS_"]["HOTA"]
|
| 670 |
+
|
| 671 |
+
local_results = {}
|
| 672 |
+
for metric_name in self.metric_to_collect:
|
| 673 |
+
result_key = f"{result_prefix}_{metric_name}"
|
| 674 |
+
local_results[result_key] = float(results[metric_name].mean())
|
| 675 |
+
if (video_id, category_id) not in video_np_level_results:
|
| 676 |
+
video_np_level_results[(video_id, category_id)] = {}
|
| 677 |
+
video_np_level_results[(video_id, category_id)].update(local_results)
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
class VideoClassBasedHotaEvaluator(VideoPhraseHotaEvaluator):
|
| 681 |
+
def __init__(
|
| 682 |
+
self,
|
| 683 |
+
gt_ann_file: str,
|
| 684 |
+
dataset_name: str = "video",
|
| 685 |
+
prob_thresh: float = 0.5,
|
| 686 |
+
):
|
| 687 |
+
super().__init__(gt_ann_file, dataset_name, prob_thresh)
|
| 688 |
+
self.metric_prefix = "class"
|
| 689 |
+
|
| 690 |
+
def _remap_gt_dt(self, gt, dt):
|
| 691 |
+
return gt, dt # no remapping needed for class-based HOTA evaluation
|
| 692 |
+
|
| 693 |
+
def extract_video_np_level_results(self, *args, **kwargs):
|
| 694 |
+
pass # no video-NP level results for class-based HOTA evaluation
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def _compress_rle(rle):
|
| 698 |
+
"""Convert RLEs from uncompressed (integer list) to compressed (string) format."""
|
| 699 |
+
if rle is None:
|
| 700 |
+
return None
|
| 701 |
+
if isinstance(rle["counts"], list):
|
| 702 |
+
rle = pycocotools.mask.frPyObjects(rle, rle["size"][0], rle["size"][1])
|
| 703 |
+
rle["counts"] = rle["counts"].decode()
|
| 704 |
+
return rle
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def remap_video_category_pairs_to_unique_video_ids(
|
| 708 |
+
gt_json, dt_json, add_negative_np_pairs=False
|
| 709 |
+
):
|
| 710 |
+
"""
|
| 711 |
+
Remap each pair of (video_id, category_id) to a new unique video_id. This is useful
|
| 712 |
+
for phrase AP and demo F1 evaluation on videos, where we have `useCat=False` and
|
| 713 |
+
rely on separating different NPs (from the same video) into different new video ids,
|
| 714 |
+
so that we don't mix detections from different categories in computeIoU under `useCat=False`.
|
| 715 |
+
|
| 716 |
+
This is consistent with how do we phrase AP and demo F1 evaluation on images, where we
|
| 717 |
+
use a remapped unique coco_image_id for each image-NP pair (based in its query["id"] in
|
| 718 |
+
CustomCocoDetectionAPI.load_queries in modulated_detection_api.py)
|
| 719 |
+
"""
|
| 720 |
+
# collect the unique video_id-category_id pairs
|
| 721 |
+
video_id_to_video = {v["id"]: v for v in gt_json["videos"]}
|
| 722 |
+
video_id_category_id_pairs = set()
|
| 723 |
+
for pred in dt_json:
|
| 724 |
+
video_id_category_id_pairs.add((pred["video_id"], pred["category_id"]))
|
| 725 |
+
for ann in gt_json["annotations"]:
|
| 726 |
+
video_id_category_id_pairs.add((ann["video_id"], ann["category_id"]))
|
| 727 |
+
|
| 728 |
+
# assign the video_id-category_id pairs to unique video ids
|
| 729 |
+
video_id_category_id_pairs = sorted(video_id_category_id_pairs)
|
| 730 |
+
video_id_category_id_to_new_video_id = {
|
| 731 |
+
pair: (i + 1) for i, pair in enumerate(video_id_category_id_pairs)
|
| 732 |
+
}
|
| 733 |
+
# also map the negative NP pairs -- this is needed for IL_MCC and CG-F1 evaluation
|
| 734 |
+
if add_negative_np_pairs:
|
| 735 |
+
for vnp in gt_json["video_np_pairs"]:
|
| 736 |
+
pair = (vnp["video_id"], vnp["category_id"])
|
| 737 |
+
if pair not in video_id_category_id_to_new_video_id:
|
| 738 |
+
video_id_category_id_to_new_video_id[pair] = (
|
| 739 |
+
len(video_id_category_id_to_new_video_id) + 1
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
# map the "video_id" in predictions
|
| 743 |
+
for pred in dt_json:
|
| 744 |
+
pred["video_id"] = video_id_category_id_to_new_video_id[
|
| 745 |
+
(pred["video_id"], pred["category_id"])
|
| 746 |
+
]
|
| 747 |
+
# map the "video_id" in gt_json["annotations"]
|
| 748 |
+
for ann in gt_json["annotations"]:
|
| 749 |
+
ann["video_id"] = video_id_category_id_to_new_video_id[
|
| 750 |
+
(ann["video_id"], ann["category_id"])
|
| 751 |
+
]
|
| 752 |
+
# map and duplicate gt_json["videos"]
|
| 753 |
+
new_videos = []
|
| 754 |
+
for (
|
| 755 |
+
video_id,
|
| 756 |
+
category_id,
|
| 757 |
+
), new_video_id in video_id_category_id_to_new_video_id.items():
|
| 758 |
+
video = video_id_to_video[video_id].copy()
|
| 759 |
+
video["id"] = new_video_id
|
| 760 |
+
# preserve the original video_id and category_id of each remapped video entry,
|
| 761 |
+
# so that we can associate sample-level eval metrics with the original video-NP pairs
|
| 762 |
+
video["orig_video_id"] = video_id
|
| 763 |
+
video["orig_category_id"] = category_id
|
| 764 |
+
new_videos.append(video)
|
| 765 |
+
gt_json["videos"] = new_videos
|
| 766 |
+
|
| 767 |
+
return gt_json, dt_json
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
def remap_gt_dt_class_agnostic(gt, dt):
|
| 771 |
+
"""
|
| 772 |
+
For class-agnostic HOTA, merge all GT tracks for each video (across NPs),
|
| 773 |
+
ensure unique track_ids, and set all category_id to 1.
|
| 774 |
+
Also, add orig_video_id and orig_category_id for compatibility.
|
| 775 |
+
"""
|
| 776 |
+
# 1. Remap all GT track_ids to be unique per video
|
| 777 |
+
gt_anns_by_video = defaultdict(list)
|
| 778 |
+
for ann in gt["annotations"]:
|
| 779 |
+
gt_anns_by_video[ann["video_id"]].append(ann)
|
| 780 |
+
|
| 781 |
+
# Ensure unique track ids across tracks of all videos
|
| 782 |
+
next_tid = 1
|
| 783 |
+
for _, anns in gt_anns_by_video.items():
|
| 784 |
+
# Map old track_ids to new unique ones
|
| 785 |
+
old_to_new_tid = {}
|
| 786 |
+
for ann in anns:
|
| 787 |
+
old_tid = ann["id"]
|
| 788 |
+
if old_tid not in old_to_new_tid:
|
| 789 |
+
old_to_new_tid[old_tid] = next_tid
|
| 790 |
+
next_tid += 1
|
| 791 |
+
ann["id"] = old_to_new_tid[old_tid]
|
| 792 |
+
# Set category_id to 1 for class-agnostic
|
| 793 |
+
ann["category_id"] = 1
|
| 794 |
+
|
| 795 |
+
# Set all GT categories to a single category
|
| 796 |
+
gt["categories"] = [
|
| 797 |
+
{
|
| 798 |
+
"supercategory": "object",
|
| 799 |
+
"id": 1,
|
| 800 |
+
"name": "_REMAPPED_FOR_PHRASE_METRICS_",
|
| 801 |
+
}
|
| 802 |
+
]
|
| 803 |
+
|
| 804 |
+
# Add orig_video_id and orig_category_id to each video for compatibility
|
| 805 |
+
anns_by_video = defaultdict(list)
|
| 806 |
+
for ann in gt["annotations"]:
|
| 807 |
+
anns_by_video[ann["video_id"]].append(ann)
|
| 808 |
+
for video in gt["videos"]:
|
| 809 |
+
video["orig_video_id"] = video["id"]
|
| 810 |
+
# Use the first annotation's original category_id if available, else None
|
| 811 |
+
orig_cat = (
|
| 812 |
+
anns_by_video[video["id"]][0]["category_id"]
|
| 813 |
+
if anns_by_video[video["id"]]
|
| 814 |
+
else None
|
| 815 |
+
)
|
| 816 |
+
video["orig_category_id"] = orig_cat
|
| 817 |
+
video["file_names"] = [
|
| 818 |
+
f"remapped_vid_{video['id']:012d}/{name}" for name in video["file_names"]
|
| 819 |
+
]
|
| 820 |
+
|
| 821 |
+
# Set all DT category_id to 1
|
| 822 |
+
for d in dt:
|
| 823 |
+
d["category_id"] = 1
|
| 824 |
+
return gt, dt
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
def _fill_in_ann_height_width(gt_json):
|
| 828 |
+
"""Fill in missing height/width in GT annotations from its video info."""
|
| 829 |
+
video_id_to_video = {v["id"]: v for v in gt_json["videos"]}
|
| 830 |
+
for ann in gt_json["annotations"]:
|
| 831 |
+
if "height" not in ann or "width" not in ann:
|
| 832 |
+
video = video_id_to_video[ann["video_id"]]
|
| 833 |
+
if "height" not in ann:
|
| 834 |
+
ann["height"] = video["height"]
|
| 835 |
+
if "width" not in ann:
|
| 836 |
+
ann["width"] = video["width"]
|
| 837 |
+
|
| 838 |
+
return gt_json
|
sam3/eval/teta_eval_toolkit/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# fmt: off
|
| 2 |
+
# flake8: noqa
|
| 3 |
+
|
| 4 |
+
from . import config, datasets, metrics, utils
|
| 5 |
+
from .eval import Evaluator
|
sam3/eval/teta_eval_toolkit/_timing.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# fmt: off
|
| 2 |
+
# flake8: noqa
|
| 3 |
+
|
| 4 |
+
import inspect
|
| 5 |
+
from functools import wraps
|
| 6 |
+
from time import perf_counter
|
| 7 |
+
|
| 8 |
+
DO_TIMING = False
|
| 9 |
+
DISPLAY_LESS_PROGRESS = False
|
| 10 |
+
timer_dict = {}
|
| 11 |
+
counter = 0
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def time(f):
|
| 15 |
+
@wraps(f)
|
| 16 |
+
def wrap(*args, **kw):
|
| 17 |
+
if DO_TIMING:
|
| 18 |
+
# Run function with timing
|
| 19 |
+
ts = perf_counter()
|
| 20 |
+
result = f(*args, **kw)
|
| 21 |
+
te = perf_counter()
|
| 22 |
+
tt = te - ts
|
| 23 |
+
|
| 24 |
+
# Get function name
|
| 25 |
+
arg_names = inspect.getfullargspec(f)[0]
|
| 26 |
+
if arg_names[0] == "self" and DISPLAY_LESS_PROGRESS:
|
| 27 |
+
return result
|
| 28 |
+
elif arg_names[0] == "self":
|
| 29 |
+
method_name = type(args[0]).__name__ + "." + f.__name__
|
| 30 |
+
else:
|
| 31 |
+
method_name = f.__name__
|
| 32 |
+
|
| 33 |
+
# Record accumulative time in each function for analysis
|
| 34 |
+
if method_name in timer_dict.keys():
|
| 35 |
+
timer_dict[method_name] += tt
|
| 36 |
+
else:
|
| 37 |
+
timer_dict[method_name] = tt
|
| 38 |
+
|
| 39 |
+
# If code is finished, display timing summary
|
| 40 |
+
if method_name == "Evaluator.evaluate":
|
| 41 |
+
print("")
|
| 42 |
+
print("Timing analysis:")
|
| 43 |
+
for key, value in timer_dict.items():
|
| 44 |
+
print("%-70s %2.4f sec" % (key, value))
|
| 45 |
+
else:
|
| 46 |
+
# Get function argument values for printing special arguments of interest
|
| 47 |
+
arg_titles = ["tracker", "seq", "cls"]
|
| 48 |
+
arg_vals = []
|
| 49 |
+
for i, a in enumerate(arg_names):
|
| 50 |
+
if a in arg_titles:
|
| 51 |
+
arg_vals.append(args[i])
|
| 52 |
+
arg_text = "(" + ", ".join(arg_vals) + ")"
|
| 53 |
+
|
| 54 |
+
# Display methods and functions with different indentation.
|
| 55 |
+
if arg_names[0] == "self":
|
| 56 |
+
print("%-74s %2.4f sec" % (" " * 4 + method_name + arg_text, tt))
|
| 57 |
+
elif arg_names[0] == "test":
|
| 58 |
+
pass
|
| 59 |
+
else:
|
| 60 |
+
global counter
|
| 61 |
+
counter += 1
|
| 62 |
+
print("%i %-70s %2.4f sec" % (counter, method_name + arg_text, tt))
|
| 63 |
+
|
| 64 |
+
return result
|
| 65 |
+
else:
|
| 66 |
+
# If config["TIME_PROGRESS"] is false, or config["USE_PARALLEL"] is true, run functions normally without timing.
|
| 67 |
+
return f(*args, **kw)
|
| 68 |
+
|
| 69 |
+
return wrap
|