Justin331 commited on
Commit
3e8dd07
·
verified ·
1 Parent(s): d63a70e

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. sam3/__init__.py +7 -0
  3. sam3/agent/__init__.py +1 -0
  4. sam3/agent/agent_core.py +563 -0
  5. sam3/agent/client_llm.py +205 -0
  6. sam3/agent/client_sam3.py +138 -0
  7. sam3/agent/helpers/__init__.py +1 -0
  8. sam3/agent/helpers/boxes.py +438 -0
  9. sam3/agent/helpers/color_map.py +150 -0
  10. sam3/agent/helpers/keypoints.py +244 -0
  11. sam3/agent/helpers/mask_overlap_removal.py +128 -0
  12. sam3/agent/helpers/masks.py +560 -0
  13. sam3/agent/helpers/memory.py +87 -0
  14. sam3/agent/helpers/rle.py +122 -0
  15. sam3/agent/helpers/roi_align.py +75 -0
  16. sam3/agent/helpers/rotated_boxes.py +533 -0
  17. sam3/agent/helpers/som_utils.py +406 -0
  18. sam3/agent/helpers/visualizer.py +1662 -0
  19. sam3/agent/helpers/zoom_in.py +195 -0
  20. sam3/agent/inference.py +65 -0
  21. sam3/agent/system_prompts/system_prompt.txt +242 -0
  22. sam3/agent/system_prompts/system_prompt_iterative_checking.txt +26 -0
  23. sam3/agent/viz.py +114 -0
  24. sam3/eval/__init__.py +1 -0
  25. sam3/eval/cgf1_eval.py +703 -0
  26. sam3/eval/coco_eval.py +916 -0
  27. sam3/eval/coco_eval_offline.py +181 -0
  28. sam3/eval/coco_reindex.py +230 -0
  29. sam3/eval/coco_writer.py +352 -0
  30. sam3/eval/conversion_util.py +211 -0
  31. sam3/eval/demo_eval.py +658 -0
  32. sam3/eval/hota_eval_toolkit/__init__.py +1 -0
  33. sam3/eval/hota_eval_toolkit/run_ytvis_eval.py +114 -0
  34. sam3/eval/hota_eval_toolkit/trackeval/__init__.py +4 -0
  35. sam3/eval/hota_eval_toolkit/trackeval/_timing.py +68 -0
  36. sam3/eval/hota_eval_toolkit/trackeval/datasets/__init__.py +4 -0
  37. sam3/eval/hota_eval_toolkit/trackeval/datasets/_base_dataset.py +379 -0
  38. sam3/eval/hota_eval_toolkit/trackeval/datasets/tao_ow.py +891 -0
  39. sam3/eval/hota_eval_toolkit/trackeval/datasets/youtube_vis.py +524 -0
  40. sam3/eval/hota_eval_toolkit/trackeval/eval.py +395 -0
  41. sam3/eval/hota_eval_toolkit/trackeval/metrics/__init__.py +4 -0
  42. sam3/eval/hota_eval_toolkit/trackeval/metrics/_base_metric.py +145 -0
  43. sam3/eval/hota_eval_toolkit/trackeval/metrics/count.py +48 -0
  44. sam3/eval/hota_eval_toolkit/trackeval/metrics/hota.py +291 -0
  45. sam3/eval/hota_eval_toolkit/trackeval/utils.py +195 -0
  46. sam3/eval/postprocessors.py +648 -0
  47. sam3/eval/saco_veval_eval.py +155 -0
  48. sam3/eval/saco_veval_evaluators.py +838 -0
  49. sam3/eval/teta_eval_toolkit/__init__.py +5 -0
  50. sam3/eval/teta_eval_toolkit/_timing.py +69 -0
.gitattributes CHANGED
@@ -1,2 +1,3 @@
1
  *.pt filter=lfs diff=lfs merge=lfs -text
2
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
1
  *.pt filter=lfs diff=lfs merge=lfs -text
2
  *.safetensors filter=lfs diff=lfs merge=lfs -text
3
+ sam3/perflib/tests/assets/masks.tiff filter=lfs diff=lfs merge=lfs -text
sam3/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ from .model_builder import build_sam3_image_model
4
+
5
+ __version__ = "0.1.0"
6
+
7
+ __all__ = ["build_sam3_image_model"]
sam3/agent/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
sam3/agent/agent_core.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import copy
4
+ import json
5
+ import os
6
+
7
+ import cv2
8
+ from PIL import Image
9
+
10
+ from .client_llm import send_generate_request
11
+ from .client_sam3 import call_sam_service
12
+ from .viz import visualize
13
+
14
+
15
+ def save_debug_messages(messages_list, debug, debug_folder_path, debug_jsonl_path):
16
+ """Save messages to debug jsonl file if debug is enabled"""
17
+ if debug and debug_jsonl_path:
18
+ # Ensure the debug directory exists before writing
19
+ os.makedirs(debug_folder_path, exist_ok=True)
20
+ with open(debug_jsonl_path, "w") as f:
21
+ for msg in messages_list:
22
+ f.write(json.dumps(msg, indent=4) + "\n")
23
+
24
+
25
+ def cleanup_debug_files(debug, debug_folder_path, debug_jsonl_path):
26
+ """Clean up debug files when function successfully returns"""
27
+ if debug and debug_folder_path:
28
+ try:
29
+ if os.path.exists(debug_jsonl_path):
30
+ os.remove(debug_jsonl_path)
31
+ if os.path.exists(debug_folder_path):
32
+ os.rmdir(debug_folder_path)
33
+ except Exception as e:
34
+ print(f"Warning: Could not clean up debug files: {e}")
35
+
36
+
37
+ def count_images(messages):
38
+ """Count the total number of images present in the messages history."""
39
+ total = 0
40
+ for message in messages:
41
+ # Check if message has content (should be a list)
42
+ if "content" in message and isinstance(message["content"], list):
43
+ # Iterate through each content item
44
+ for content_item in message["content"]:
45
+ # Check if content item is a dict with type "image"
46
+ if (
47
+ isinstance(content_item, dict)
48
+ and content_item.get("type") == "image"
49
+ ):
50
+ total += 1
51
+ return total
52
+
53
+
54
+ def _prune_messages_for_next_round(
55
+ messages_list,
56
+ used_text_prompts,
57
+ latest_sam3_text_prompt,
58
+ img_path,
59
+ initial_text_prompt,
60
+ ):
61
+ """Return a new messages list that contains only:
62
+ 1) messages[:2] (with optional warning text added to the second message's content)
63
+ 2) the latest assistant message (and everything after it) that contains a segment_phrase tool call
64
+ """
65
+ # There should not be more than 10 messages in the conversation history
66
+ assert len(messages_list) < 10
67
+
68
+ # Part 1: always keep the first two message JSONs
69
+ part1 = copy.deepcopy(messages_list[:2])
70
+
71
+ # Part 2: search backwards for the latest assistant message containing a segment_phrase tool call
72
+ part2_start_idx = None
73
+ for idx in range(len(messages_list) - 1, 1, -1):
74
+ msg = messages_list[idx]
75
+ # We only consider assistant messages with a "content" list
76
+ if msg.get("role") != "assistant" or "content" not in msg:
77
+ continue
78
+ # Look for any content element that is a text containing the segment_phrase tool call
79
+ for content in msg["content"]:
80
+ if (
81
+ isinstance(content, dict)
82
+ and content.get("type") == "text"
83
+ and "<tool>" in content.get("text", "")
84
+ and "segment_phrase" in content.get("text", "")
85
+ ):
86
+ part2_start_idx = idx
87
+ break
88
+ if part2_start_idx is not None:
89
+ break
90
+
91
+ part2 = messages_list[part2_start_idx:] if part2_start_idx is not None else []
92
+
93
+ # Part 3: decide whether to add warning text to the second message in part1
94
+ previously_used = (
95
+ [p for p in used_text_prompts if p != latest_sam3_text_prompt]
96
+ if latest_sam3_text_prompt
97
+ else list(used_text_prompts)
98
+ )
99
+ if part2 and len(previously_used) > 0:
100
+ warning_text = f'Note that we have previously called the segment_phrase tool with each "text_prompt" in this list: {list(previously_used)}, but none of the generated results were satisfactory. So make sure that you do not use any of these phrases as the "text_prompt" to call the segment_phrase tool again.'
101
+ # Replace the second message entirely to keep exactly 2 content items
102
+ part1[1] = {
103
+ "role": "user",
104
+ "content": [
105
+ {"type": "image", "image": img_path},
106
+ {
107
+ "type": "text",
108
+ "text": f"The above image is the raw input image. The initial user input query is: '{initial_text_prompt}'."
109
+ + " "
110
+ + warning_text,
111
+ },
112
+ ],
113
+ }
114
+ assert len(part1[1]["content"]) == 2
115
+
116
+ # Build the new messages list: part1 (with optional warning), then part2
117
+ new_messages = list(part1)
118
+ new_messages.extend(part2)
119
+ return new_messages
120
+
121
+
122
+ def agent_inference(
123
+ img_path: str,
124
+ initial_text_prompt: str,
125
+ debug: bool = False,
126
+ send_generate_request=send_generate_request,
127
+ call_sam_service=call_sam_service,
128
+ max_generations: int = 100,
129
+ output_dir="../../sam3_agent_out",
130
+ ):
131
+ """
132
+ Given a text prompt and an image, this tool will perform all aspects of agentic problem solving,
133
+ while saving sam3 and MLLM outputs to their respective directories.
134
+
135
+ Args:
136
+ img_path: Path to the input image
137
+ initial_text_prompt: Initial text prompt from the user
138
+ debug: Whether to enable debug mode
139
+ max_generations: Maximum number of send_generate_request calls allowed (default: 100)
140
+ """
141
+ # setup dir
142
+ sam_output_dir = os.path.join(output_dir, "sam_out")
143
+ error_save_dir = os.path.join(output_dir, "none_out")
144
+ debug_save_dir = os.path.join(output_dir, "agent_debug_out")
145
+ os.makedirs(sam_output_dir, exist_ok=True)
146
+ os.makedirs(error_save_dir, exist_ok=True)
147
+ os.makedirs(debug_save_dir, exist_ok=True)
148
+ current_dir = os.path.dirname(os.path.abspath(__file__))
149
+ MLLM_SYSTEM_PROMPT_PATH = os.path.join(
150
+ current_dir, "system_prompts/system_prompt.txt"
151
+ )
152
+ ITERATIVE_CHECKING_SYSTEM_PROMPT_PATH = os.path.join(
153
+ current_dir, "system_prompts/system_prompt_iterative_checking.txt"
154
+ )
155
+ # init variables
156
+ PATH_TO_LATEST_OUTPUT_JSON = ""
157
+ LATEST_SAM3_TEXT_PROMPT = ""
158
+ USED_TEXT_PROMPTS = (
159
+ set()
160
+ ) # Track all previously used text prompts for segment_phrase
161
+ generation_count = 0 # Counter for number of send_generate_request calls
162
+
163
+ # debug setup
164
+ debug_folder_path = None
165
+ debug_jsonl_path = None
166
+ if debug:
167
+ debug_folder_path = os.path.join(
168
+ debug_save_dir, f"{img_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]}"
169
+ )
170
+ debug_jsonl_path = os.path.join(debug_folder_path, "debug_history.json")
171
+ os.makedirs(debug_folder_path, exist_ok=True)
172
+
173
+ # The helper functions are now defined outside the agent_inference function
174
+ with open(MLLM_SYSTEM_PROMPT_PATH, "r") as f:
175
+ system_prompt = f.read().strip()
176
+ with open(ITERATIVE_CHECKING_SYSTEM_PROMPT_PATH, "r") as f:
177
+ iterative_checking_system_prompt = f.read().strip()
178
+
179
+ # Construct the initial message list
180
+ messages = [
181
+ {"role": "system", "content": system_prompt},
182
+ {
183
+ "role": "user",
184
+ "content": [
185
+ {"type": "image", "image": img_path},
186
+ {
187
+ "type": "text",
188
+ "text": f"The above image is the raw input image. The initial user input query is: '{initial_text_prompt}'.",
189
+ },
190
+ ],
191
+ },
192
+ ]
193
+ print(f"> Text prompt: {initial_text_prompt}")
194
+ print(f"> Image path: {img_path}")
195
+
196
+ print("\n\n")
197
+ print("-" * 30 + f" Round {str(generation_count + 1)}" + "-" * 30)
198
+ print("\n\n")
199
+ generated_text = send_generate_request(messages)
200
+ print(f"\n>>> MLLM Response [start]\n{generated_text}\n<<< MLLM Response [end]\n")
201
+ while generated_text is not None:
202
+ save_debug_messages(messages, debug, debug_folder_path, debug_jsonl_path)
203
+ assert (
204
+ "<tool>" in generated_text,
205
+ f"Generated text does not contain <tool> tag: {generated_text}",
206
+ )
207
+ generated_text = generated_text.split("</tool>", 1)[0] + "</tool>"
208
+ tool_call_json_str = (
209
+ generated_text.split("<tool>")[-1]
210
+ .split("</tool>")[0]
211
+ .strip()
212
+ .replace(r"}}}", r"}}") # remove extra } if any
213
+ )
214
+ try:
215
+ tool_call = json.loads(tool_call_json_str)
216
+ except json.JSONDecodeError:
217
+ raise ValueError(f"Invalid JSON in tool call: {tool_call_json_str}")
218
+
219
+ if PATH_TO_LATEST_OUTPUT_JSON == "":
220
+ # The first tool call must be segment_phrase or report_no_mask
221
+ assert (
222
+ tool_call["name"] == "segment_phrase"
223
+ or tool_call["name"] == "report_no_mask"
224
+ )
225
+
226
+ if tool_call["name"] == "segment_phrase":
227
+ print("🔍 Calling segment_phrase tool...")
228
+ assert list(tool_call["parameters"].keys()) == ["text_prompt"]
229
+
230
+ # Check if this text_prompt has been used before
231
+ current_text_prompt = tool_call["parameters"]["text_prompt"]
232
+ if current_text_prompt in USED_TEXT_PROMPTS:
233
+ print(
234
+ f"❌ Text prompt '{current_text_prompt}' has been used before. Requesting a different prompt."
235
+ )
236
+ duplicate_prompt_message = f"You have previously used '{current_text_prompt}' as your text_prompt to call the segment_phrase tool. You may not use it again. Please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase prompt, while adhering to all the rules stated in the system prompt. You must also never use any of the following text_prompt(s): {str(list(USED_TEXT_PROMPTS))}."
237
+ messages.append(
238
+ {
239
+ "role": "assistant",
240
+ "content": [{"type": "text", "text": generated_text}],
241
+ }
242
+ )
243
+ messages.append(
244
+ {
245
+ "role": "user",
246
+ "content": [{"type": "text", "text": duplicate_prompt_message}],
247
+ }
248
+ )
249
+ else:
250
+ # Add the text_prompt to the set of used prompts
251
+ USED_TEXT_PROMPTS.add(current_text_prompt)
252
+ LATEST_SAM3_TEXT_PROMPT = current_text_prompt
253
+ PATH_TO_LATEST_OUTPUT_JSON = call_sam_service(
254
+ image_path=img_path,
255
+ text_prompt=current_text_prompt,
256
+ output_folder_path=sam_output_dir,
257
+ )
258
+ sam3_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
259
+ sam3_output_image_path = sam3_outputs["output_image_path"]
260
+ num_masks = len(sam3_outputs["pred_boxes"])
261
+
262
+ messages.append(
263
+ {
264
+ "role": "assistant",
265
+ "content": [{"type": "text", "text": generated_text}],
266
+ }
267
+ )
268
+ if num_masks == 0:
269
+ print("❌ No masks generated by SAM3, reporting no mask to Qwen.")
270
+ sam3_output_text_message = f"The segment_phrase tool did not generate any masks for the text_prompt '{current_text_prompt}'. Now, please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase text_prompt, while adhering to all the rules stated in the system prompt. Please be reminded that the original user query was '{initial_text_prompt}'."
271
+ messages.append(
272
+ {
273
+ "role": "user",
274
+ "content": [
275
+ {"type": "text", "text": sam3_output_text_message}
276
+ ],
277
+ }
278
+ )
279
+ else:
280
+ sam3_output_text_message = rf"The segment_phrase tool generated {num_masks} available masks. All {num_masks} available masks are rendered in this image below, now you must analyze the {num_masks} available mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action. Please be reminded that the original user query was '{initial_text_prompt}'."
281
+ messages.append(
282
+ {
283
+ "role": "user",
284
+ "content": [
285
+ {"type": "text", "text": sam3_output_text_message},
286
+ {"type": "image", "image": sam3_output_image_path},
287
+ ],
288
+ }
289
+ )
290
+ print("\n\n>>> sam3_output_text_message:\n", sam3_output_text_message)
291
+
292
+ elif tool_call["name"] == "examine_each_mask":
293
+ print("🔍 Calling examine_each_mask tool...")
294
+ assert LATEST_SAM3_TEXT_PROMPT != ""
295
+
296
+ # Make sure that the last message is a image
297
+ assert (
298
+ messages[-1]["content"][1]["type"] == "image"
299
+ ), "Second content element should be an image"
300
+ messages.pop() # Remove the last user message
301
+ # Add simplified replacement message
302
+ simplified_message = {
303
+ "role": "user",
304
+ "content": [
305
+ {
306
+ "type": "text",
307
+ "text": "The segment_phrase tool generated several masks. Now you must analyze the mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action.",
308
+ }
309
+ ],
310
+ }
311
+ messages.append(simplified_message)
312
+
313
+ current_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
314
+ num_masks = len(current_outputs["pred_masks"])
315
+ masks_to_keep = []
316
+
317
+ # MLLM check the mask one by one
318
+ for i in range(num_masks):
319
+ print(f"🔍 Checking mask {i+1}/{num_masks}...")
320
+ image_w_mask_i, image_w_zoomed_in_mask_i = visualize(current_outputs, i)
321
+
322
+ image_w_zoomed_in_mask_i_path = os.path.join(
323
+ sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png".replace("/", "_")
324
+ ).replace(".png", f"_zoom_in_mask_{i + 1}.png")
325
+ image_w_mask_i_path = os.path.join(
326
+ sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png".replace("/", "_")
327
+ ).replace(".png", f"_selected_mask_{i + 1}.png")
328
+ image_w_zoomed_in_mask_i.save(image_w_zoomed_in_mask_i_path)
329
+ image_w_mask_i.save(image_w_mask_i_path)
330
+
331
+ iterative_checking_messages = [
332
+ {"role": "system", "content": iterative_checking_system_prompt},
333
+ {
334
+ "role": "user",
335
+ "content": [
336
+ {"type": "text", "text": f"The raw input image: "},
337
+ {"type": "image", "image": img_path},
338
+ {
339
+ "type": "text",
340
+ "text": f"The initial user input query is: '{initial_text_prompt}'",
341
+ },
342
+ {
343
+ "type": "text",
344
+ "text": f"Image with the predicted segmentation mask rendered on it: ",
345
+ },
346
+ {"type": "image", "image": image_w_mask_i_path},
347
+ {
348
+ "type": "text",
349
+ "text": f"Image with the zoomed-in mask: ",
350
+ },
351
+ {"type": "image", "image": image_w_zoomed_in_mask_i_path},
352
+ ],
353
+ },
354
+ ]
355
+ checking_generated_text = send_generate_request(
356
+ iterative_checking_messages
357
+ )
358
+
359
+ # Process the generated text to determine if the mask should be kept or rejected
360
+ if checking_generated_text is None:
361
+ raise ValueError(
362
+ "Generated text is None, which is unexpected. Please check the Qwen server and the input parameters."
363
+ )
364
+ print(f"Generated text for mask {i+1}: {checking_generated_text}")
365
+ verdict = (
366
+ checking_generated_text.split("<verdict>")[-1]
367
+ .split("</verdict>")[0]
368
+ .strip()
369
+ )
370
+ if "Accept" in verdict:
371
+ assert not "Reject" in verdict
372
+ print(f"Mask {i+1} accepted, keeping it in the outputs.")
373
+ masks_to_keep.append(i)
374
+ elif "Reject" in verdict:
375
+ assert not "Accept" in verdict
376
+ print(f"Mask {i+1} rejected, removing it from the outputs.")
377
+ else:
378
+ raise ValueError(
379
+ f"Unexpected verdict in generated text: {checking_generated_text}. Expected 'Accept' or 'Reject'."
380
+ )
381
+
382
+ updated_outputs = {
383
+ "original_image_path": current_outputs["original_image_path"],
384
+ "orig_img_h": current_outputs["orig_img_h"],
385
+ "orig_img_w": current_outputs["orig_img_w"],
386
+ "pred_boxes": [current_outputs["pred_boxes"][i] for i in masks_to_keep],
387
+ "pred_scores": [
388
+ current_outputs["pred_scores"][i] for i in masks_to_keep
389
+ ],
390
+ "pred_masks": [current_outputs["pred_masks"][i] for i in masks_to_keep],
391
+ }
392
+
393
+ image_w_check_masks = visualize(updated_outputs)
394
+ image_w_check_masks_path = os.path.join(
395
+ sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png"
396
+ ).replace(
397
+ ".png",
398
+ f"_selected_masks_{'-'.join(map(str, [i+1 for i in masks_to_keep]))}.png".replace(
399
+ "/", "_"
400
+ ),
401
+ )
402
+ image_w_check_masks.save(image_w_check_masks_path)
403
+ # save the updated json outputs and append to message history
404
+ messages.append(
405
+ {
406
+ "role": "assistant",
407
+ "content": [{"type": "text", "text": generated_text}],
408
+ }
409
+ )
410
+ if len(masks_to_keep) == 0:
411
+ messages.append(
412
+ {
413
+ "role": "user",
414
+ "content": [
415
+ {
416
+ "type": "text",
417
+ "text": f"The original user query was: '{initial_text_prompt}'. The examine_each_mask tool examined and rejected all of the masks generated by the segment_phrase tool. Now, please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase text_prompt, while adhering to all the rules stated in the system prompt.",
418
+ }
419
+ ],
420
+ }
421
+ )
422
+ else:
423
+ messages.append(
424
+ {
425
+ "role": "user",
426
+ "content": [
427
+ {
428
+ "type": "text",
429
+ "text": f"The original user query was: '{initial_text_prompt}'. After calling the examine_each_mask tool on the available masks, the number of available masks is now {len(masks_to_keep)}. All {len(masks_to_keep)} available masks are rendered in this image below, now you must analyze the {len(masks_to_keep)} available mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action.",
430
+ },
431
+ {"type": "image", "image": image_w_check_masks_path},
432
+ ],
433
+ }
434
+ )
435
+
436
+ # Create a new filename based on the original path to avoid filename length issues
437
+ base_path = PATH_TO_LATEST_OUTPUT_JSON
438
+ # Remove any existing "masks_" suffix to avoid duplication
439
+ if "masks_" in base_path:
440
+ base_path = base_path.split("masks_")[0] + ".json"
441
+ # Create new filename with current masks; use a clearer suffix when empty
442
+ if len(masks_to_keep) == 0:
443
+ PATH_TO_LATEST_OUTPUT_JSON = base_path.replace(
444
+ ".json", "masks_none.json"
445
+ )
446
+ else:
447
+ PATH_TO_LATEST_OUTPUT_JSON = base_path.replace(
448
+ ".json", f"masks_{'_'.join(map(str, masks_to_keep))}.json"
449
+ )
450
+ json.dump(updated_outputs, open(PATH_TO_LATEST_OUTPUT_JSON, "w"), indent=4)
451
+
452
+ elif tool_call["name"] == "select_masks_and_return":
453
+ print("🔍 Calling select_masks_and_return tool...")
454
+ current_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
455
+
456
+ assert list(tool_call["parameters"].keys()) == ["final_answer_masks"]
457
+ masks_to_keep = tool_call["parameters"]["final_answer_masks"]
458
+
459
+ # Keep only valid mask indices, remove duplicates, and preserve deterministic ascending order
460
+ available_masks = set(range(1, len(current_outputs["pred_masks"]) + 1))
461
+ masks_to_keep = sorted({i for i in masks_to_keep if i in available_masks})
462
+ # Change this to a update message telling the model to try again along with information about errors made.
463
+
464
+ final_outputs = {
465
+ "original_image_path": current_outputs["original_image_path"],
466
+ "orig_img_h": current_outputs["orig_img_h"],
467
+ "orig_img_w": current_outputs["orig_img_w"],
468
+ "pred_boxes": [
469
+ current_outputs["pred_boxes"][i - 1] for i in masks_to_keep
470
+ ],
471
+ "pred_scores": [
472
+ current_outputs["pred_scores"][i - 1] for i in masks_to_keep
473
+ ],
474
+ "pred_masks": [
475
+ current_outputs["pred_masks"][i - 1] for i in masks_to_keep
476
+ ],
477
+ }
478
+
479
+ rendered_final_output = visualize(final_outputs)
480
+ messages.append(
481
+ {
482
+ "role": "assistant",
483
+ "content": [{"type": "text", "text": generated_text}],
484
+ }
485
+ )
486
+
487
+ # Clean up debug files before successful return
488
+ cleanup_debug_files(debug, debug_folder_path, debug_jsonl_path)
489
+ return messages, final_outputs, rendered_final_output
490
+
491
+ elif tool_call["name"] == "report_no_mask":
492
+ print("🔍 Calling report_no_mask tool...")
493
+ height, width = cv2.imread(img_path).shape[:2]
494
+ final_outputs = {
495
+ "original_image_path": img_path,
496
+ "orig_img_h": height,
497
+ "orig_img_w": width,
498
+ "pred_boxes": [],
499
+ "pred_scores": [],
500
+ "pred_masks": [],
501
+ }
502
+ rendered_final_output = Image.open(img_path)
503
+ messages.append(
504
+ {
505
+ "role": "assistant",
506
+ "content": [{"type": "text", "text": generated_text}],
507
+ }
508
+ )
509
+ return messages, final_outputs, rendered_final_output
510
+
511
+ else:
512
+ raise ValueError(f"Unknown tool call: {tool_call['name']}")
513
+
514
+ # sometimes the MLLM don't know when to stop, and generates multiple tool calls in one round, so we need to split the generated text by </tool> and only keep the first one
515
+
516
+ for message in messages:
517
+ if message["role"] == "assistant" and "content" in message:
518
+ for content in message["content"]:
519
+ if (
520
+ isinstance(content, dict)
521
+ and content.get("type") == "text"
522
+ and "text" in content
523
+ ):
524
+ content["text"] = (
525
+ content["text"].split("</tool>", 1)[0] + "</tool>\n\n"
526
+ )
527
+ # Prune the messages history before the next MLLM generation round according to the 3-part rules.
528
+ # This keeps history compact and ensures the model sees only the allowed parts.
529
+ messages = _prune_messages_for_next_round(
530
+ messages,
531
+ USED_TEXT_PROMPTS,
532
+ LATEST_SAM3_TEXT_PROMPT,
533
+ img_path,
534
+ initial_text_prompt,
535
+ )
536
+ # make sure there can never be more than 2 images in the context
537
+ assert count_images(messages) <= 2
538
+ generation_count += 1
539
+ if generation_count > max_generations:
540
+ raise ValueError(
541
+ f"Exceeded maximum number of allowed generation requests ({max_generations})"
542
+ )
543
+
544
+ print("\n\n")
545
+ print("-" * 30 + f" Round {str(generation_count + 1)}" + "-" * 30)
546
+ print("\n\n")
547
+ generated_text = send_generate_request(messages)
548
+ print(
549
+ f"\n>>> MLLM Response [start]\n{generated_text}\n<<< MLLM Response [end]\n"
550
+ )
551
+
552
+ print("\n\n>>> SAM 3 Agent execution ended.\n\n")
553
+
554
+ error_save_path = os.path.join(
555
+ error_save_dir,
556
+ f"{img_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]}_error_history.json",
557
+ )
558
+ with open(error_save_path, "w") as f:
559
+ json.dump(messages, f, indent=4)
560
+ print("Saved messages history that caused error to:", error_save_path)
561
+ raise ValueError(
562
+ rf"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters for image path: {img_path} and initial text prompt: {initial_text_prompt}."
563
+ )
sam3/agent/client_llm.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import base64
4
+ import os
5
+ from typing import Any, Optional
6
+
7
+ from openai import OpenAI
8
+
9
+
10
+ def get_image_base64_and_mime(image_path):
11
+ """Convert image file to base64 string and get MIME type"""
12
+ try:
13
+ # Get MIME type based on file extension
14
+ ext = os.path.splitext(image_path)[1].lower()
15
+ mime_types = {
16
+ ".jpg": "image/jpeg",
17
+ ".jpeg": "image/jpeg",
18
+ ".png": "image/png",
19
+ ".gif": "image/gif",
20
+ ".webp": "image/webp",
21
+ ".bmp": "image/bmp",
22
+ }
23
+ mime_type = mime_types.get(ext, "image/jpeg") # Default to JPEG
24
+
25
+ # Convert image to base64
26
+ with open(image_path, "rb") as image_file:
27
+ base64_data = base64.b64encode(image_file.read()).decode("utf-8")
28
+ return base64_data, mime_type
29
+ except Exception as e:
30
+ print(f"Error converting image to base64: {e}")
31
+ return None, None
32
+
33
+
34
+ def send_generate_request(
35
+ messages,
36
+ server_url=None,
37
+ model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
38
+ api_key=None,
39
+ max_tokens=4096,
40
+ ):
41
+ """
42
+ Sends a request to the OpenAI-compatible API endpoint using the OpenAI client library.
43
+
44
+ Args:
45
+ server_url (str): The base URL of the server, e.g. "http://127.0.0.1:8000"
46
+ messages (list): A list of message dicts, each containing role and content.
47
+ model (str): The model to use for generation (default: "llama-4")
48
+ max_tokens (int): Maximum number of tokens to generate (default: 4096)
49
+
50
+ Returns:
51
+ str: The generated response text from the server.
52
+ """
53
+ # Process messages to convert image paths to base64
54
+ processed_messages = []
55
+ for message in messages:
56
+ processed_message = message.copy()
57
+ if message["role"] == "user" and "content" in message:
58
+ processed_content = []
59
+ for c in message["content"]:
60
+ if isinstance(c, dict) and c.get("type") == "image":
61
+ # Convert image path to base64 format
62
+ image_path = c["image"]
63
+
64
+ print("image_path", image_path)
65
+ new_image_path = image_path.replace(
66
+ "?", "%3F"
67
+ ) # Escape ? in the path
68
+
69
+ # Read the image file and convert to base64
70
+ try:
71
+ base64_image, mime_type = get_image_base64_and_mime(
72
+ new_image_path
73
+ )
74
+ if base64_image is None:
75
+ print(
76
+ f"Warning: Could not convert image to base64: {new_image_path}"
77
+ )
78
+ continue
79
+
80
+ # Create the proper image_url structure with base64 data
81
+ processed_content.append(
82
+ {
83
+ "type": "image_url",
84
+ "image_url": {
85
+ "url": f"data:{mime_type};base64,{base64_image}",
86
+ "detail": "high",
87
+ },
88
+ }
89
+ )
90
+
91
+ except FileNotFoundError:
92
+ print(f"Warning: Image file not found: {new_image_path}")
93
+ continue
94
+ except Exception as e:
95
+ print(f"Warning: Error processing image {new_image_path}: {e}")
96
+ continue
97
+ else:
98
+ processed_content.append(c)
99
+
100
+ processed_message["content"] = processed_content
101
+ processed_messages.append(processed_message)
102
+
103
+ # Create OpenAI client with custom base URL
104
+ client = OpenAI(api_key=api_key, base_url=server_url)
105
+
106
+ try:
107
+ print(f"🔍 Calling model {model}...")
108
+ response = client.chat.completions.create(
109
+ model=model,
110
+ messages=processed_messages,
111
+ max_completion_tokens=max_tokens,
112
+ n=1,
113
+ )
114
+ # print(f"Received response: {response.choices[0].message}")
115
+
116
+ # Extract the response content
117
+ if response.choices and len(response.choices) > 0:
118
+ return response.choices[0].message.content
119
+ else:
120
+ print(f"Unexpected response format: {response}")
121
+ return None
122
+
123
+ except Exception as e:
124
+ print(f"Request failed: {e}")
125
+ return None
126
+
127
+
128
+ def send_direct_request(
129
+ llm: Any,
130
+ messages: list[dict[str, Any]],
131
+ sampling_params: Any,
132
+ ) -> Optional[str]:
133
+ """
134
+ Run inference on a vLLM model instance directly without using a server.
135
+
136
+ Args:
137
+ llm: Initialized vLLM LLM instance (passed from external initialization)
138
+ messages: List of message dicts with role and content (OpenAI format)
139
+ sampling_params: vLLM SamplingParams instance (initialized externally)
140
+
141
+ Returns:
142
+ str: Generated response text, or None if inference fails
143
+ """
144
+ try:
145
+ # Process messages to handle images (convert to base64 if needed)
146
+ processed_messages = []
147
+ for message in messages:
148
+ processed_message = message.copy()
149
+ if message["role"] == "user" and "content" in message:
150
+ processed_content = []
151
+ for c in message["content"]:
152
+ if isinstance(c, dict) and c.get("type") == "image":
153
+ # Convert image path to base64 format
154
+ image_path = c["image"]
155
+ new_image_path = image_path.replace("?", "%3F")
156
+
157
+ try:
158
+ base64_image, mime_type = get_image_base64_and_mime(
159
+ new_image_path
160
+ )
161
+ if base64_image is None:
162
+ print(
163
+ f"Warning: Could not convert image: {new_image_path}"
164
+ )
165
+ continue
166
+
167
+ # vLLM expects image_url format
168
+ processed_content.append(
169
+ {
170
+ "type": "image_url",
171
+ "image_url": {
172
+ "url": f"data:{mime_type};base64,{base64_image}"
173
+ },
174
+ }
175
+ )
176
+ except Exception as e:
177
+ print(
178
+ f"Warning: Error processing image {new_image_path}: {e}"
179
+ )
180
+ continue
181
+ else:
182
+ processed_content.append(c)
183
+
184
+ processed_message["content"] = processed_content
185
+ processed_messages.append(processed_message)
186
+
187
+ print("🔍 Running direct inference with vLLM...")
188
+
189
+ # Run inference using vLLM's chat interface
190
+ outputs = llm.chat(
191
+ messages=processed_messages,
192
+ sampling_params=sampling_params,
193
+ )
194
+
195
+ # Extract the generated text from the first output
196
+ if outputs and len(outputs) > 0:
197
+ generated_text = outputs[0].outputs[0].text
198
+ return generated_text
199
+ else:
200
+ print(f"Unexpected output format: {outputs}")
201
+ return None
202
+
203
+ except Exception as e:
204
+ print(f"Direct inference failed: {e}")
205
+ return None
sam3/agent/client_sam3.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import json
4
+ import os
5
+
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from sam3.model.box_ops import box_xyxy_to_xywh
10
+ from sam3.train.masks_ops import rle_encode
11
+
12
+ from .helpers.mask_overlap_removal import remove_overlapping_masks
13
+ from .viz import visualize
14
+
15
+
16
+ def sam3_inference(processor, image_path, text_prompt):
17
+ """Run SAM 3 image inference with text prompts and format the outputs"""
18
+ image = Image.open(image_path)
19
+ orig_img_w, orig_img_h = image.size
20
+
21
+ # model inference
22
+ inference_state = processor.set_image(image)
23
+ inference_state = processor.set_text_prompt(
24
+ state=inference_state, prompt=text_prompt
25
+ )
26
+
27
+ # format and assemble outputs
28
+ pred_boxes_xyxy = torch.stack(
29
+ [
30
+ inference_state["boxes"][:, 0] / orig_img_w,
31
+ inference_state["boxes"][:, 1] / orig_img_h,
32
+ inference_state["boxes"][:, 2] / orig_img_w,
33
+ inference_state["boxes"][:, 3] / orig_img_h,
34
+ ],
35
+ dim=-1,
36
+ ) # normalized in range [0, 1]
37
+ pred_boxes_xywh = box_xyxy_to_xywh(pred_boxes_xyxy).tolist()
38
+ pred_masks = rle_encode(inference_state["masks"].squeeze(1))
39
+ pred_masks = [m["counts"] for m in pred_masks]
40
+ outputs = {
41
+ "orig_img_h": orig_img_h,
42
+ "orig_img_w": orig_img_w,
43
+ "pred_boxes": pred_boxes_xywh,
44
+ "pred_masks": pred_masks,
45
+ "pred_scores": inference_state["scores"].tolist(),
46
+ }
47
+ return outputs
48
+
49
+
50
+ def call_sam_service(
51
+ sam3_processor,
52
+ image_path: str,
53
+ text_prompt: str,
54
+ output_folder_path: str = "sam3_output",
55
+ ):
56
+ """
57
+ Loads an image, sends it with a text prompt to the service,
58
+ saves the results, and renders the visualization.
59
+ """
60
+ print(f"📞 Loading image '{image_path}' and sending with prompt '{text_prompt}'...")
61
+
62
+ text_prompt_for_save_path = (
63
+ text_prompt.replace("/", "_") if "/" in text_prompt else text_prompt
64
+ )
65
+
66
+ os.makedirs(
67
+ os.path.join(output_folder_path, image_path.replace("/", "-")), exist_ok=True
68
+ )
69
+ output_json_path = os.path.join(
70
+ output_folder_path,
71
+ image_path.replace("/", "-"),
72
+ rf"{text_prompt_for_save_path}.json",
73
+ )
74
+ output_image_path = os.path.join(
75
+ output_folder_path,
76
+ image_path.replace("/", "-"),
77
+ rf"{text_prompt_for_save_path}.png",
78
+ )
79
+
80
+ try:
81
+ # Send the image and text prompt as a multipart/form-data request
82
+ serialized_response = sam3_inference(sam3_processor, image_path, text_prompt)
83
+
84
+ # 1. Prepare the response dictionary
85
+ serialized_response = remove_overlapping_masks(serialized_response)
86
+ serialized_response = {
87
+ "original_image_path": image_path,
88
+ "output_image_path": output_image_path,
89
+ **serialized_response,
90
+ }
91
+
92
+ # 2. Reorder predictions by scores (highest to lowest) if scores are available
93
+ if "pred_scores" in serialized_response and serialized_response["pred_scores"]:
94
+ # Create indices sorted by scores in descending order
95
+ score_indices = sorted(
96
+ range(len(serialized_response["pred_scores"])),
97
+ key=lambda i: serialized_response["pred_scores"][i],
98
+ reverse=True,
99
+ )
100
+
101
+ # Reorder all three lists based on the sorted indices
102
+ serialized_response["pred_scores"] = [
103
+ serialized_response["pred_scores"][i] for i in score_indices
104
+ ]
105
+ serialized_response["pred_boxes"] = [
106
+ serialized_response["pred_boxes"][i] for i in score_indices
107
+ ]
108
+ serialized_response["pred_masks"] = [
109
+ serialized_response["pred_masks"][i] for i in score_indices
110
+ ]
111
+
112
+ # 3. Remove any invalid RLE masks that is too short (shorter than 5 characters)
113
+ valid_masks = []
114
+ valid_boxes = []
115
+ valid_scores = []
116
+ for i, rle in enumerate(serialized_response["pred_masks"]):
117
+ if len(rle) > 4:
118
+ valid_masks.append(rle)
119
+ valid_boxes.append(serialized_response["pred_boxes"][i])
120
+ valid_scores.append(serialized_response["pred_scores"][i])
121
+ serialized_response["pred_masks"] = valid_masks
122
+ serialized_response["pred_boxes"] = valid_boxes
123
+ serialized_response["pred_scores"] = valid_scores
124
+
125
+ with open(output_json_path, "w") as f:
126
+ json.dump(serialized_response, f, indent=4)
127
+ print(f"✅ Raw JSON response saved to '{output_json_path}'")
128
+
129
+ # 4. Render and save visualizations on the image and save it in the SAM3 output folder
130
+ print("🔍 Rendering visualizations on the image ...")
131
+ viz_image = visualize(serialized_response)
132
+ os.makedirs(os.path.dirname(output_image_path), exist_ok=True)
133
+ viz_image.save(output_image_path)
134
+ print("✅ Saved visualization at:", output_image_path)
135
+ except Exception as e:
136
+ print(f"❌ Error calling service: {e}")
137
+
138
+ return output_json_path
sam3/agent/helpers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
sam3/agent/helpers/boxes.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import math
4
+ from enum import IntEnum, unique
5
+ from typing import List, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import device
10
+
11
+ _RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
12
+
13
+
14
+ @unique
15
+ class BoxMode(IntEnum):
16
+ """
17
+ Enum of different ways to represent a box.
18
+ """
19
+
20
+ XYXY_ABS = 0
21
+ """
22
+ (x0, y0, x1, y1) in absolute floating points coordinates.
23
+ The coordinates in range [0, width or height].
24
+ """
25
+ XYWH_ABS = 1
26
+ """
27
+ (x0, y0, w, h) in absolute floating points coordinates.
28
+ """
29
+ XYXY_REL = 2
30
+ """
31
+ Not yet supported!
32
+ (x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
33
+ """
34
+ XYWH_REL = 3
35
+ """
36
+ Not yet supported!
37
+ (x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
38
+ """
39
+ XYWHA_ABS = 4
40
+ """
41
+ (xc, yc, w, h, a) in absolute floating points coordinates.
42
+ (xc, yc) is the center of the rotated box, and the angle a is in degrees ccw.
43
+ """
44
+
45
+ @staticmethod
46
+ def convert(
47
+ box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode"
48
+ ) -> _RawBoxType:
49
+ """
50
+ Args:
51
+ box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5
52
+ from_mode, to_mode (BoxMode)
53
+
54
+ Returns:
55
+ The converted box of the same type.
56
+ """
57
+ if from_mode == to_mode:
58
+ return box
59
+
60
+ original_type = type(box)
61
+ is_numpy = isinstance(box, np.ndarray)
62
+ single_box = isinstance(box, (list, tuple))
63
+ if single_box:
64
+ assert len(box) == 4 or len(box) == 5, (
65
+ "BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor,"
66
+ " where k == 4 or 5"
67
+ )
68
+ arr = torch.tensor(box)[None, :]
69
+ else:
70
+ # avoid modifying the input box
71
+ if is_numpy:
72
+ arr = torch.from_numpy(np.asarray(box)).clone()
73
+ else:
74
+ arr = box.clone()
75
+
76
+ assert to_mode not in [
77
+ BoxMode.XYXY_REL,
78
+ BoxMode.XYWH_REL,
79
+ ] and from_mode not in [
80
+ BoxMode.XYXY_REL,
81
+ BoxMode.XYWH_REL,
82
+ ], "Relative mode not yet supported!"
83
+
84
+ if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
85
+ assert (
86
+ arr.shape[-1] == 5
87
+ ), "The last dimension of input shape must be 5 for XYWHA format"
88
+ original_dtype = arr.dtype
89
+ arr = arr.double()
90
+
91
+ w = arr[:, 2]
92
+ h = arr[:, 3]
93
+ a = arr[:, 4]
94
+ c = torch.abs(torch.cos(a * math.pi / 180.0))
95
+ s = torch.abs(torch.sin(a * math.pi / 180.0))
96
+ # This basically computes the horizontal bounding rectangle of the rotated box
97
+ new_w = c * w + s * h
98
+ new_h = c * h + s * w
99
+
100
+ # convert center to top-left corner
101
+ arr[:, 0] -= new_w / 2.0
102
+ arr[:, 1] -= new_h / 2.0
103
+ # bottom-right corner
104
+ arr[:, 2] = arr[:, 0] + new_w
105
+ arr[:, 3] = arr[:, 1] + new_h
106
+
107
+ arr = arr[:, :4].to(dtype=original_dtype)
108
+ elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS:
109
+ original_dtype = arr.dtype
110
+ arr = arr.double()
111
+ arr[:, 0] += arr[:, 2] / 2.0
112
+ arr[:, 1] += arr[:, 3] / 2.0
113
+ angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype)
114
+ arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype)
115
+ else:
116
+ if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
117
+ arr[:, 2] += arr[:, 0]
118
+ arr[:, 3] += arr[:, 1]
119
+ elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
120
+ arr[:, 2] -= arr[:, 0]
121
+ arr[:, 3] -= arr[:, 1]
122
+ else:
123
+ raise NotImplementedError(
124
+ "Conversion from BoxMode {} to {} is not supported yet".format(
125
+ from_mode, to_mode
126
+ )
127
+ )
128
+
129
+ if single_box:
130
+ return original_type(arr.flatten().tolist())
131
+ if is_numpy:
132
+ return arr.numpy()
133
+ else:
134
+ return arr
135
+
136
+
137
+ class Boxes:
138
+ """
139
+ This structure stores a list of boxes as a Nx4 torch.Tensor.
140
+ It supports some common methods about boxes
141
+ (`area`, `clip`, `nonempty`, etc),
142
+ and also behaves like a Tensor
143
+ (support indexing, `to(device)`, `.device`, and iteration over all boxes)
144
+
145
+ Attributes:
146
+ tensor (torch.Tensor): float matrix of Nx4. Each row is (x1, y1, x2, y2).
147
+ """
148
+
149
+ def __init__(self, tensor: torch.Tensor):
150
+ """
151
+ Args:
152
+ tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2).
153
+ """
154
+ if not isinstance(tensor, torch.Tensor):
155
+ tensor = torch.as_tensor(
156
+ tensor, dtype=torch.float32, device=torch.device("cpu")
157
+ )
158
+ else:
159
+ tensor = tensor.to(torch.float32)
160
+ if tensor.numel() == 0:
161
+ # Use reshape, so we don't end up creating a new tensor that does not depend on
162
+ # the inputs (and consequently confuses jit)
163
+ tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32)
164
+ assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
165
+
166
+ self.tensor = tensor
167
+
168
+ def clone(self) -> "Boxes":
169
+ """
170
+ Clone the Boxes.
171
+
172
+ Returns:
173
+ Boxes
174
+ """
175
+ return Boxes(self.tensor.clone())
176
+
177
+ def to(self, device: torch.device):
178
+ # Boxes are assumed float32 and does not support to(dtype)
179
+ return Boxes(self.tensor.to(device=device))
180
+
181
+ def area(self) -> torch.Tensor:
182
+ """
183
+ Computes the area of all the boxes.
184
+
185
+ Returns:
186
+ torch.Tensor: a vector with areas of each box.
187
+ """
188
+ box = self.tensor
189
+ area = (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
190
+ return area
191
+
192
+ def clip(self, box_size: Tuple[int, int]) -> None:
193
+ """
194
+ Clip (in place) the boxes by limiting x coordinates to the range [0, width]
195
+ and y coordinates to the range [0, height].
196
+
197
+ Args:
198
+ box_size (height, width): The clipping box's size.
199
+ """
200
+ assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
201
+ h, w = box_size
202
+ x1 = self.tensor[:, 0].clamp(min=0, max=w)
203
+ y1 = self.tensor[:, 1].clamp(min=0, max=h)
204
+ x2 = self.tensor[:, 2].clamp(min=0, max=w)
205
+ y2 = self.tensor[:, 3].clamp(min=0, max=h)
206
+ self.tensor = torch.stack((x1, y1, x2, y2), dim=-1)
207
+
208
+ def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
209
+ """
210
+ Find boxes that are non-empty.
211
+ A box is considered empty, if either of its side is no larger than threshold.
212
+
213
+ Returns:
214
+ Tensor:
215
+ a binary vector which represents whether each box is empty
216
+ (False) or non-empty (True).
217
+ """
218
+ box = self.tensor
219
+ widths = box[:, 2] - box[:, 0]
220
+ heights = box[:, 3] - box[:, 1]
221
+ keep = (widths > threshold) & (heights > threshold)
222
+ return keep
223
+
224
+ def __getitem__(self, item) -> "Boxes":
225
+ """
226
+ Args:
227
+ item: int, slice, or a BoolTensor
228
+
229
+ Returns:
230
+ Boxes: Create a new :class:`Boxes` by indexing.
231
+
232
+ The following usage are allowed:
233
+
234
+ 1. `new_boxes = boxes[3]`: return a `Boxes` which contains only one box.
235
+ 2. `new_boxes = boxes[2:10]`: return a slice of boxes.
236
+ 3. `new_boxes = boxes[vector]`, where vector is a torch.BoolTensor
237
+ with `length = len(boxes)`. Nonzero elements in the vector will be selected.
238
+
239
+ Note that the returned Boxes might share storage with this Boxes,
240
+ subject to Pytorch's indexing semantics.
241
+ """
242
+ if isinstance(item, int):
243
+ return Boxes(self.tensor[item].view(1, -1))
244
+ b = self.tensor[item]
245
+ assert (
246
+ b.dim() == 2
247
+ ), "Indexing on Boxes with {} failed to return a matrix!".format(item)
248
+ return Boxes(b)
249
+
250
+ def __len__(self) -> int:
251
+ return self.tensor.shape[0]
252
+
253
+ def __repr__(self) -> str:
254
+ return "Boxes(" + str(self.tensor) + ")"
255
+
256
+ def inside_box(
257
+ self, box_size: Tuple[int, int], boundary_threshold: int = 0
258
+ ) -> torch.Tensor:
259
+ """
260
+ Args:
261
+ box_size (height, width): Size of the reference box.
262
+ boundary_threshold (int): Boxes that extend beyond the reference box
263
+ boundary by more than boundary_threshold are considered "outside".
264
+
265
+ Returns:
266
+ a binary vector, indicating whether each box is inside the reference box.
267
+ """
268
+ height, width = box_size
269
+ inds_inside = (
270
+ (self.tensor[..., 0] >= -boundary_threshold)
271
+ & (self.tensor[..., 1] >= -boundary_threshold)
272
+ & (self.tensor[..., 2] < width + boundary_threshold)
273
+ & (self.tensor[..., 3] < height + boundary_threshold)
274
+ )
275
+ return inds_inside
276
+
277
+ def get_centers(self) -> torch.Tensor:
278
+ """
279
+ Returns:
280
+ The box centers in a Nx2 array of (x, y).
281
+ """
282
+ return (self.tensor[:, :2] + self.tensor[:, 2:]) / 2
283
+
284
+ def scale(self, scale_x: float, scale_y: float) -> None:
285
+ """
286
+ Scale the box with horizontal and vertical scaling factors
287
+ """
288
+ self.tensor[:, 0::2] *= scale_x
289
+ self.tensor[:, 1::2] *= scale_y
290
+
291
+ @classmethod
292
+ def cat(cls, boxes_list: List["Boxes"]) -> "Boxes":
293
+ """
294
+ Concatenates a list of Boxes into a single Boxes
295
+
296
+ Arguments:
297
+ boxes_list (list[Boxes])
298
+
299
+ Returns:
300
+ Boxes: the concatenated Boxes
301
+ """
302
+ assert isinstance(boxes_list, (list, tuple))
303
+ if len(boxes_list) == 0:
304
+ return cls(torch.empty(0))
305
+ assert all([isinstance(box, Boxes) for box in boxes_list])
306
+
307
+ # use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input
308
+ cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0))
309
+ return cat_boxes
310
+
311
+ @property
312
+ def device(self) -> device:
313
+ return self.tensor.device
314
+
315
+ # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript
316
+ # https://github.com/pytorch/pytorch/issues/18627
317
+ @torch.jit.unused
318
+ def __iter__(self):
319
+ """
320
+ Yield a box as a Tensor of shape (4,) at a time.
321
+ """
322
+ yield from self.tensor
323
+
324
+
325
+ def pairwise_intersection(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
326
+ """
327
+ Given two lists of boxes of size N and M,
328
+ compute the intersection area between __all__ N x M pairs of boxes.
329
+ The box order must be (xmin, ymin, xmax, ymax)
330
+
331
+ Args:
332
+ boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
333
+
334
+ Returns:
335
+ Tensor: intersection, sized [N,M].
336
+ """
337
+ boxes1, boxes2 = boxes1.tensor, boxes2.tensor
338
+ width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max(
339
+ boxes1[:, None, :2], boxes2[:, :2]
340
+ ) # [N,M,2]
341
+
342
+ width_height.clamp_(min=0) # [N,M,2]
343
+ intersection = width_height.prod(dim=2) # [N,M]
344
+ return intersection
345
+
346
+
347
+ # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
348
+ # with slight modifications
349
+ def pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
350
+ """
351
+ Given two lists of boxes of size N and M, compute the IoU
352
+ (intersection over union) between **all** N x M pairs of boxes.
353
+ The box order must be (xmin, ymin, xmax, ymax).
354
+
355
+ Args:
356
+ boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
357
+
358
+ Returns:
359
+ Tensor: IoU, sized [N,M].
360
+ """
361
+ area1 = boxes1.area() # [N]
362
+ area2 = boxes2.area() # [M]
363
+ inter = pairwise_intersection(boxes1, boxes2)
364
+
365
+ # handle empty boxes
366
+ iou = torch.where(
367
+ inter > 0,
368
+ inter / (area1[:, None] + area2 - inter),
369
+ torch.zeros(1, dtype=inter.dtype, device=inter.device),
370
+ )
371
+ return iou
372
+
373
+
374
+ def pairwise_ioa(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
375
+ """
376
+ Similar to :func:`pariwise_iou` but compute the IoA (intersection over boxes2 area).
377
+
378
+ Args:
379
+ boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
380
+
381
+ Returns:
382
+ Tensor: IoA, sized [N,M].
383
+ """
384
+ area2 = boxes2.area() # [M]
385
+ inter = pairwise_intersection(boxes1, boxes2)
386
+
387
+ # handle empty boxes
388
+ ioa = torch.where(
389
+ inter > 0, inter / area2, torch.zeros(1, dtype=inter.dtype, device=inter.device)
390
+ )
391
+ return ioa
392
+
393
+
394
+ def pairwise_point_box_distance(points: torch.Tensor, boxes: Boxes):
395
+ """
396
+ Pairwise distance between N points and M boxes. The distance between a
397
+ point and a box is represented by the distance from the point to 4 edges
398
+ of the box. Distances are all positive when the point is inside the box.
399
+
400
+ Args:
401
+ points: Nx2 coordinates. Each row is (x, y)
402
+ boxes: M boxes
403
+
404
+ Returns:
405
+ Tensor: distances of size (N, M, 4). The 4 values are distances from
406
+ the point to the left, top, right, bottom of the box.
407
+ """
408
+ x, y = points.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
409
+ x0, y0, x1, y1 = boxes.tensor.unsqueeze(dim=0).unbind(dim=2) # (1, M)
410
+ return torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2)
411
+
412
+
413
+ def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
414
+ """
415
+ Compute pairwise intersection over union (IOU) of two sets of matched
416
+ boxes that have the same number of boxes.
417
+ Similar to :func:`pairwise_iou`, but computes only diagonal elements of the matrix.
418
+
419
+ Args:
420
+ boxes1 (Boxes): bounding boxes, sized [N,4].
421
+ boxes2 (Boxes): same length as boxes1
422
+ Returns:
423
+ Tensor: iou, sized [N].
424
+ """
425
+ assert len(boxes1) == len(boxes2), (
426
+ "boxlists should have the same" "number of entries, got {}, {}".format(
427
+ len(boxes1), len(boxes2)
428
+ )
429
+ )
430
+ area1 = boxes1.area() # [N]
431
+ area2 = boxes2.area() # [N]
432
+ box1, box2 = boxes1.tensor, boxes2.tensor
433
+ lt = torch.max(box1[:, :2], box2[:, :2]) # [N,2]
434
+ rb = torch.min(box1[:, 2:], box2[:, 2:]) # [N,2]
435
+ wh = (rb - lt).clamp(min=0) # [N,2]
436
+ inter = wh[:, 0] * wh[:, 1] # [N]
437
+ iou = inter / (area1 + area2 - inter) # [N]
438
+ return iou
sam3/agent/helpers/color_map.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ """
4
+ An awesome colormap for really neat visualizations.
5
+ Copied from Detectron, and removed gray colors.
6
+ """
7
+
8
+ import random
9
+
10
+ import numpy as np
11
+
12
+ __all__ = ["colormap", "random_color", "random_colors"]
13
+
14
+
15
+ # A list of 25 bright and sharp colors for segmentation masks,
16
+ # generated from the edges of the sRGB color space for maximum intensity.
17
+ _COLORS = (
18
+ np.array(
19
+ [
20
+ # The original 8 sharp colors
21
+ 1.000,
22
+ 1.000,
23
+ 0.000, # 1. Yellow
24
+ 0.000,
25
+ 1.000,
26
+ 0.000, # 2. Lime
27
+ 0.000,
28
+ 1.000,
29
+ 1.000, # 3. Cyan
30
+ 1.000,
31
+ 0.000,
32
+ 1.000, # 4. Magenta
33
+ 1.000,
34
+ 0.000,
35
+ 0.000, # 5. Red
36
+ 1.000,
37
+ 0.498,
38
+ 0.000, # 6. Orange
39
+ 0.498,
40
+ 1.000,
41
+ 0.000, # 7. Chartreuse
42
+ 0.000,
43
+ 1.000,
44
+ 0.498, # 8. Spring Green
45
+ 1.000,
46
+ 0.000,
47
+ 0.498, # 9. Rose
48
+ 0.498,
49
+ 0.000,
50
+ 1.000, # 10. Violet
51
+ 0.753,
52
+ 1.000,
53
+ 0.000, # 11. Electric Lime
54
+ 1.000,
55
+ 0.753,
56
+ 0.000, # 12. Vivid Orange
57
+ 0.000,
58
+ 1.000,
59
+ 0.753, # 13. Turquoise
60
+ 0.753,
61
+ 0.000,
62
+ 1.000, # 14. Bright Violet
63
+ 1.000,
64
+ 0.000,
65
+ 0.753, # 15. Bright Pink
66
+ 1.000,
67
+ 0.251,
68
+ 0.000, # 16. Fiery Orange
69
+ 0.251,
70
+ 1.000,
71
+ 0.000, # 17. Bright Chartreuse
72
+ 0.000,
73
+ 1.000,
74
+ 0.251, # 18. Malachite Green
75
+ 0.251,
76
+ 0.000,
77
+ 1.000, # 19. Deep Violet
78
+ 1.000,
79
+ 0.000,
80
+ 0.251, # 20. Hot Pink
81
+ ]
82
+ )
83
+ .astype(np.float32)
84
+ .reshape(-1, 3)
85
+ )
86
+
87
+
88
+ def colormap(rgb=False, maximum=255):
89
+ """
90
+ Args:
91
+ rgb (bool): whether to return RGB colors or BGR colors.
92
+ maximum (int): either 255 or 1
93
+
94
+ Returns:
95
+ ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1]
96
+ """
97
+ assert maximum in [255, 1], maximum
98
+ c = _COLORS * maximum
99
+ if not rgb:
100
+ c = c[:, ::-1]
101
+ return c
102
+
103
+
104
+ def random_color(rgb=False, maximum=255):
105
+ """
106
+ Args:
107
+ rgb (bool): whether to return RGB colors or BGR colors.
108
+ maximum (int): either 255 or 1
109
+
110
+ Returns:
111
+ ndarray: a vector of 3 numbers
112
+ """
113
+ idx = np.random.randint(0, len(_COLORS))
114
+ ret = _COLORS[idx] * maximum
115
+ if not rgb:
116
+ ret = ret[::-1]
117
+ return ret
118
+
119
+
120
+ def random_colors(N, rgb=False, maximum=255):
121
+ """
122
+ Args:
123
+ N (int): number of unique colors needed
124
+ rgb (bool): whether to return RGB colors or BGR colors.
125
+ maximum (int): either 255 or 1
126
+
127
+ Returns:
128
+ ndarray: a list of random_color
129
+ """
130
+ indices = random.sample(range(len(_COLORS)), N)
131
+ ret = [_COLORS[i] * maximum for i in indices]
132
+ if not rgb:
133
+ ret = [x[::-1] for x in ret]
134
+ return ret
135
+
136
+
137
+ if __name__ == "__main__":
138
+ import cv2
139
+
140
+ size = 100
141
+ H, W = 10, 10
142
+ canvas = np.random.rand(H * size, W * size, 3).astype("float32")
143
+ for h in range(H):
144
+ for w in range(W):
145
+ idx = h * W + w
146
+ if idx >= len(_COLORS):
147
+ break
148
+ canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx]
149
+ cv2.imshow("a", canvas)
150
+ cv2.waitKey(0)
sam3/agent/helpers/keypoints.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ from typing import Any, List, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+
10
+ class Keypoints:
11
+ """
12
+ Stores keypoint **annotation** data. GT Instances have a `gt_keypoints` property
13
+ containing the x,y location and visibility flag of each keypoint. This tensor has shape
14
+ (N, K, 3) where N is the number of instances and K is the number of keypoints per instance.
15
+
16
+ The visibility flag follows the COCO format and must be one of three integers:
17
+
18
+ * v=0: not labeled (in which case x=y=0)
19
+ * v=1: labeled but not visible
20
+ * v=2: labeled and visible
21
+ """
22
+
23
+ def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]):
24
+ """
25
+ Arguments:
26
+ keypoints: A Tensor, numpy array, or list of the x, y, and visibility of each keypoint.
27
+ The shape should be (N, K, 3) where N is the number of
28
+ instances, and K is the number of keypoints per instance.
29
+ """
30
+ device = (
31
+ keypoints.device
32
+ if isinstance(keypoints, torch.Tensor)
33
+ else torch.device("cpu")
34
+ )
35
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device)
36
+ assert keypoints.dim() == 3 and keypoints.shape[2] == 3, keypoints.shape
37
+ self.tensor = keypoints
38
+
39
+ def __len__(self) -> int:
40
+ return self.tensor.size(0)
41
+
42
+ def to(self, *args: Any, **kwargs: Any) -> "Keypoints":
43
+ return type(self)(self.tensor.to(*args, **kwargs))
44
+
45
+ @property
46
+ def device(self) -> torch.device:
47
+ return self.tensor.device
48
+
49
+ def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor:
50
+ """
51
+ Convert keypoint annotations to a heatmap of one-hot labels for training,
52
+ as described in :paper:`Mask R-CNN`.
53
+
54
+ Arguments:
55
+ boxes: Nx4 tensor, the boxes to draw the keypoints to
56
+
57
+ Returns:
58
+ heatmaps:
59
+ A tensor of shape (N, K), each element is integer spatial label
60
+ in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
61
+ valid:
62
+ A tensor of shape (N, K) containing whether each keypoint is in the roi or not.
63
+ """
64
+ return _keypoints_to_heatmap(self.tensor, boxes, heatmap_size)
65
+
66
+ def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints":
67
+ """
68
+ Create a new `Keypoints` by indexing on this `Keypoints`.
69
+
70
+ The following usage are allowed:
71
+
72
+ 1. `new_kpts = kpts[3]`: return a `Keypoints` which contains only one instance.
73
+ 2. `new_kpts = kpts[2:10]`: return a slice of key points.
74
+ 3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor
75
+ with `length = len(kpts)`. Nonzero elements in the vector will be selected.
76
+
77
+ Note that the returned Keypoints might share storage with this Keypoints,
78
+ subject to Pytorch's indexing semantics.
79
+ """
80
+ if isinstance(item, int):
81
+ return Keypoints([self.tensor[item]])
82
+ return Keypoints(self.tensor[item])
83
+
84
+ def __repr__(self) -> str:
85
+ s = self.__class__.__name__ + "("
86
+ s += "num_instances={})".format(len(self.tensor))
87
+ return s
88
+
89
+ @staticmethod
90
+ def cat(keypoints_list: List["Keypoints"]) -> "Keypoints":
91
+ """
92
+ Concatenates a list of Keypoints into a single Keypoints
93
+
94
+ Arguments:
95
+ keypoints_list (list[Keypoints])
96
+
97
+ Returns:
98
+ Keypoints: the concatenated Keypoints
99
+ """
100
+ assert isinstance(keypoints_list, (list, tuple))
101
+ assert len(keypoints_list) > 0
102
+ assert all(isinstance(keypoints, Keypoints) for keypoints in keypoints_list)
103
+
104
+ cat_kpts = type(keypoints_list[0])(
105
+ torch.cat([kpts.tensor for kpts in keypoints_list], dim=0)
106
+ )
107
+ return cat_kpts
108
+
109
+
110
+ def _keypoints_to_heatmap(
111
+ keypoints: torch.Tensor, rois: torch.Tensor, heatmap_size: int
112
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
113
+ """
114
+ Encode keypoint locations into a target heatmap for use in SoftmaxWithLoss across space.
115
+
116
+ Maps keypoints from the half-open interval [x1, x2) on continuous image coordinates to the
117
+ closed interval [0, heatmap_size - 1] on discrete image coordinates. We use the
118
+ continuous-discrete conversion from Heckbert 1990 ("What is the coordinate of a pixel?"):
119
+ d = floor(c) and c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
120
+
121
+ Arguments:
122
+ keypoints: tensor of keypoint locations in of shape (N, K, 3).
123
+ rois: Nx4 tensor of rois in xyxy format
124
+ heatmap_size: integer side length of square heatmap.
125
+
126
+ Returns:
127
+ heatmaps: A tensor of shape (N, K) containing an integer spatial label
128
+ in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
129
+ valid: A tensor of shape (N, K) containing whether each keypoint is in
130
+ the roi or not.
131
+ """
132
+
133
+ if rois.numel() == 0:
134
+ return rois.new().long(), rois.new().long()
135
+ offset_x = rois[:, 0]
136
+ offset_y = rois[:, 1]
137
+ scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
138
+ scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
139
+
140
+ offset_x = offset_x[:, None]
141
+ offset_y = offset_y[:, None]
142
+ scale_x = scale_x[:, None]
143
+ scale_y = scale_y[:, None]
144
+
145
+ x = keypoints[..., 0]
146
+ y = keypoints[..., 1]
147
+
148
+ x_boundary_inds = x == rois[:, 2][:, None]
149
+ y_boundary_inds = y == rois[:, 3][:, None]
150
+
151
+ x = (x - offset_x) * scale_x
152
+ x = x.floor().long()
153
+ y = (y - offset_y) * scale_y
154
+ y = y.floor().long()
155
+
156
+ x[x_boundary_inds] = heatmap_size - 1
157
+ y[y_boundary_inds] = heatmap_size - 1
158
+
159
+ valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
160
+ vis = keypoints[..., 2] > 0
161
+ valid = (valid_loc & vis).long()
162
+
163
+ lin_ind = y * heatmap_size + x
164
+ heatmaps = lin_ind * valid
165
+
166
+ return heatmaps, valid
167
+
168
+
169
+ @torch.jit.script_if_tracing
170
+ def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
171
+ """
172
+ Extract predicted keypoint locations from heatmaps.
173
+
174
+ Args:
175
+ maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for
176
+ each ROI and each keypoint.
177
+ rois (Tensor): (#ROIs, 4). The box of each ROI.
178
+
179
+ Returns:
180
+ Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to
181
+ (x, y, logit, score) for each keypoint.
182
+
183
+ When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate,
184
+ we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from
185
+ Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
186
+ """
187
+
188
+ offset_x = rois[:, 0]
189
+ offset_y = rois[:, 1]
190
+
191
+ widths = (rois[:, 2] - rois[:, 0]).clamp(min=1)
192
+ heights = (rois[:, 3] - rois[:, 1]).clamp(min=1)
193
+ widths_ceil = widths.ceil()
194
+ heights_ceil = heights.ceil()
195
+
196
+ num_rois, num_keypoints = maps.shape[:2]
197
+ xy_preds = maps.new_zeros(rois.shape[0], num_keypoints, 4)
198
+
199
+ width_corrections = widths / widths_ceil
200
+ height_corrections = heights / heights_ceil
201
+
202
+ keypoints_idx = torch.arange(num_keypoints, device=maps.device)
203
+
204
+ for i in range(num_rois):
205
+ outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
206
+ roi_map = F.interpolate(
207
+ maps[[i]], size=outsize, mode="bicubic", align_corners=False
208
+ )
209
+
210
+ # Although semantically equivalent, `reshape` is used instead of `squeeze` due
211
+ # to limitation during ONNX export of `squeeze` in scripting mode
212
+ roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W
213
+
214
+ # softmax over the spatial region
215
+ max_score, _ = roi_map.view(num_keypoints, -1).max(1)
216
+ max_score = max_score.view(num_keypoints, 1, 1)
217
+ tmp_full_resolution = (roi_map - max_score).exp_()
218
+ tmp_pool_resolution = (maps[i] - max_score).exp_()
219
+ # Produce scores over the region H x W, but normalize with POOL_H x POOL_W,
220
+ # so that the scores of objects of different absolute sizes will be more comparable
221
+ roi_map_scores = tmp_full_resolution / tmp_pool_resolution.sum(
222
+ (1, 2), keepdim=True
223
+ )
224
+
225
+ w = roi_map.shape[2]
226
+ pos = roi_map.view(num_keypoints, -1).argmax(1)
227
+
228
+ x_int = pos % w
229
+ y_int = (pos - x_int) // w
230
+
231
+ assert (
232
+ roi_map_scores[keypoints_idx, y_int, x_int]
233
+ == roi_map_scores.view(num_keypoints, -1).max(1)[0]
234
+ ).all()
235
+
236
+ x = (x_int.float() + 0.5) * width_corrections[i]
237
+ y = (y_int.float() + 0.5) * height_corrections[i]
238
+
239
+ xy_preds[i, :, 0] = x + offset_x[i]
240
+ xy_preds[i, :, 1] = y + offset_y[i]
241
+ xy_preds[i, :, 2] = roi_map[keypoints_idx, y_int, x_int]
242
+ xy_preds[i, :, 3] = roi_map_scores[keypoints_idx, y_int, x_int]
243
+
244
+ return xy_preds
sam3/agent/helpers/mask_overlap_removal.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ from typing import Dict, List
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ try:
9
+ from pycocotools import mask as mask_utils
10
+ except Exception:
11
+ mask_utils = None
12
+
13
+
14
+ def mask_intersection(
15
+ masks1: torch.Tensor, masks2: torch.Tensor, block_size: int = 16
16
+ ) -> torch.Tensor:
17
+ assert masks1.shape[1:] == masks2.shape[1:]
18
+ assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
19
+ N, M = masks1.shape[0], masks2.shape[0]
20
+ out = torch.zeros(N, M, device=masks1.device, dtype=torch.long)
21
+ for i in range(0, N, block_size):
22
+ for j in range(0, M, block_size):
23
+ a = masks1[i : i + block_size]
24
+ b = masks2[j : j + block_size]
25
+ inter = (a[:, None] & b[None, :]).flatten(-2).sum(-1)
26
+ out[i : i + block_size, j : j + block_size] = inter
27
+ return out
28
+
29
+
30
+ def mask_iom(masks1: torch.Tensor, masks2: torch.Tensor) -> torch.Tensor:
31
+ assert masks1.shape[1:] == masks2.shape[1:]
32
+ assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
33
+ inter = mask_intersection(masks1, masks2)
34
+ area1 = masks1.flatten(-2).sum(-1) # (N,)
35
+ area2 = masks2.flatten(-2).sum(-1) # (M,)
36
+ min_area = torch.min(area1[:, None], area2[None, :]).clamp_min(1)
37
+ return inter.float() / (min_area.float() + 1e-8)
38
+
39
+
40
+ def _decode_single_mask(mask_repr, h: int, w: int) -> np.ndarray:
41
+ if isinstance(mask_repr, (list, tuple, np.ndarray)):
42
+ arr = np.array(mask_repr)
43
+ if arr.ndim != 2:
44
+ raise ValueError("Mask array must be 2D (H, W).")
45
+ return (arr > 0).astype(np.uint8)
46
+
47
+ if mask_utils is None:
48
+ raise ImportError(
49
+ "pycocotools is required to decode RLE mask strings. pip install pycocotools"
50
+ )
51
+
52
+ if not isinstance(mask_repr, (str, bytes)):
53
+ raise ValueError("Unsupported mask representation type for RLE decode.")
54
+
55
+ rle = {
56
+ "counts": mask_repr if isinstance(mask_repr, (str, bytes)) else str(mask_repr),
57
+ "size": [h, w],
58
+ }
59
+ decoded = mask_utils.decode(rle)
60
+ if decoded.ndim == 3:
61
+ decoded = decoded[:, :, 0]
62
+ return (decoded > 0).astype(np.uint8)
63
+
64
+
65
+ def _decode_masks_to_torch_bool(pred_masks: List, h: int, w: int) -> torch.Tensor:
66
+ bin_masks = [_decode_single_mask(m, h, w) for m in pred_masks]
67
+ masks_np = np.stack(bin_masks, axis=0).astype(np.uint8) # (N, H, W)
68
+ return torch.from_numpy(masks_np > 0)
69
+
70
+
71
+ def remove_overlapping_masks(sample: Dict, iom_thresh: float = 0.3) -> Dict:
72
+ """
73
+ Greedy keep: sort by score desc; keep a mask if IoM to all kept masks <= threshold.
74
+ If pred_masks has length 0 or 1, returns sample unchanged (no extra keys).
75
+ """
76
+ # Basic presence checks
77
+ if "pred_masks" not in sample or not isinstance(sample["pred_masks"], list):
78
+ return sample # nothing to do / preserve as-is
79
+
80
+ pred_masks = sample["pred_masks"]
81
+ N = len(pred_masks)
82
+
83
+ # --- Early exit: 0 or 1 mask -> do NOT modify the JSON at all ---
84
+ if N <= 1:
85
+ return sample
86
+
87
+ # From here on we have at least 2 masks
88
+ h = int(sample["orig_img_h"])
89
+ w = int(sample["orig_img_w"])
90
+ pred_scores = sample.get("pred_scores", [1.0] * N) # fallback if scores missing
91
+ pred_boxes = sample.get("pred_boxes", None)
92
+
93
+ assert N == len(pred_scores), "pred_masks and pred_scores must have same length"
94
+ if pred_boxes is not None:
95
+ assert N == len(pred_boxes), "pred_masks and pred_boxes must have same length"
96
+
97
+ masks_bool = _decode_masks_to_torch_bool(pred_masks, h, w) # (N, H, W)
98
+
99
+ order = sorted(range(N), key=lambda i: float(pred_scores[i]), reverse=True)
100
+ kept_idx: List[int] = []
101
+ kept_masks: List[torch.Tensor] = []
102
+
103
+ for i in order:
104
+ cand = masks_bool[i].unsqueeze(0) # (1, H, W)
105
+ if len(kept_masks) == 0:
106
+ kept_idx.append(i)
107
+ kept_masks.append(masks_bool[i])
108
+ continue
109
+
110
+ kept_stack = torch.stack(kept_masks, dim=0) # (K, H, W)
111
+ iom_vals = mask_iom(cand, kept_stack).squeeze(0) # (K,)
112
+ if torch.any(iom_vals > iom_thresh):
113
+ continue # overlaps too much with a higher-scored kept mask
114
+ kept_idx.append(i)
115
+ kept_masks.append(masks_bool[i])
116
+
117
+ kept_idx_sorted = sorted(kept_idx)
118
+
119
+ # Build filtered JSON (this *does* modify fields; only for N>=2 case)
120
+ out = dict(sample)
121
+ out["pred_masks"] = [pred_masks[i] for i in kept_idx_sorted]
122
+ out["pred_scores"] = [pred_scores[i] for i in kept_idx_sorted]
123
+ if pred_boxes is not None:
124
+ out["pred_boxes"] = [pred_boxes[i] for i in kept_idx_sorted]
125
+ out["kept_indices"] = kept_idx_sorted
126
+ out["removed_indices"] = [i for i in range(N) if i not in set(kept_idx_sorted)]
127
+ out["iom_threshold"] = float(iom_thresh)
128
+ return out
sam3/agent/helpers/masks.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import copy
4
+ import itertools
5
+ from typing import Any, Iterator, List, Union
6
+
7
+ import numpy as np
8
+ import pycocotools.mask as mask_util
9
+ import torch
10
+ from torch import device
11
+
12
+ from .boxes import Boxes
13
+ from .memory import retry_if_cuda_oom
14
+
15
+ from .roi_align import ROIAlign
16
+
17
+
18
+ def polygon_area(x, y):
19
+ # Using the shoelace formula
20
+ # https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
21
+ return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
22
+
23
+
24
+ def polygons_to_bitmask(
25
+ polygons: List[np.ndarray], height: int, width: int
26
+ ) -> np.ndarray:
27
+ """
28
+ Args:
29
+ polygons (list[ndarray]): each array has shape (Nx2,)
30
+ height, width (int)
31
+
32
+ Returns:
33
+ ndarray: a bool mask of shape (height, width)
34
+ """
35
+ if len(polygons) == 0:
36
+ # COCOAPI does not support empty polygons
37
+ return np.zeros((height, width)).astype(bool)
38
+ rles = mask_util.frPyObjects(polygons, height, width)
39
+ rle = mask_util.merge(rles)
40
+ return mask_util.decode(rle).astype(bool)
41
+
42
+
43
+ def rasterize_polygons_within_box(
44
+ polygons: List[np.ndarray], box: np.ndarray, mask_size: int
45
+ ) -> torch.Tensor:
46
+ """
47
+ Rasterize the polygons into a mask image and
48
+ crop the mask content in the given box.
49
+ The cropped mask is resized to (mask_size, mask_size).
50
+
51
+ This function is used when generating training targets for mask head in Mask R-CNN.
52
+ Given original ground-truth masks for an image, new ground-truth mask
53
+ training targets in the size of `mask_size x mask_size`
54
+ must be provided for each predicted box. This function will be called to
55
+ produce such targets.
56
+
57
+ Args:
58
+ polygons (list[ndarray[float]]): a list of polygons, which represents an instance.
59
+ box: 4-element numpy array
60
+ mask_size (int):
61
+
62
+ Returns:
63
+ Tensor: BoolTensor of shape (mask_size, mask_size)
64
+ """
65
+ # 1. Shift the polygons w.r.t the boxes
66
+ w, h = box[2] - box[0], box[3] - box[1]
67
+
68
+ polygons = copy.deepcopy(polygons)
69
+ for p in polygons:
70
+ p[0::2] = p[0::2] - box[0]
71
+ p[1::2] = p[1::2] - box[1]
72
+
73
+ # 2. Rescale the polygons to the new box size
74
+ # max() to avoid division by small number
75
+ ratio_h = mask_size / max(h, 0.1)
76
+ ratio_w = mask_size / max(w, 0.1)
77
+
78
+ if ratio_h == ratio_w:
79
+ for p in polygons:
80
+ p *= ratio_h
81
+ else:
82
+ for p in polygons:
83
+ p[0::2] *= ratio_w
84
+ p[1::2] *= ratio_h
85
+
86
+ # 3. Rasterize the polygons with coco api
87
+ mask = polygons_to_bitmask(polygons, mask_size, mask_size)
88
+ mask = torch.from_numpy(mask)
89
+ return mask
90
+
91
+
92
+ class BitMasks:
93
+ """
94
+ This class stores the segmentation masks for all objects in one image, in
95
+ the form of bitmaps.
96
+
97
+ Attributes:
98
+ tensor: bool Tensor of N,H,W, representing N instances in the image.
99
+ """
100
+
101
+ def __init__(self, tensor: Union[torch.Tensor, np.ndarray]):
102
+ """
103
+ Args:
104
+ tensor: bool Tensor of N,H,W, representing N instances in the image.
105
+ """
106
+ if isinstance(tensor, torch.Tensor):
107
+ tensor = tensor.to(torch.bool)
108
+ else:
109
+ tensor = torch.as_tensor(
110
+ tensor, dtype=torch.bool, device=torch.device("cpu")
111
+ )
112
+ assert tensor.dim() == 3, tensor.size()
113
+ self.image_size = tensor.shape[1:]
114
+ self.tensor = tensor
115
+
116
+ @torch.jit.unused
117
+ def to(self, *args: Any, **kwargs: Any) -> "BitMasks":
118
+ return BitMasks(self.tensor.to(*args, **kwargs))
119
+
120
+ @property
121
+ def device(self) -> torch.device:
122
+ return self.tensor.device
123
+
124
+ @torch.jit.unused
125
+ def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "BitMasks":
126
+ """
127
+ Returns:
128
+ BitMasks: Create a new :class:`BitMasks` by indexing.
129
+
130
+ The following usage are allowed:
131
+
132
+ 1. `new_masks = masks[3]`: return a `BitMasks` which contains only one mask.
133
+ 2. `new_masks = masks[2:10]`: return a slice of masks.
134
+ 3. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
135
+ with `length = len(masks)`. Nonzero elements in the vector will be selected.
136
+
137
+ Note that the returned object might share storage with this object,
138
+ subject to Pytorch's indexing semantics.
139
+ """
140
+ if isinstance(item, int):
141
+ return BitMasks(self.tensor[item].unsqueeze(0))
142
+ m = self.tensor[item]
143
+ assert (
144
+ m.dim() == 3
145
+ ), "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
146
+ item, m.shape
147
+ )
148
+ return BitMasks(m)
149
+
150
+ @torch.jit.unused
151
+ def __iter__(self) -> torch.Tensor:
152
+ yield from self.tensor
153
+
154
+ @torch.jit.unused
155
+ def __repr__(self) -> str:
156
+ s = self.__class__.__name__ + "("
157
+ s += "num_instances={})".format(len(self.tensor))
158
+ return s
159
+
160
+ def __len__(self) -> int:
161
+ return self.tensor.shape[0]
162
+
163
+ def nonempty(self) -> torch.Tensor:
164
+ """
165
+ Find masks that are non-empty.
166
+
167
+ Returns:
168
+ Tensor: a BoolTensor which represents
169
+ whether each mask is empty (False) or non-empty (True).
170
+ """
171
+ return self.tensor.flatten(1).any(dim=1)
172
+
173
+ @staticmethod
174
+ def from_polygon_masks(
175
+ polygon_masks: Union["PolygonMasks", List[List[np.ndarray]]],
176
+ height: int,
177
+ width: int,
178
+ ) -> "BitMasks":
179
+ """
180
+ Args:
181
+ polygon_masks (list[list[ndarray]] or PolygonMasks)
182
+ height, width (int)
183
+ """
184
+ if isinstance(polygon_masks, PolygonMasks):
185
+ polygon_masks = polygon_masks.polygons
186
+ masks = [polygons_to_bitmask(p, height, width) for p in polygon_masks]
187
+ if len(masks):
188
+ return BitMasks(torch.stack([torch.from_numpy(x) for x in masks]))
189
+ else:
190
+ return BitMasks(torch.empty(0, height, width, dtype=torch.bool))
191
+
192
+ @staticmethod
193
+ def from_roi_masks(roi_masks: "ROIMasks", height: int, width: int) -> "BitMasks":
194
+ """
195
+ Args:
196
+ roi_masks:
197
+ height, width (int):
198
+ """
199
+ return roi_masks.to_bitmasks(height, width)
200
+
201
+ def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
202
+ """
203
+ Crop each bitmask by the given box, and resize results to (mask_size, mask_size).
204
+ This can be used to prepare training targets for Mask R-CNN.
205
+ It has less reconstruction error compared to rasterization with polygons.
206
+ However we observe no difference in accuracy,
207
+ but BitMasks requires more memory to store all the masks.
208
+
209
+ Args:
210
+ boxes (Tensor): Nx4 tensor storing the boxes for each mask
211
+ mask_size (int): the size of the rasterized mask.
212
+
213
+ Returns:
214
+ Tensor:
215
+ A bool tensor of shape (N, mask_size, mask_size), where
216
+ N is the number of predicted boxes for this image.
217
+ """
218
+ assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
219
+ device = self.tensor.device
220
+
221
+ batch_inds = torch.arange(len(boxes), device=device).to(dtype=boxes.dtype)[
222
+ :, None
223
+ ]
224
+ rois = torch.cat([batch_inds, boxes], dim=1) # Nx5
225
+
226
+ bit_masks = self.tensor.to(dtype=torch.float32)
227
+ rois = rois.to(device=device)
228
+ output = (
229
+ ROIAlign((mask_size, mask_size), 1.0, 0, aligned=True)
230
+ .forward(bit_masks[:, None, :, :], rois)
231
+ .squeeze(1)
232
+ )
233
+ output = output >= 0.5
234
+ return output
235
+
236
+ def get_bounding_boxes(self) -> Boxes:
237
+ """
238
+ Returns:
239
+ Boxes: tight bounding boxes around bitmasks.
240
+ If a mask is empty, it's bounding box will be all zero.
241
+ """
242
+ boxes = torch.zeros(self.tensor.shape[0], 4, dtype=torch.float32)
243
+ x_any = torch.any(self.tensor, dim=1)
244
+ y_any = torch.any(self.tensor, dim=2)
245
+ for idx in range(self.tensor.shape[0]):
246
+ x = torch.where(x_any[idx, :])[0]
247
+ y = torch.where(y_any[idx, :])[0]
248
+ if len(x) > 0 and len(y) > 0:
249
+ boxes[idx, :] = torch.as_tensor(
250
+ [x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=torch.float32
251
+ )
252
+ return Boxes(boxes)
253
+
254
+ @staticmethod
255
+ def cat(bitmasks_list: List["BitMasks"]) -> "BitMasks":
256
+ """
257
+ Concatenates a list of BitMasks into a single BitMasks
258
+
259
+ Arguments:
260
+ bitmasks_list (list[BitMasks])
261
+
262
+ Returns:
263
+ BitMasks: the concatenated BitMasks
264
+ """
265
+ assert isinstance(bitmasks_list, (list, tuple))
266
+ assert len(bitmasks_list) > 0
267
+ assert all(isinstance(bitmask, BitMasks) for bitmask in bitmasks_list)
268
+
269
+ cat_bitmasks = type(bitmasks_list[0])(
270
+ torch.cat([bm.tensor for bm in bitmasks_list], dim=0)
271
+ )
272
+ return cat_bitmasks
273
+
274
+
275
+ class PolygonMasks:
276
+ """
277
+ This class stores the segmentation masks for all objects in one image, in the form of polygons.
278
+
279
+ Attributes:
280
+ polygons: list[list[ndarray]]. Each ndarray is a float64 vector representing a polygon.
281
+ """
282
+
283
+ def __init__(self, polygons: List[List[Union[torch.Tensor, np.ndarray]]]):
284
+ """
285
+ Arguments:
286
+ polygons (list[list[np.ndarray]]): The first
287
+ level of the list correspond to individual instances,
288
+ the second level to all the polygons that compose the
289
+ instance, and the third level to the polygon coordinates.
290
+ The third level array should have the format of
291
+ [x0, y0, x1, y1, ..., xn, yn] (n >= 3).
292
+ """
293
+ if not isinstance(polygons, list):
294
+ raise ValueError(
295
+ "Cannot create PolygonMasks: Expect a list of list of polygons per image. "
296
+ "Got '{}' instead.".format(type(polygons))
297
+ )
298
+
299
+ def _make_array(t: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
300
+ # Use float64 for higher precision, because why not?
301
+ # Always put polygons on CPU (self.to is a no-op) since they
302
+ # are supposed to be small tensors.
303
+ # May need to change this assumption if GPU placement becomes useful
304
+ if isinstance(t, torch.Tensor):
305
+ t = t.cpu().numpy()
306
+ return np.asarray(t).astype("float64")
307
+
308
+ def process_polygons(
309
+ polygons_per_instance: List[Union[torch.Tensor, np.ndarray]],
310
+ ) -> List[np.ndarray]:
311
+ if not isinstance(polygons_per_instance, list):
312
+ raise ValueError(
313
+ "Cannot create polygons: Expect a list of polygons per instance. "
314
+ "Got '{}' instead.".format(type(polygons_per_instance))
315
+ )
316
+ # transform each polygon to a numpy array
317
+ polygons_per_instance = [_make_array(p) for p in polygons_per_instance]
318
+ for polygon in polygons_per_instance:
319
+ if len(polygon) % 2 != 0 or len(polygon) < 6:
320
+ raise ValueError(
321
+ f"Cannot create a polygon from {len(polygon)} coordinates."
322
+ )
323
+ return polygons_per_instance
324
+
325
+ self.polygons: List[List[np.ndarray]] = [
326
+ process_polygons(polygons_per_instance)
327
+ for polygons_per_instance in polygons
328
+ ]
329
+
330
+ def to(self, *args: Any, **kwargs: Any) -> "PolygonMasks":
331
+ return self
332
+
333
+ @property
334
+ def device(self) -> torch.device:
335
+ return torch.device("cpu")
336
+
337
+ def get_bounding_boxes(self) -> Boxes:
338
+ """
339
+ Returns:
340
+ Boxes: tight bounding boxes around polygon masks.
341
+ """
342
+ boxes = torch.zeros(len(self.polygons), 4, dtype=torch.float32)
343
+ for idx, polygons_per_instance in enumerate(self.polygons):
344
+ minxy = torch.as_tensor([float("inf"), float("inf")], dtype=torch.float32)
345
+ maxxy = torch.zeros(2, dtype=torch.float32)
346
+ for polygon in polygons_per_instance:
347
+ coords = torch.from_numpy(polygon).view(-1, 2).to(dtype=torch.float32)
348
+ minxy = torch.min(minxy, torch.min(coords, dim=0).values)
349
+ maxxy = torch.max(maxxy, torch.max(coords, dim=0).values)
350
+ boxes[idx, :2] = minxy
351
+ boxes[idx, 2:] = maxxy
352
+ return Boxes(boxes)
353
+
354
+ def nonempty(self) -> torch.Tensor:
355
+ """
356
+ Find masks that are non-empty.
357
+
358
+ Returns:
359
+ Tensor:
360
+ a BoolTensor which represents whether each mask is empty (False) or not (True).
361
+ """
362
+ keep = [1 if len(polygon) > 0 else 0 for polygon in self.polygons]
363
+ return torch.from_numpy(np.asarray(keep, dtype=bool))
364
+
365
+ def __getitem__(
366
+ self, item: Union[int, slice, List[int], torch.BoolTensor]
367
+ ) -> "PolygonMasks":
368
+ """
369
+ Support indexing over the instances and return a `PolygonMasks` object.
370
+ `item` can be:
371
+
372
+ 1. An integer. It will return an object with only one instance.
373
+ 2. A slice. It will return an object with the selected instances.
374
+ 3. A list[int]. It will return an object with the selected instances,
375
+ correpsonding to the indices in the list.
376
+ 4. A vector mask of type BoolTensor, whose length is num_instances.
377
+ It will return an object with the instances whose mask is nonzero.
378
+ """
379
+ if isinstance(item, int):
380
+ selected_polygons = [self.polygons[item]]
381
+ elif isinstance(item, slice):
382
+ selected_polygons = self.polygons[item]
383
+ elif isinstance(item, list):
384
+ selected_polygons = [self.polygons[i] for i in item]
385
+ elif isinstance(item, torch.Tensor):
386
+ # Polygons is a list, so we have to move the indices back to CPU.
387
+ if item.dtype == torch.bool:
388
+ assert item.dim() == 1, item.shape
389
+ item = item.nonzero().squeeze(1).cpu().numpy().tolist()
390
+ elif item.dtype in [torch.int32, torch.int64]:
391
+ item = item.cpu().numpy().tolist()
392
+ else:
393
+ raise ValueError(
394
+ "Unsupported tensor dtype={} for indexing!".format(item.dtype)
395
+ )
396
+ selected_polygons = [self.polygons[i] for i in item]
397
+ return PolygonMasks(selected_polygons)
398
+
399
+ def __iter__(self) -> Iterator[List[np.ndarray]]:
400
+ """
401
+ Yields:
402
+ list[ndarray]: the polygons for one instance.
403
+ Each Tensor is a float64 vector representing a polygon.
404
+ """
405
+ return iter(self.polygons)
406
+
407
+ def __repr__(self) -> str:
408
+ s = self.__class__.__name__ + "("
409
+ s += "num_instances={})".format(len(self.polygons))
410
+ return s
411
+
412
+ def __len__(self) -> int:
413
+ return len(self.polygons)
414
+
415
+ def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
416
+ """
417
+ Crop each mask by the given box, and resize results to (mask_size, mask_size).
418
+ This can be used to prepare training targets for Mask R-CNN.
419
+
420
+ Args:
421
+ boxes (Tensor): Nx4 tensor storing the boxes for each mask
422
+ mask_size (int): the size of the rasterized mask.
423
+
424
+ Returns:
425
+ Tensor: A bool tensor of shape (N, mask_size, mask_size), where
426
+ N is the number of predicted boxes for this image.
427
+ """
428
+ assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
429
+
430
+ device = boxes.device
431
+ # Put boxes on the CPU, as the polygon representation is not efficient GPU-wise
432
+ # (several small tensors for representing a single instance mask)
433
+ boxes = boxes.to(torch.device("cpu"))
434
+
435
+ results = [
436
+ rasterize_polygons_within_box(poly, box.numpy(), mask_size)
437
+ for poly, box in zip(self.polygons, boxes)
438
+ ]
439
+ """
440
+ poly: list[list[float]], the polygons for one instance
441
+ box: a tensor of shape (4,)
442
+ """
443
+ if len(results) == 0:
444
+ return torch.empty(0, mask_size, mask_size, dtype=torch.bool, device=device)
445
+ return torch.stack(results, dim=0).to(device=device)
446
+
447
+ def area(self):
448
+ """
449
+ Computes area of the mask.
450
+ Only works with Polygons, using the shoelace formula:
451
+ https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
452
+
453
+ Returns:
454
+ Tensor: a vector, area for each instance
455
+ """
456
+
457
+ area = []
458
+ for polygons_per_instance in self.polygons:
459
+ area_per_instance = 0
460
+ for p in polygons_per_instance:
461
+ area_per_instance += polygon_area(p[0::2], p[1::2])
462
+ area.append(area_per_instance)
463
+
464
+ return torch.tensor(area)
465
+
466
+ @staticmethod
467
+ def cat(polymasks_list: List["PolygonMasks"]) -> "PolygonMasks":
468
+ """
469
+ Concatenates a list of PolygonMasks into a single PolygonMasks
470
+
471
+ Arguments:
472
+ polymasks_list (list[PolygonMasks])
473
+
474
+ Returns:
475
+ PolygonMasks: the concatenated PolygonMasks
476
+ """
477
+ assert isinstance(polymasks_list, (list, tuple))
478
+ assert len(polymasks_list) > 0
479
+ assert all(isinstance(polymask, PolygonMasks) for polymask in polymasks_list)
480
+
481
+ cat_polymasks = type(polymasks_list[0])(
482
+ list(itertools.chain.from_iterable(pm.polygons for pm in polymasks_list))
483
+ )
484
+ return cat_polymasks
485
+
486
+
487
+ class ROIMasks:
488
+ """
489
+ Represent masks by N smaller masks defined in some ROIs. Once ROI boxes are given,
490
+ full-image bitmask can be obtained by "pasting" the mask on the region defined
491
+ by the corresponding ROI box.
492
+ """
493
+
494
+ def __init__(self, tensor: torch.Tensor):
495
+ """
496
+ Args:
497
+ tensor: (N, M, M) mask tensor that defines the mask within each ROI.
498
+ """
499
+ if tensor.dim() != 3:
500
+ raise ValueError("ROIMasks must take a masks of 3 dimension.")
501
+ self.tensor = tensor
502
+
503
+ def to(self, device: torch.device) -> "ROIMasks":
504
+ return ROIMasks(self.tensor.to(device))
505
+
506
+ @property
507
+ def device(self) -> device:
508
+ return self.tensor.device
509
+
510
+ def __len__(self):
511
+ return self.tensor.shape[0]
512
+
513
+ def __getitem__(self, item) -> "ROIMasks":
514
+ """
515
+ Returns:
516
+ ROIMasks: Create a new :class:`ROIMasks` by indexing.
517
+
518
+ The following usage are allowed:
519
+
520
+ 1. `new_masks = masks[2:10]`: return a slice of masks.
521
+ 2. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
522
+ with `length = len(masks)`. Nonzero elements in the vector will be selected.
523
+
524
+ Note that the returned object might share storage with this object,
525
+ subject to Pytorch's indexing semantics.
526
+ """
527
+ t = self.tensor[item]
528
+ if t.dim() != 3:
529
+ raise ValueError(
530
+ f"Indexing on ROIMasks with {item} returns a tensor with shape {t.shape}!"
531
+ )
532
+ return ROIMasks(t)
533
+
534
+ @torch.jit.unused
535
+ def __repr__(self) -> str:
536
+ s = self.__class__.__name__ + "("
537
+ s += "num_instances={})".format(len(self.tensor))
538
+ return s
539
+
540
+ @torch.jit.unused
541
+ def to_bitmasks(self, boxes: torch.Tensor, height, width, threshold=0.5):
542
+ """
543
+ Args: see documentation of :func:`paste_masks_in_image`.
544
+ """
545
+ from detectron2.layers.mask_ops import (
546
+ _paste_masks_tensor_shape,
547
+ paste_masks_in_image,
548
+ )
549
+
550
+ if torch.jit.is_tracing():
551
+ if isinstance(height, torch.Tensor):
552
+ paste_func = _paste_masks_tensor_shape
553
+ else:
554
+ paste_func = paste_masks_in_image
555
+ else:
556
+ paste_func = retry_if_cuda_oom(paste_masks_in_image)
557
+ bitmasks = paste_func(
558
+ self.tensor, boxes.tensor, (height, width), threshold=threshold
559
+ )
560
+ return BitMasks(bitmasks)
sam3/agent/helpers/memory.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from functools import wraps
6
+
7
+ import torch
8
+
9
+ __all__ = ["retry_if_cuda_oom"]
10
+
11
+
12
+ @contextmanager
13
+ def _ignore_torch_cuda_oom():
14
+ """
15
+ A context which ignores CUDA OOM exception from pytorch.
16
+ """
17
+ try:
18
+ yield
19
+ except RuntimeError as e:
20
+ # NOTE: the string may change?
21
+ if "CUDA out of memory. " in str(e):
22
+ pass
23
+ else:
24
+ raise
25
+
26
+
27
+ def retry_if_cuda_oom(func):
28
+ """
29
+ Makes a function retry itself after encountering
30
+ pytorch's CUDA OOM error.
31
+ It will first retry after calling `torch.cuda.empty_cache()`.
32
+
33
+ If that still fails, it will then retry by trying to convert inputs to CPUs.
34
+ In this case, it expects the function to dispatch to CPU implementation.
35
+ The return values may become CPU tensors as well and it's user's
36
+ responsibility to convert it back to CUDA tensor if needed.
37
+
38
+ Args:
39
+ func: a stateless callable that takes tensor-like objects as arguments
40
+
41
+ Returns:
42
+ a callable which retries `func` if OOM is encountered.
43
+
44
+ Examples:
45
+ ::
46
+ output = retry_if_cuda_oom(some_torch_function)(input1, input2)
47
+ # output may be on CPU even if inputs are on GPU
48
+
49
+ Note:
50
+ 1. When converting inputs to CPU, it will only look at each argument and check
51
+ if it has `.device` and `.to` for conversion. Nested structures of tensors
52
+ are not supported.
53
+
54
+ 2. Since the function might be called more than once, it has to be
55
+ stateless.
56
+ """
57
+
58
+ def maybe_to_cpu(x):
59
+ try:
60
+ like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to")
61
+ except AttributeError:
62
+ like_gpu_tensor = False
63
+ if like_gpu_tensor:
64
+ return x.to(device="cpu")
65
+ else:
66
+ return x
67
+
68
+ @wraps(func)
69
+ def wrapped(*args, **kwargs):
70
+ with _ignore_torch_cuda_oom():
71
+ return func(*args, **kwargs)
72
+
73
+ # Clear cache and retry
74
+ torch.cuda.empty_cache()
75
+ with _ignore_torch_cuda_oom():
76
+ return func(*args, **kwargs)
77
+
78
+ # Try on CPU. This slows down the code significantly, therefore print a notice.
79
+ logger = logging.getLogger(__name__)
80
+ logger.info(
81
+ "Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func))
82
+ )
83
+ new_args = (maybe_to_cpu(x) for x in args)
84
+ new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
85
+ return func(*new_args, **new_kwargs)
86
+
87
+ return wrapped
sam3/agent/helpers/rle.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ """Some utilities for RLE encoding that doesn't require downloading the masks to the cpu"""
4
+
5
+ import numpy as np
6
+ import torch
7
+ from pycocotools import mask as mask_util
8
+
9
+
10
+ @torch.no_grad()
11
+ def rle_encode(orig_mask, return_areas=False):
12
+ """Encodes a collection of masks in RLE format
13
+
14
+ This function emulates the behavior of the COCO API's encode function, but
15
+ is executed partially on the GPU for faster execution.
16
+
17
+ Args:
18
+ mask (torch.Tensor): A mask of shape (N, H, W) with dtype=torch.bool
19
+ return_areas (bool): If True, add the areas of the masks as a part of
20
+ the RLE output dict under the "area" key. Default is False.
21
+
22
+ Returns:
23
+ str: The RLE encoded masks
24
+ """
25
+ assert orig_mask.ndim == 3, "Mask must be of shape (N, H, W)"
26
+ assert orig_mask.dtype == torch.bool, "Mask must have dtype=torch.bool"
27
+
28
+ if orig_mask.numel() == 0:
29
+ return []
30
+
31
+ # First, transpose the spatial dimensions.
32
+ # This is necessary because the COCO API uses Fortran order
33
+ mask = orig_mask.transpose(1, 2)
34
+
35
+ # Flatten the mask
36
+ flat_mask = mask.reshape(mask.shape[0], -1)
37
+ if return_areas:
38
+ mask_areas = flat_mask.sum(-1).tolist()
39
+ # Find the indices where the mask changes
40
+ differences = torch.ones(
41
+ mask.shape[0], flat_mask.shape[1] + 1, device=mask.device, dtype=torch.bool
42
+ )
43
+ differences[:, 1:-1] = flat_mask[:, :-1] != flat_mask[:, 1:]
44
+ differences[:, 0] = flat_mask[:, 0]
45
+ _, change_indices = torch.where(differences)
46
+
47
+ try:
48
+ boundaries = torch.cumsum(differences.sum(-1), 0).cpu()
49
+ except RuntimeError as _:
50
+ boundaries = torch.cumsum(differences.cpu().sum(-1), 0)
51
+
52
+ change_indices_clone = change_indices.clone()
53
+ # First pass computes the RLEs on GPU, in a flatten format
54
+ for i in range(mask.shape[0]):
55
+ # Get the change indices for this batch item
56
+ beg = 0 if i == 0 else boundaries[i - 1].item()
57
+ end = boundaries[i].item()
58
+ change_indices[beg + 1 : end] -= change_indices_clone[beg : end - 1]
59
+
60
+ # Now we can split the RLES of each batch item, and convert them to strings
61
+ # No more gpu at this point
62
+ change_indices = change_indices.tolist()
63
+
64
+ batch_rles = []
65
+ # Process each mask in the batch separately
66
+ for i in range(mask.shape[0]):
67
+ beg = 0 if i == 0 else boundaries[i - 1].item()
68
+ end = boundaries[i].item()
69
+ run_lengths = change_indices[beg:end]
70
+
71
+ uncompressed_rle = {"counts": run_lengths, "size": list(orig_mask.shape[1:])}
72
+ h, w = uncompressed_rle["size"]
73
+ rle = mask_util.frPyObjects(uncompressed_rle, h, w)
74
+ rle["counts"] = rle["counts"].decode("utf-8")
75
+ if return_areas:
76
+ rle["area"] = mask_areas[i]
77
+ batch_rles.append(rle)
78
+
79
+ return batch_rles
80
+
81
+
82
+ def robust_rle_encode(masks):
83
+ """Encodes a collection of masks in RLE format. Uses the gpu version fist, falls back to the cpu version if it fails"""
84
+
85
+ assert masks.ndim == 3, "Mask must be of shape (N, H, W)"
86
+ assert masks.dtype == torch.bool, "Mask must have dtype=torch.bool"
87
+
88
+ try:
89
+ return rle_encode(masks)
90
+ except RuntimeError as _:
91
+ masks = masks.cpu().numpy()
92
+ rles = [
93
+ mask_util.encode(
94
+ np.array(mask[:, :, np.newaxis], dtype=np.uint8, order="F")
95
+ )[0]
96
+ for mask in masks
97
+ ]
98
+ for rle in rles:
99
+ rle["counts"] = rle["counts"].decode("utf-8")
100
+ return rles
101
+
102
+
103
+ def ann_to_rle(segm, im_info):
104
+ """Convert annotation which can be polygons, uncompressed RLE to RLE.
105
+ Args:
106
+ ann (dict) : annotation object
107
+ Returns:
108
+ ann (rle)
109
+ """
110
+ h, w = im_info["height"], im_info["width"]
111
+ if isinstance(segm, list):
112
+ # polygon -- a single object might consist of multiple parts
113
+ # we merge all parts into one mask rle code
114
+ rles = mask_util.frPyObjects(segm, h, w)
115
+ rle = mask_util.merge(rles)
116
+ elif isinstance(segm["counts"], list):
117
+ # uncompressed RLE
118
+ rle = mask_util.frPyObjects(segm, h, w)
119
+ else:
120
+ # rle
121
+ rle = segm
122
+ return rle
sam3/agent/helpers/roi_align.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ from torch import nn
4
+ from torchvision.ops import roi_align
5
+
6
+
7
+ # NOTE: torchvision's RoIAlign has a different default aligned=False
8
+ class ROIAlign(nn.Module):
9
+ def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True):
10
+ """
11
+ Args:
12
+ output_size (tuple): h, w
13
+ spatial_scale (float): scale the input boxes by this number
14
+ sampling_ratio (int): number of inputs samples to take for each output
15
+ sample. 0 to take samples densely.
16
+ aligned (bool): if False, use the legacy implementation in
17
+ Detectron. If True, align the results more perfectly.
18
+
19
+ Note:
20
+ The meaning of aligned=True:
21
+
22
+ Given a continuous coordinate c, its two neighboring pixel indices (in our
23
+ pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example,
24
+ c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled
25
+ from the underlying signal at continuous coordinates 0.5 and 1.5). But the original
26
+ roi_align (aligned=False) does not subtract the 0.5 when computing neighboring
27
+ pixel indices and therefore it uses pixels with a slightly incorrect alignment
28
+ (relative to our pixel model) when performing bilinear interpolation.
29
+
30
+ With `aligned=True`,
31
+ we first appropriately scale the ROI and then shift it by -0.5
32
+ prior to calling roi_align. This produces the correct neighbors; see
33
+ detectron2/tests/test_roi_align.py for verification.
34
+
35
+ The difference does not make a difference to the model's performance if
36
+ ROIAlign is used together with conv layers.
37
+ """
38
+ super().__init__()
39
+ self.output_size = output_size
40
+ self.spatial_scale = spatial_scale
41
+ self.sampling_ratio = sampling_ratio
42
+ self.aligned = aligned
43
+
44
+ from torchvision import __version__
45
+
46
+ version = tuple(int(x) for x in __version__.split(".")[:2])
47
+ # https://github.com/pytorch/vision/pull/2438
48
+ assert version >= (0, 7), "Require torchvision >= 0.7"
49
+
50
+ def forward(self, input, rois):
51
+ """
52
+ Args:
53
+ input: NCHW images
54
+ rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
55
+ """
56
+ assert rois.dim() == 2 and rois.size(1) == 5
57
+ if input.is_quantized:
58
+ input = input.dequantize()
59
+ return roi_align(
60
+ input,
61
+ rois.to(dtype=input.dtype),
62
+ self.output_size,
63
+ self.spatial_scale,
64
+ self.sampling_ratio,
65
+ self.aligned,
66
+ )
67
+
68
+ def __repr__(self):
69
+ tmpstr = self.__class__.__name__ + "("
70
+ tmpstr += "output_size=" + str(self.output_size)
71
+ tmpstr += ", spatial_scale=" + str(self.spatial_scale)
72
+ tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
73
+ tmpstr += ", aligned=" + str(self.aligned)
74
+ tmpstr += ")"
75
+ return tmpstr
sam3/agent/helpers/rotated_boxes.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ from __future__ import absolute_import, division, print_function, unicode_literals
4
+
5
+ import math
6
+ from typing import List, Tuple
7
+
8
+ import torch
9
+
10
+ # from detectron2.layers.rotated_boxes import pairwise_iou_rotated
11
+
12
+ from .boxes import Boxes
13
+
14
+
15
+ def pairwise_iou_rotated(boxes1, boxes2):
16
+ """
17
+ Return intersection-over-union (Jaccard index) of boxes.
18
+
19
+ Both sets of boxes are expected to be in
20
+ (x_center, y_center, width, height, angle) format.
21
+
22
+ Arguments:
23
+ boxes1 (Tensor[N, 5])
24
+ boxes2 (Tensor[M, 5])
25
+
26
+ Returns:
27
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
28
+ IoU values for every element in boxes1 and boxes2
29
+ """
30
+ return torch.ops.detectron2.box_iou_rotated(boxes1, boxes2)
31
+
32
+
33
+ class RotatedBoxes(Boxes):
34
+ """
35
+ This structure stores a list of rotated boxes as a Nx5 torch.Tensor.
36
+ It supports some common methods about boxes
37
+ (`area`, `clip`, `nonempty`, etc),
38
+ and also behaves like a Tensor
39
+ (support indexing, `to(device)`, `.device`, and iteration over all boxes)
40
+ """
41
+
42
+ def __init__(self, tensor: torch.Tensor):
43
+ """
44
+ Args:
45
+ tensor (Tensor[float]): a Nx5 matrix. Each row is
46
+ (x_center, y_center, width, height, angle),
47
+ in which angle is represented in degrees.
48
+ While there's no strict range restriction for it,
49
+ the recommended principal range is between [-180, 180) degrees.
50
+
51
+ Assume we have a horizontal box B = (x_center, y_center, width, height),
52
+ where width is along the x-axis and height is along the y-axis.
53
+ The rotated box B_rot (x_center, y_center, width, height, angle)
54
+ can be seen as:
55
+
56
+ 1. When angle == 0:
57
+ B_rot == B
58
+ 2. When angle > 0:
59
+ B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CCW;
60
+ 3. When angle < 0:
61
+ B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CW.
62
+
63
+ Mathematically, since the right-handed coordinate system for image space
64
+ is (y, x), where y is top->down and x is left->right, the 4 vertices of the
65
+ rotated rectangle :math:`(yr_i, xr_i)` (i = 1, 2, 3, 4) can be obtained from
66
+ the vertices of the horizontal rectangle :math:`(y_i, x_i)` (i = 1, 2, 3, 4)
67
+ in the following way (:math:`\\theta = angle*\\pi/180` is the angle in radians,
68
+ :math:`(y_c, x_c)` is the center of the rectangle):
69
+
70
+ .. math::
71
+
72
+ yr_i = \\cos(\\theta) (y_i - y_c) - \\sin(\\theta) (x_i - x_c) + y_c,
73
+
74
+ xr_i = \\sin(\\theta) (y_i - y_c) + \\cos(\\theta) (x_i - x_c) + x_c,
75
+
76
+ which is the standard rigid-body rotation transformation.
77
+
78
+ Intuitively, the angle is
79
+ (1) the rotation angle from y-axis in image space
80
+ to the height vector (top->down in the box's local coordinate system)
81
+ of the box in CCW, and
82
+ (2) the rotation angle from x-axis in image space
83
+ to the width vector (left->right in the box's local coordinate system)
84
+ of the box in CCW.
85
+
86
+ More intuitively, consider the following horizontal box ABCD represented
87
+ in (x1, y1, x2, y2): (3, 2, 7, 4),
88
+ covering the [3, 7] x [2, 4] region of the continuous coordinate system
89
+ which looks like this:
90
+
91
+ .. code:: none
92
+
93
+ O--------> x
94
+ |
95
+ | A---B
96
+ | | |
97
+ | D---C
98
+ |
99
+ v y
100
+
101
+ Note that each capital letter represents one 0-dimensional geometric point
102
+ instead of a 'square pixel' here.
103
+
104
+ In the example above, using (x, y) to represent a point we have:
105
+
106
+ .. math::
107
+
108
+ O = (0, 0), A = (3, 2), B = (7, 2), C = (7, 4), D = (3, 4)
109
+
110
+ We name vector AB = vector DC as the width vector in box's local coordinate system, and
111
+ vector AD = vector BC as the height vector in box's local coordinate system. Initially,
112
+ when angle = 0 degree, they're aligned with the positive directions of x-axis and y-axis
113
+ in the image space, respectively.
114
+
115
+ For better illustration, we denote the center of the box as E,
116
+
117
+ .. code:: none
118
+
119
+ O--------> x
120
+ |
121
+ | A---B
122
+ | | E |
123
+ | D---C
124
+ |
125
+ v y
126
+
127
+ where the center E = ((3+7)/2, (2+4)/2) = (5, 3).
128
+
129
+ Also,
130
+
131
+ .. math::
132
+
133
+ width = |AB| = |CD| = 7 - 3 = 4,
134
+ height = |AD| = |BC| = 4 - 2 = 2.
135
+
136
+ Therefore, the corresponding representation for the same shape in rotated box in
137
+ (x_center, y_center, width, height, angle) format is:
138
+
139
+ (5, 3, 4, 2, 0),
140
+
141
+ Now, let's consider (5, 3, 4, 2, 90), which is rotated by 90 degrees
142
+ CCW (counter-clockwise) by definition. It looks like this:
143
+
144
+ .. code:: none
145
+
146
+ O--------> x
147
+ | B-C
148
+ | | |
149
+ | |E|
150
+ | | |
151
+ | A-D
152
+ v y
153
+
154
+ The center E is still located at the same point (5, 3), while the vertices
155
+ ABCD are rotated by 90 degrees CCW with regard to E:
156
+ A = (4, 5), B = (4, 1), C = (6, 1), D = (6, 5)
157
+
158
+ Here, 90 degrees can be seen as the CCW angle to rotate from y-axis to
159
+ vector AD or vector BC (the top->down height vector in box's local coordinate system),
160
+ or the CCW angle to rotate from x-axis to vector AB or vector DC (the left->right
161
+ width vector in box's local coordinate system).
162
+
163
+ .. math::
164
+
165
+ width = |AB| = |CD| = 5 - 1 = 4,
166
+ height = |AD| = |BC| = 6 - 4 = 2.
167
+
168
+ Next, how about (5, 3, 4, 2, -90), which is rotated by 90 degrees CW (clockwise)
169
+ by definition? It looks like this:
170
+
171
+ .. code:: none
172
+
173
+ O--------> x
174
+ | D-A
175
+ | | |
176
+ | |E|
177
+ | | |
178
+ | C-B
179
+ v y
180
+
181
+ The center E is still located at the same point (5, 3), while the vertices
182
+ ABCD are rotated by 90 degrees CW with regard to E:
183
+ A = (6, 1), B = (6, 5), C = (4, 5), D = (4, 1)
184
+
185
+ .. math::
186
+
187
+ width = |AB| = |CD| = 5 - 1 = 4,
188
+ height = |AD| = |BC| = 6 - 4 = 2.
189
+
190
+ This covers exactly the same region as (5, 3, 4, 2, 90) does, and their IoU
191
+ will be 1. However, these two will generate different RoI Pooling results and
192
+ should not be treated as an identical box.
193
+
194
+ On the other hand, it's easy to see that (X, Y, W, H, A) is identical to
195
+ (X, Y, W, H, A+360N), for any integer N. For example (5, 3, 4, 2, 270) would be
196
+ identical to (5, 3, 4, 2, -90), because rotating the shape 270 degrees CCW is
197
+ equivalent to rotating the same shape 90 degrees CW.
198
+
199
+ We could rotate further to get (5, 3, 4, 2, 180), or (5, 3, 4, 2, -180):
200
+
201
+ .. code:: none
202
+
203
+ O--------> x
204
+ |
205
+ | C---D
206
+ | | E |
207
+ | B---A
208
+ |
209
+ v y
210
+
211
+ .. math::
212
+
213
+ A = (7, 4), B = (3, 4), C = (3, 2), D = (7, 2),
214
+
215
+ width = |AB| = |CD| = 7 - 3 = 4,
216
+ height = |AD| = |BC| = 4 - 2 = 2.
217
+
218
+ Finally, this is a very inaccurate (heavily quantized) illustration of
219
+ how (5, 3, 4, 2, 60) looks like in case anyone wonders:
220
+
221
+ .. code:: none
222
+
223
+ O--------> x
224
+ | B\
225
+ | / C
226
+ | /E /
227
+ | A /
228
+ | `D
229
+ v y
230
+
231
+ It's still a rectangle with center of (5, 3), width of 4 and height of 2,
232
+ but its angle (and thus orientation) is somewhere between
233
+ (5, 3, 4, 2, 0) and (5, 3, 4, 2, 90).
234
+ """
235
+ device = (
236
+ tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
237
+ )
238
+ tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
239
+ if tensor.numel() == 0:
240
+ # Use reshape, so we don't end up creating a new tensor that does not depend on
241
+ # the inputs (and consequently confuses jit)
242
+ tensor = tensor.reshape((0, 5)).to(dtype=torch.float32, device=device)
243
+ assert tensor.dim() == 2 and tensor.size(-1) == 5, tensor.size()
244
+
245
+ self.tensor = tensor
246
+
247
+ def clone(self) -> "RotatedBoxes":
248
+ """
249
+ Clone the RotatedBoxes.
250
+
251
+ Returns:
252
+ RotatedBoxes
253
+ """
254
+ return RotatedBoxes(self.tensor.clone())
255
+
256
+ def to(self, device: torch.device, non_blocking: bool = False):
257
+ # Boxes are assumed float32 and does not support to(dtype)
258
+ return RotatedBoxes(self.tensor.to(device=device, non_blocking=non_blocking))
259
+
260
+ def area(self) -> torch.Tensor:
261
+ """
262
+ Computes the area of all the boxes.
263
+
264
+ Returns:
265
+ torch.Tensor: a vector with areas of each box.
266
+ """
267
+ box = self.tensor
268
+ area = box[:, 2] * box[:, 3]
269
+ return area
270
+
271
+ # Avoid in-place operations so that we can torchscript; NOTE: this creates a new tensor
272
+ def normalize_angles(self) -> None:
273
+ """
274
+ Restrict angles to the range of [-180, 180) degrees
275
+ """
276
+ angle_tensor = (self.tensor[:, 4] + 180.0) % 360.0 - 180.0
277
+ self.tensor = torch.cat((self.tensor[:, :4], angle_tensor[:, None]), dim=1)
278
+
279
+ def clip(
280
+ self, box_size: Tuple[int, int], clip_angle_threshold: float = 1.0
281
+ ) -> None:
282
+ """
283
+ Clip (in place) the boxes by limiting x coordinates to the range [0, width]
284
+ and y coordinates to the range [0, height].
285
+
286
+ For RRPN:
287
+ Only clip boxes that are almost horizontal with a tolerance of
288
+ clip_angle_threshold to maintain backward compatibility.
289
+
290
+ Rotated boxes beyond this threshold are not clipped for two reasons:
291
+
292
+ 1. There are potentially multiple ways to clip a rotated box to make it
293
+ fit within the image.
294
+ 2. It's tricky to make the entire rectangular box fit within the image
295
+ and still be able to not leave out pixels of interest.
296
+
297
+ Therefore we rely on ops like RoIAlignRotated to safely handle this.
298
+
299
+ Args:
300
+ box_size (height, width): The clipping box's size.
301
+ clip_angle_threshold:
302
+ Iff. abs(normalized(angle)) <= clip_angle_threshold (in degrees),
303
+ we do the clipping as horizontal boxes.
304
+ """
305
+ h, w = box_size
306
+
307
+ # normalize angles to be within (-180, 180] degrees
308
+ self.normalize_angles()
309
+
310
+ idx = torch.where(torch.abs(self.tensor[:, 4]) <= clip_angle_threshold)[0]
311
+
312
+ # convert to (x1, y1, x2, y2)
313
+ x1 = self.tensor[idx, 0] - self.tensor[idx, 2] / 2.0
314
+ y1 = self.tensor[idx, 1] - self.tensor[idx, 3] / 2.0
315
+ x2 = self.tensor[idx, 0] + self.tensor[idx, 2] / 2.0
316
+ y2 = self.tensor[idx, 1] + self.tensor[idx, 3] / 2.0
317
+
318
+ # clip
319
+ x1.clamp_(min=0, max=w)
320
+ y1.clamp_(min=0, max=h)
321
+ x2.clamp_(min=0, max=w)
322
+ y2.clamp_(min=0, max=h)
323
+
324
+ # convert back to (xc, yc, w, h)
325
+ self.tensor[idx, 0] = (x1 + x2) / 2.0
326
+ self.tensor[idx, 1] = (y1 + y2) / 2.0
327
+ # make sure widths and heights do not increase due to numerical errors
328
+ self.tensor[idx, 2] = torch.min(self.tensor[idx, 2], x2 - x1)
329
+ self.tensor[idx, 3] = torch.min(self.tensor[idx, 3], y2 - y1)
330
+
331
+ def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
332
+ """
333
+ Find boxes that are non-empty.
334
+ A box is considered empty, if either of its side is no larger than threshold.
335
+
336
+ Returns:
337
+ Tensor: a binary vector which represents
338
+ whether each box is empty (False) or non-empty (True).
339
+ """
340
+ box = self.tensor
341
+ widths = box[:, 2]
342
+ heights = box[:, 3]
343
+ keep = (widths > threshold) & (heights > threshold)
344
+ return keep
345
+
346
+ def __getitem__(self, item) -> "RotatedBoxes":
347
+ """
348
+ Returns:
349
+ RotatedBoxes: Create a new :class:`RotatedBoxes` by indexing.
350
+
351
+ The following usage are allowed:
352
+
353
+ 1. `new_boxes = boxes[3]`: return a `RotatedBoxes` which contains only one box.
354
+ 2. `new_boxes = boxes[2:10]`: return a slice of boxes.
355
+ 3. `new_boxes = boxes[vector]`, where vector is a torch.ByteTensor
356
+ with `length = len(boxes)`. Nonzero elements in the vector will be selected.
357
+
358
+ Note that the returned RotatedBoxes might share storage with this RotatedBoxes,
359
+ subject to Pytorch's indexing semantics.
360
+ """
361
+ if isinstance(item, int):
362
+ return RotatedBoxes(self.tensor[item].view(1, -1))
363
+ b = self.tensor[item]
364
+ assert (
365
+ b.dim() == 2
366
+ ), "Indexing on RotatedBoxes with {} failed to return a matrix!".format(item)
367
+ return RotatedBoxes(b)
368
+
369
+ def __len__(self) -> int:
370
+ return self.tensor.shape[0]
371
+
372
+ def __repr__(self) -> str:
373
+ return "RotatedBoxes(" + str(self.tensor) + ")"
374
+
375
+ def inside_box(
376
+ self, box_size: Tuple[int, int], boundary_threshold: int = 0
377
+ ) -> torch.Tensor:
378
+ """
379
+ Args:
380
+ box_size (height, width): Size of the reference box covering
381
+ [0, width] x [0, height]
382
+ boundary_threshold (int): Boxes that extend beyond the reference box
383
+ boundary by more than boundary_threshold are considered "outside".
384
+
385
+ For RRPN, it might not be necessary to call this function since it's common
386
+ for rotated box to extend to outside of the image boundaries
387
+ (the clip function only clips the near-horizontal boxes)
388
+
389
+ Returns:
390
+ a binary vector, indicating whether each box is inside the reference box.
391
+ """
392
+ height, width = box_size
393
+
394
+ cnt_x = self.tensor[..., 0]
395
+ cnt_y = self.tensor[..., 1]
396
+ half_w = self.tensor[..., 2] / 2.0
397
+ half_h = self.tensor[..., 3] / 2.0
398
+ a = self.tensor[..., 4]
399
+ c = torch.abs(torch.cos(a * math.pi / 180.0))
400
+ s = torch.abs(torch.sin(a * math.pi / 180.0))
401
+ # This basically computes the horizontal bounding rectangle of the rotated box
402
+ max_rect_dx = c * half_w + s * half_h
403
+ max_rect_dy = c * half_h + s * half_w
404
+
405
+ inds_inside = (
406
+ (cnt_x - max_rect_dx >= -boundary_threshold)
407
+ & (cnt_y - max_rect_dy >= -boundary_threshold)
408
+ & (cnt_x + max_rect_dx < width + boundary_threshold)
409
+ & (cnt_y + max_rect_dy < height + boundary_threshold)
410
+ )
411
+
412
+ return inds_inside
413
+
414
+ def get_centers(self) -> torch.Tensor:
415
+ """
416
+ Returns:
417
+ The box centers in a Nx2 array of (x, y).
418
+ """
419
+ return self.tensor[:, :2]
420
+
421
+ def scale(self, scale_x: float, scale_y: float) -> None:
422
+ """
423
+ Scale the rotated box with horizontal and vertical scaling factors
424
+ Note: when scale_factor_x != scale_factor_y,
425
+ the rotated box does not preserve the rectangular shape when the angle
426
+ is not a multiple of 90 degrees under resize transformation.
427
+ Instead, the shape is a parallelogram (that has skew)
428
+ Here we make an approximation by fitting a rotated rectangle to the parallelogram.
429
+ """
430
+ self.tensor[:, 0] *= scale_x
431
+ self.tensor[:, 1] *= scale_y
432
+ theta = self.tensor[:, 4] * math.pi / 180.0
433
+ c = torch.cos(theta)
434
+ s = torch.sin(theta)
435
+
436
+ # In image space, y is top->down and x is left->right
437
+ # Consider the local coordintate system for the rotated box,
438
+ # where the box center is located at (0, 0), and the four vertices ABCD are
439
+ # A(-w / 2, -h / 2), B(w / 2, -h / 2), C(w / 2, h / 2), D(-w / 2, h / 2)
440
+ # the midpoint of the left edge AD of the rotated box E is:
441
+ # E = (A+D)/2 = (-w / 2, 0)
442
+ # the midpoint of the top edge AB of the rotated box F is:
443
+ # F(0, -h / 2)
444
+ # To get the old coordinates in the global system, apply the rotation transformation
445
+ # (Note: the right-handed coordinate system for image space is yOx):
446
+ # (old_x, old_y) = (s * y + c * x, c * y - s * x)
447
+ # E(old) = (s * 0 + c * (-w/2), c * 0 - s * (-w/2)) = (-c * w / 2, s * w / 2)
448
+ # F(old) = (s * (-h / 2) + c * 0, c * (-h / 2) - s * 0) = (-s * h / 2, -c * h / 2)
449
+ # After applying the scaling factor (sfx, sfy):
450
+ # E(new) = (-sfx * c * w / 2, sfy * s * w / 2)
451
+ # F(new) = (-sfx * s * h / 2, -sfy * c * h / 2)
452
+ # The new width after scaling tranformation becomes:
453
+
454
+ # w(new) = |E(new) - O| * 2
455
+ # = sqrt[(sfx * c * w / 2)^2 + (sfy * s * w / 2)^2] * 2
456
+ # = sqrt[(sfx * c)^2 + (sfy * s)^2] * w
457
+ # i.e., scale_factor_w = sqrt[(sfx * c)^2 + (sfy * s)^2]
458
+ #
459
+ # For example,
460
+ # when angle = 0 or 180, |c| = 1, s = 0, scale_factor_w == scale_factor_x;
461
+ # when |angle| = 90, c = 0, |s| = 1, scale_factor_w == scale_factor_y
462
+ self.tensor[:, 2] *= torch.sqrt((scale_x * c) ** 2 + (scale_y * s) ** 2)
463
+
464
+ # h(new) = |F(new) - O| * 2
465
+ # = sqrt[(sfx * s * h / 2)^2 + (sfy * c * h / 2)^2] * 2
466
+ # = sqrt[(sfx * s)^2 + (sfy * c)^2] * h
467
+ # i.e., scale_factor_h = sqrt[(sfx * s)^2 + (sfy * c)^2]
468
+ #
469
+ # For example,
470
+ # when angle = 0 or 180, |c| = 1, s = 0, scale_factor_h == scale_factor_y;
471
+ # when |angle| = 90, c = 0, |s| = 1, scale_factor_h == scale_factor_x
472
+ self.tensor[:, 3] *= torch.sqrt((scale_x * s) ** 2 + (scale_y * c) ** 2)
473
+
474
+ # The angle is the rotation angle from y-axis in image space to the height
475
+ # vector (top->down in the box's local coordinate system) of the box in CCW.
476
+ #
477
+ # angle(new) = angle_yOx(O - F(new))
478
+ # = angle_yOx( (sfx * s * h / 2, sfy * c * h / 2) )
479
+ # = atan2(sfx * s * h / 2, sfy * c * h / 2)
480
+ # = atan2(sfx * s, sfy * c)
481
+ #
482
+ # For example,
483
+ # when sfx == sfy, angle(new) == atan2(s, c) == angle(old)
484
+ self.tensor[:, 4] = torch.atan2(scale_x * s, scale_y * c) * 180 / math.pi
485
+
486
+ @classmethod
487
+ def cat(cls, boxes_list: List["RotatedBoxes"]) -> "RotatedBoxes":
488
+ """
489
+ Concatenates a list of RotatedBoxes into a single RotatedBoxes
490
+
491
+ Arguments:
492
+ boxes_list (list[RotatedBoxes])
493
+
494
+ Returns:
495
+ RotatedBoxes: the concatenated RotatedBoxes
496
+ """
497
+ assert isinstance(boxes_list, (list, tuple))
498
+ if len(boxes_list) == 0:
499
+ return cls(torch.empty(0))
500
+ assert all([isinstance(box, RotatedBoxes) for box in boxes_list])
501
+
502
+ # use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input
503
+ cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0))
504
+ return cat_boxes
505
+
506
+ @property
507
+ def device(self) -> torch.device:
508
+ return self.tensor.device
509
+
510
+ @torch.jit.unused
511
+ def __iter__(self):
512
+ """
513
+ Yield a box as a Tensor of shape (5,) at a time.
514
+ """
515
+ yield from self.tensor
516
+
517
+
518
+ def pairwise_iou(boxes1: RotatedBoxes, boxes2: RotatedBoxes) -> None:
519
+ """
520
+ Given two lists of rotated boxes of size N and M,
521
+ compute the IoU (intersection over union)
522
+ between **all** N x M pairs of boxes.
523
+ The box order must be (x_center, y_center, width, height, angle).
524
+
525
+ Args:
526
+ boxes1, boxes2 (RotatedBoxes):
527
+ two `RotatedBoxes`. Contains N & M rotated boxes, respectively.
528
+
529
+ Returns:
530
+ Tensor: IoU, sized [N,M].
531
+ """
532
+
533
+ return pairwise_iou_rotated(boxes1.tensor, boxes2.tensor)
sam3/agent/helpers/som_utils.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import colorsys
4
+ from dataclasses import dataclass
5
+ from typing import List, Tuple
6
+
7
+ import cv2
8
+ import matplotlib as mpl
9
+ import matplotlib.colors as mplc
10
+ import numpy as np
11
+ import pycocotools.mask as mask_utils
12
+
13
+
14
+ def rgb_to_hex(rgb_color):
15
+ """
16
+ Convert a rgb color to hex color.
17
+
18
+ Args:
19
+ rgb_color (tuple/list of ints): RGB color in tuple or list format.
20
+
21
+ Returns:
22
+ str: Hex color.
23
+
24
+ Example:
25
+ ```
26
+ >>> rgb_to_hex((255, 0, 244))
27
+ '#ff00ff'
28
+ ```
29
+ """
30
+ return "#" + "".join([hex(c)[2:].zfill(2) for c in rgb_color])
31
+
32
+
33
+ # DEFAULT_COLOR_HEX_TO_NAME = {
34
+ # rgb_to_hex((255, 0, 0)): "red",
35
+ # rgb_to_hex((0, 255, 0)): "lime",
36
+ # rgb_to_hex((0, 0, 255)): "blue",
37
+ # rgb_to_hex((255, 255, 0)): "yellow",
38
+ # rgb_to_hex((255, 0, 255)): "fuchsia",
39
+ # rgb_to_hex((0, 255, 255)): "aqua",
40
+ # rgb_to_hex((255, 165, 0)): "orange",
41
+ # rgb_to_hex((128, 0, 128)): "purple",
42
+ # rgb_to_hex((255, 215, 0)): "gold",
43
+ # }
44
+
45
+ # Assuming rgb_to_hex is a function that converts an (R, G, B) tuple to a hex string.
46
+ # For example: def rgb_to_hex(rgb): return '#%02x%02x%02x' % rgb
47
+
48
+ DEFAULT_COLOR_HEX_TO_NAME = {
49
+ # The top 20 approved colors
50
+ rgb_to_hex((255, 255, 0)): "yellow",
51
+ rgb_to_hex((0, 255, 0)): "lime",
52
+ rgb_to_hex((0, 255, 255)): "cyan",
53
+ rgb_to_hex((255, 0, 255)): "magenta",
54
+ rgb_to_hex((255, 0, 0)): "red",
55
+ rgb_to_hex((255, 127, 0)): "orange",
56
+ rgb_to_hex((127, 255, 0)): "chartreuse",
57
+ rgb_to_hex((0, 255, 127)): "spring green",
58
+ rgb_to_hex((255, 0, 127)): "rose",
59
+ rgb_to_hex((127, 0, 255)): "violet",
60
+ rgb_to_hex((192, 255, 0)): "electric lime",
61
+ rgb_to_hex((255, 192, 0)): "vivid orange",
62
+ rgb_to_hex((0, 255, 192)): "turquoise",
63
+ rgb_to_hex((192, 0, 255)): "bright violet",
64
+ rgb_to_hex((255, 0, 192)): "bright pink",
65
+ rgb_to_hex((255, 64, 0)): "fiery orange",
66
+ rgb_to_hex((64, 255, 0)): "bright chartreuse",
67
+ rgb_to_hex((0, 255, 64)): "malachite",
68
+ rgb_to_hex((64, 0, 255)): "deep violet",
69
+ rgb_to_hex((255, 0, 64)): "hot pink",
70
+ }
71
+
72
+
73
+ DEFAULT_COLOR_PALETTE = list(DEFAULT_COLOR_HEX_TO_NAME.keys())
74
+
75
+
76
+ def _validate_color_hex(color_hex: str):
77
+ color_hex = color_hex.lstrip("#")
78
+ if not all(c in "0123456789abcdefABCDEF" for c in color_hex):
79
+ raise ValueError("Invalid characters in color hash")
80
+ if len(color_hex) not in (3, 6):
81
+ raise ValueError("Invalid length of color hash")
82
+
83
+
84
+ # copied from https://github.com/roboflow/supervision/blob/c8f557af0c61b5c03392bad2cc36c8835598b1e1/supervision/draw/color.py
85
+ @dataclass
86
+ class Color:
87
+ """
88
+ Represents a color in RGB format.
89
+
90
+ Attributes:
91
+ r (int): Red channel.
92
+ g (int): Green channel.
93
+ b (int): Blue channel.
94
+ """
95
+
96
+ r: int
97
+ g: int
98
+ b: int
99
+
100
+ @classmethod
101
+ def from_hex(cls, color_hex: str):
102
+ """
103
+ Create a Color instance from a hex string.
104
+
105
+ Args:
106
+ color_hex (str): Hex string of the color.
107
+
108
+ Returns:
109
+ Color: Instance representing the color.
110
+
111
+ Example:
112
+ ```
113
+ >>> Color.from_hex('#ff00ff')
114
+ Color(r=255, g=0, b=255)
115
+ ```
116
+ """
117
+ _validate_color_hex(color_hex)
118
+ color_hex = color_hex.lstrip("#")
119
+ if len(color_hex) == 3:
120
+ color_hex = "".join(c * 2 for c in color_hex)
121
+ r, g, b = (int(color_hex[i : i + 2], 16) for i in range(0, 6, 2))
122
+ return cls(r, g, b)
123
+
124
+ @classmethod
125
+ def to_hex(cls, color):
126
+ """
127
+ Convert a Color instance to a hex string.
128
+
129
+ Args:
130
+ color (Color): Color instance of color.
131
+
132
+ Returns:
133
+ Color: a hex string.
134
+ """
135
+ return rgb_to_hex((color.r, color.g, color.b))
136
+
137
+ def as_rgb(self) -> Tuple[int, int, int]:
138
+ """
139
+ Returns the color as an RGB tuple.
140
+
141
+ Returns:
142
+ Tuple[int, int, int]: RGB tuple.
143
+
144
+ Example:
145
+ ```
146
+ >>> color.as_rgb()
147
+ (255, 0, 255)
148
+ ```
149
+ """
150
+ return self.r, self.g, self.b
151
+
152
+ def as_bgr(self) -> Tuple[int, int, int]:
153
+ """
154
+ Returns the color as a BGR tuple.
155
+
156
+ Returns:
157
+ Tuple[int, int, int]: BGR tuple.
158
+
159
+ Example:
160
+ ```
161
+ >>> color.as_bgr()
162
+ (255, 0, 255)
163
+ ```
164
+ """
165
+ return self.b, self.g, self.r
166
+
167
+ @classmethod
168
+ def white(cls):
169
+ return Color.from_hex(color_hex="#ffffff")
170
+
171
+ @classmethod
172
+ def black(cls):
173
+ return Color.from_hex(color_hex="#000000")
174
+
175
+ @classmethod
176
+ def red(cls):
177
+ return Color.from_hex(color_hex="#ff0000")
178
+
179
+ @classmethod
180
+ def green(cls):
181
+ return Color.from_hex(color_hex="#00ff00")
182
+
183
+ @classmethod
184
+ def blue(cls):
185
+ return Color.from_hex(color_hex="#0000ff")
186
+
187
+
188
+ @dataclass
189
+ class ColorPalette:
190
+ colors: List[Color]
191
+
192
+ @classmethod
193
+ def default(cls):
194
+ """
195
+ Returns a default color palette.
196
+
197
+ Returns:
198
+ ColorPalette: A ColorPalette instance with default colors.
199
+
200
+ Example:
201
+ ```
202
+ >>> ColorPalette.default()
203
+ ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
204
+ ```
205
+ """
206
+ return ColorPalette.from_hex(color_hex_list=DEFAULT_COLOR_PALETTE)
207
+
208
+ @classmethod
209
+ def from_hex(cls, color_hex_list: List[str]):
210
+ """
211
+ Create a ColorPalette instance from a list of hex strings.
212
+
213
+ Args:
214
+ color_hex_list (List[str]): List of color hex strings.
215
+
216
+ Returns:
217
+ ColorPalette: A ColorPalette instance.
218
+
219
+ Example:
220
+ ```
221
+ >>> ColorPalette.from_hex(['#ff0000', '#00ff00', '#0000ff'])
222
+ ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
223
+ ```
224
+ """
225
+ colors = [Color.from_hex(color_hex) for color_hex in color_hex_list]
226
+ return cls(colors)
227
+
228
+ def by_idx(self, idx: int) -> Color:
229
+ """
230
+ Return the color at a given index in the palette.
231
+
232
+ Args:
233
+ idx (int): Index of the color in the palette.
234
+
235
+ Returns:
236
+ Color: Color at the given index.
237
+
238
+ Example:
239
+ ```
240
+ >>> color_palette.by_idx(1)
241
+ Color(r=0, g=255, b=0)
242
+ ```
243
+ """
244
+ if idx < 0:
245
+ raise ValueError("idx argument should not be negative")
246
+ idx = idx % len(self.colors)
247
+ return self.colors[idx]
248
+
249
+ def find_farthest_color(self, img_array):
250
+ """
251
+ Return the color that is the farthest from the given color.
252
+
253
+ Args:
254
+ img_array (np array): any *x3 np array, 3 is the RGB color channel.
255
+
256
+ Returns:
257
+ Color: Farthest color.
258
+
259
+ """
260
+ # Reshape the image array for broadcasting
261
+ img_array = img_array.reshape((-1, 3))
262
+
263
+ # Convert colors dictionary to a NumPy array
264
+ color_values = np.array([[c.r, c.g, c.b] for c in self.colors])
265
+
266
+ # Calculate the Euclidean distance between the colors and each pixel in the image
267
+ # Broadcasting happens here: img_array shape is (num_pixels, 3), color_values shape is (num_colors, 3)
268
+ distances = np.sqrt(
269
+ np.sum((img_array[:, np.newaxis, :] - color_values) ** 2, axis=2)
270
+ )
271
+
272
+ # Average the distances for each color
273
+ mean_distances = np.mean(distances, axis=0)
274
+
275
+ # return the farthest color
276
+ farthest_idx = np.argmax(mean_distances)
277
+ farthest_color = self.colors[farthest_idx]
278
+ farthest_color_hex = Color.to_hex(farthest_color)
279
+ if farthest_color_hex in DEFAULT_COLOR_HEX_TO_NAME:
280
+ farthest_color_name = DEFAULT_COLOR_HEX_TO_NAME[farthest_color_hex]
281
+ else:
282
+ farthest_color_name = "unknown"
283
+
284
+ return farthest_color, farthest_color_name
285
+
286
+
287
+ def draw_box(ax, box_coord, alpha=0.8, edge_color="g", line_style="-", linewidth=2.0):
288
+ x0, y0, width, height = box_coord
289
+ ax.add_patch(
290
+ mpl.patches.Rectangle(
291
+ (x0, y0),
292
+ width,
293
+ height,
294
+ fill=False,
295
+ edgecolor=edge_color,
296
+ linewidth=linewidth,
297
+ alpha=alpha,
298
+ linestyle=line_style,
299
+ )
300
+ )
301
+
302
+
303
+ def draw_text(
304
+ ax,
305
+ text,
306
+ position,
307
+ font_size=None,
308
+ color="g",
309
+ horizontal_alignment="left",
310
+ rotation=0,
311
+ ):
312
+ if not font_size:
313
+ font_size = mpl.rcParams["font.size"]
314
+
315
+ color = np.maximum(list(mplc.to_rgb(color)), 0.2)
316
+ color[np.argmax(color)] = max(0.8, np.max(color))
317
+
318
+ x, y = position
319
+ ax.text(
320
+ x,
321
+ y,
322
+ text,
323
+ size=font_size,
324
+ family="sans-serif",
325
+ bbox={"facecolor": "none", "alpha": 0.5, "pad": 0.7, "edgecolor": "none"},
326
+ verticalalignment="top",
327
+ horizontalalignment=horizontal_alignment,
328
+ color=color,
329
+ rotation=rotation,
330
+ )
331
+
332
+
333
+ def draw_mask(
334
+ ax, rle, color, show_holes=True, alpha=0.15, upsample_factor=1.0, rle_upsampled=None
335
+ ):
336
+ if isinstance(rle, dict):
337
+ mask = mask_utils.decode(rle)
338
+ elif isinstance(rle, np.ndarray):
339
+ mask = rle
340
+ else:
341
+ raise ValueError(f"Unsupported type for rle: {type(rle)}")
342
+
343
+ mask_upsampled = None
344
+ if upsample_factor > 1.0 and show_holes:
345
+ assert rle_upsampled is not None
346
+ if isinstance(rle_upsampled, dict):
347
+ mask_upsampled = mask_utils.decode(rle_upsampled)
348
+ elif isinstance(rle_upsampled, np.ndarray):
349
+ mask_upsampled = rle_upsampled
350
+ else:
351
+ raise ValueError(f"Unsupported type for rle: {type(rle)}")
352
+
353
+ if show_holes:
354
+ if mask_upsampled is None:
355
+ mask_upsampled = mask
356
+ h, w = mask_upsampled.shape
357
+ mask_img = np.zeros((h, w, 4))
358
+ mask_img[:, :, :-1] = color[np.newaxis, np.newaxis, :]
359
+ mask_img[:, :, -1] = mask_upsampled * alpha
360
+ ax.imshow(mask_img)
361
+
362
+ *_, contours, _ = cv2.findContours(
363
+ mask.astype(np.uint8).copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
364
+ )
365
+ upsampled_contours = [(cont + 0.5) * upsample_factor - 0.5 for cont in contours]
366
+ facecolor = (0, 0, 0, 0) if show_holes else color
367
+ if alpha > 0.8:
368
+ edge_color = _change_color_brightness(color, brightness_factor=-0.7)
369
+ else:
370
+ edge_color = color
371
+ for cont in upsampled_contours:
372
+ polygon = mpl.patches.Polygon(
373
+ [el[0] for el in cont],
374
+ edgecolor=edge_color,
375
+ linewidth=2.0,
376
+ facecolor=facecolor,
377
+ )
378
+ ax.add_patch(polygon)
379
+
380
+
381
+ def _change_color_brightness(color, brightness_factor):
382
+ """
383
+ Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
384
+ less or more saturation than the original color.
385
+
386
+ Args:
387
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
388
+ formats that are accepted.
389
+ brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
390
+ 0 will correspond to no change, a factor in [-1.0, 0) range will result in
391
+ a darker color and a factor in (0, 1.0] range will result in a lighter color.
392
+
393
+ Returns:
394
+ modified_color (tuple[double]): a tuple containing the RGB values of the
395
+ modified color. Each value in the tuple is in the [0.0, 1.0] range.
396
+ """
397
+ assert brightness_factor >= -1.0 and brightness_factor <= 1.0
398
+ color = mplc.to_rgb(color)
399
+ polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
400
+ modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
401
+ modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
402
+ modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
403
+ modified_color = colorsys.hls_to_rgb(
404
+ polygon_color[0], modified_lightness, polygon_color[2]
405
+ )
406
+ return modified_color
sam3/agent/helpers/visualizer.py ADDED
@@ -0,0 +1,1662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import colorsys
4
+ import logging
5
+ import math
6
+ import random
7
+ from enum import Enum, unique
8
+
9
+ import cv2
10
+ import matplotlib as mpl
11
+ import matplotlib.colors as mplc
12
+ import matplotlib.figure as mplfigure
13
+ import numpy as np
14
+ import pycocotools.mask as mask_util
15
+ import torch
16
+ from iopath.common.file_io import PathManager
17
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
18
+ from PIL import Image
19
+
20
+ from .boxes import Boxes, BoxMode
21
+
22
+ from .color_map import random_color
23
+ from .keypoints import Keypoints
24
+ from .masks import BitMasks, PolygonMasks
25
+ from .rotated_boxes import RotatedBoxes
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ __all__ = ["ColorMode", "VisImage", "Visualizer"]
31
+
32
+
33
+ _SMALL_OBJECT_AREA_THRESH = 1000
34
+ _LARGE_MASK_AREA_THRESH = 120000
35
+ _OFF_WHITE = (1.0, 1.0, 240.0 / 255)
36
+ _BLACK = (0, 0, 0)
37
+ _RED = (1.0, 0, 0)
38
+
39
+ _KEYPOINT_THRESHOLD = 0.05
40
+
41
+
42
+ @unique
43
+ class ColorMode(Enum):
44
+ """
45
+ Enum of different color modes to use for instance visualizations.
46
+ """
47
+
48
+ IMAGE = 0
49
+ """
50
+ Picks a random color for every instance and overlay segmentations with low opacity.
51
+ """
52
+ SEGMENTATION = 1
53
+ """
54
+ Let instances of the same category have similar colors
55
+ (from metadata.thing_colors), and overlay them with
56
+ high opacity. This provides more attention on the quality of segmentation.
57
+ """
58
+ IMAGE_BW = 2
59
+ """
60
+ Same as IMAGE, but convert all areas without masks to gray-scale.
61
+ Only available for drawing per-instance mask predictions.
62
+ """
63
+
64
+
65
+ class GenericMask:
66
+ """
67
+ Attribute:
68
+ polygons (list[ndarray]): list[ndarray]: polygons for this mask.
69
+ Each ndarray has format [x, y, x, y, ...]
70
+ mask (ndarray): a binary mask
71
+ """
72
+
73
+ def __init__(self, mask_or_polygons, height, width):
74
+ self._mask = self._polygons = self._has_holes = None
75
+ self.height = height
76
+ self.width = width
77
+
78
+ m = mask_or_polygons
79
+ if isinstance(m, dict):
80
+ # RLEs
81
+ assert "counts" in m and "size" in m
82
+ if isinstance(m["counts"], list): # uncompressed RLEs
83
+ h, w = m["size"]
84
+ assert h == height and w == width
85
+ m = mask_util.frPyObjects(m, h, w)
86
+ self._mask = mask_util.decode(m)[:, :]
87
+ return
88
+
89
+ if isinstance(m, list): # list[ndarray]
90
+ self._polygons = [np.asarray(x).reshape(-1) for x in m]
91
+ return
92
+
93
+ if isinstance(m, np.ndarray): # assumed to be a binary mask
94
+ assert m.shape[1] != 2, m.shape
95
+ assert m.shape == (
96
+ height,
97
+ width,
98
+ ), f"mask shape: {m.shape}, target dims: {height}, {width}"
99
+ self._mask = m.astype("uint8")
100
+ return
101
+
102
+ raise ValueError(
103
+ "GenericMask cannot handle object {} of type '{}'".format(m, type(m))
104
+ )
105
+
106
+ @property
107
+ def mask(self):
108
+ if self._mask is None:
109
+ self._mask = self.polygons_to_mask(self._polygons)
110
+ return self._mask
111
+
112
+ @property
113
+ def polygons(self):
114
+ if self._polygons is None:
115
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
116
+ return self._polygons
117
+
118
+ @property
119
+ def has_holes(self):
120
+ if self._has_holes is None:
121
+ if self._mask is not None:
122
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
123
+ else:
124
+ self._has_holes = (
125
+ False # if original format is polygon, does not have holes
126
+ )
127
+ return self._has_holes
128
+
129
+ def mask_to_polygons(self, mask):
130
+ # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
131
+ # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
132
+ # Internal contours (holes) are placed in hierarchy-2.
133
+ # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
134
+ mask = np.ascontiguousarray(
135
+ mask
136
+ ) # some versions of cv2 does not support incontiguous arr
137
+ res = cv2.findContours(
138
+ mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
139
+ )
140
+ hierarchy = res[-1]
141
+ if hierarchy is None: # empty mask
142
+ return [], False
143
+ has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
144
+ res = res[-2]
145
+ res = [x.flatten() for x in res]
146
+ # These coordinates from OpenCV are integers in range [0, W-1 or H-1].
147
+ # We add 0.5 to turn them into real-value coordinate space. A better solution
148
+ # would be to first +0.5 and then dilate the returned polygon by 0.5.
149
+ res = [x + 0.5 for x in res if len(x) >= 6]
150
+ return res, has_holes
151
+
152
+ def polygons_to_mask(self, polygons):
153
+ rle = mask_util.frPyObjects(polygons, self.height, self.width)
154
+ rle = mask_util.merge(rle)
155
+ return mask_util.decode(rle)[:, :]
156
+
157
+ def area(self):
158
+ return self.mask.sum()
159
+
160
+ def bbox(self):
161
+ p = mask_util.frPyObjects(self.polygons, self.height, self.width)
162
+ p = mask_util.merge(p)
163
+ bbox = mask_util.toBbox(p)
164
+ bbox[2] += bbox[0]
165
+ bbox[3] += bbox[1]
166
+ return bbox
167
+
168
+
169
+ class _PanopticPrediction:
170
+ """
171
+ Unify different panoptic annotation/prediction formats
172
+ """
173
+
174
+ def __init__(self, panoptic_seg, segments_info, metadata=None):
175
+ if segments_info is None:
176
+ assert metadata is not None
177
+ # If "segments_info" is None, we assume "panoptic_img" is a
178
+ # H*W int32 image storing the panoptic_id in the format of
179
+ # category_id * label_divisor + instance_id. We reserve -1 for
180
+ # VOID label.
181
+ label_divisor = metadata.label_divisor
182
+ segments_info = []
183
+ for panoptic_label in np.unique(panoptic_seg.numpy()):
184
+ if panoptic_label == -1:
185
+ # VOID region.
186
+ continue
187
+ pred_class = panoptic_label // label_divisor
188
+ isthing = (
189
+ pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
190
+ )
191
+ segments_info.append(
192
+ {
193
+ "id": int(panoptic_label),
194
+ "category_id": int(pred_class),
195
+ "isthing": bool(isthing),
196
+ }
197
+ )
198
+ del metadata
199
+
200
+ self._seg = panoptic_seg
201
+
202
+ self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info
203
+ segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
204
+ areas = areas.numpy()
205
+ sorted_idxs = np.argsort(-areas)
206
+ self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]
207
+ self._seg_ids = self._seg_ids.tolist()
208
+ for sid, area in zip(self._seg_ids, self._seg_areas):
209
+ if sid in self._sinfo:
210
+ self._sinfo[sid]["area"] = float(area)
211
+
212
+ def non_empty_mask(self):
213
+ """
214
+ Returns:
215
+ (H, W) array, a mask for all pixels that have a prediction
216
+ """
217
+ empty_ids = []
218
+ for id in self._seg_ids:
219
+ if id not in self._sinfo:
220
+ empty_ids.append(id)
221
+ if len(empty_ids) == 0:
222
+ return np.zeros(self._seg.shape, dtype=np.uint8)
223
+ assert (
224
+ len(empty_ids) == 1
225
+ ), ">1 ids corresponds to no labels. This is currently not supported"
226
+ return (self._seg != empty_ids[0]).numpy().astype(np.bool)
227
+
228
+ def semantic_masks(self):
229
+ for sid in self._seg_ids:
230
+ sinfo = self._sinfo.get(sid)
231
+ if sinfo is None or sinfo["isthing"]:
232
+ # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
233
+ continue
234
+ yield (self._seg == sid).numpy().astype(np.bool), sinfo
235
+
236
+ def instance_masks(self):
237
+ for sid in self._seg_ids:
238
+ sinfo = self._sinfo.get(sid)
239
+ if sinfo is None or not sinfo["isthing"]:
240
+ continue
241
+ mask = (self._seg == sid).numpy().astype(np.bool)
242
+ if mask.sum() > 0:
243
+ yield mask, sinfo
244
+
245
+
246
+ def _create_text_labels(classes, scores, class_names, is_crowd=None):
247
+ """
248
+ Args:
249
+ classes (list[int] or None):
250
+ scores (list[float] or None):
251
+ class_names (list[str] or None):
252
+ is_crowd (list[bool] or None):
253
+
254
+ Returns:
255
+ list[str] or None
256
+ """
257
+ labels = None
258
+ if classes is not None:
259
+ if class_names is not None and len(class_names) > 0:
260
+ labels = [class_names[i] for i in classes]
261
+ else:
262
+ labels = [str(i) for i in classes]
263
+ if scores is not None:
264
+ if labels is None:
265
+ labels = ["{:.0f}%".format(s * 100) for s in scores]
266
+ else:
267
+ labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
268
+ if labels is not None and is_crowd is not None:
269
+ labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)]
270
+ return labels
271
+
272
+
273
+ class VisImage:
274
+ def __init__(self, img, scale=1.0):
275
+ """
276
+ Args:
277
+ img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
278
+ scale (float): scale the input image
279
+ """
280
+ self.img = img
281
+ self.scale = scale
282
+ self.width, self.height = img.shape[1], img.shape[0]
283
+ self._setup_figure(img)
284
+
285
+ def _setup_figure(self, img):
286
+ """
287
+ Args:
288
+ Same as in :meth:`__init__()`.
289
+
290
+ Returns:
291
+ fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
292
+ ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
293
+ """
294
+ fig = mplfigure.Figure(frameon=False)
295
+ self.dpi = fig.get_dpi()
296
+ # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
297
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
298
+ fig.set_size_inches(
299
+ (self.width * self.scale + 1e-2) / self.dpi,
300
+ (self.height * self.scale + 1e-2) / self.dpi,
301
+ )
302
+ self.canvas = FigureCanvasAgg(fig)
303
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
304
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
305
+ ax.axis("off")
306
+ self.fig = fig
307
+ self.ax = ax
308
+ self.reset_image(img)
309
+
310
+ def reset_image(self, img):
311
+ """
312
+ Args:
313
+ img: same as in __init__
314
+ """
315
+ img = img.astype("uint8")
316
+ self.ax.imshow(
317
+ img, extent=(0, self.width, self.height, 0), interpolation="nearest"
318
+ )
319
+
320
+ def save(self, filepath):
321
+ """
322
+ Args:
323
+ filepath (str): a string that contains the absolute path, including the file name, where
324
+ the visualized image will be saved.
325
+ """
326
+ self.fig.savefig(filepath)
327
+
328
+ def get_image(self):
329
+ """
330
+ Returns:
331
+ ndarray:
332
+ the visualized image of shape (H, W, 3) (RGB) in uint8 type.
333
+ The shape is scaled w.r.t the input image using the given `scale` argument.
334
+ """
335
+ canvas = self.canvas
336
+ s, (width, height) = canvas.print_to_buffer()
337
+ # buf = io.BytesIO() # works for cairo backend
338
+ # canvas.print_rgba(buf)
339
+ # width, height = self.width, self.height
340
+ # s = buf.getvalue()
341
+
342
+ buffer = np.frombuffer(s, dtype="uint8")
343
+
344
+ img_rgba = buffer.reshape(height, width, 4)
345
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
346
+ return rgb.astype("uint8")
347
+
348
+
349
+ class Visualizer:
350
+ """
351
+ Visualizer that draws data about detection/segmentation on images.
352
+
353
+ It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
354
+ that draw primitive objects to images, as well as high-level wrappers like
355
+ `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
356
+ that draw composite data in some pre-defined style.
357
+
358
+ Note that the exact visualization style for the high-level wrappers are subject to change.
359
+ Style such as color, opacity, label contents, visibility of labels, or even the visibility
360
+ of objects themselves (e.g. when the object is too small) may change according
361
+ to different heuristics, as long as the results still look visually reasonable.
362
+
363
+ To obtain a consistent style, you can implement custom drawing functions with the
364
+ abovementioned primitive methods instead. If you need more customized visualization
365
+ styles, you can process the data yourself following their format documented in
366
+ tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
367
+ intend to satisfy everyone's preference on drawing styles.
368
+
369
+ This visualizer focuses on high rendering quality rather than performance. It is not
370
+ designed to be used for real-time applications.
371
+ """
372
+
373
+ def __init__(
374
+ self,
375
+ img_rgb,
376
+ metadata=None,
377
+ scale=1.0,
378
+ instance_mode=ColorMode.IMAGE,
379
+ font_size_multiplier=1.3,
380
+ boarder_width_multiplier=1.5,
381
+ ):
382
+ """
383
+ Args:
384
+ img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
385
+ the height and width of the image respectively. C is the number of
386
+ color channels. The image is required to be in RGB format since that
387
+ is a requirement of the Matplotlib library. The image is also expected
388
+ to be in the range [0, 255].
389
+ metadata (Metadata): dataset metadata (e.g. class names and colors)
390
+ instance_mode (ColorMode): defines one of the pre-defined style for drawing
391
+ instances on an image.
392
+ """
393
+ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
394
+ self.boarder_width_multiplier = boarder_width_multiplier
395
+ # if metadata is None:
396
+ # metadata = MetadataCatalog.get("__nonexist__")
397
+ # self.metadata = metadata
398
+ self.output = VisImage(self.img, scale=scale)
399
+ self.cpu_device = torch.device("cpu")
400
+
401
+ # too small texts are useless, therefore clamp to 9
402
+ self._default_font_size = (
403
+ max(np.sqrt(self.output.height * self.output.width) // 60, 15 // scale)
404
+ * font_size_multiplier
405
+ )
406
+ # self._default_font_size = 18
407
+ self._instance_mode = instance_mode
408
+ self.keypoint_threshold = _KEYPOINT_THRESHOLD
409
+
410
+ import matplotlib.colors as mcolors
411
+
412
+ css4_colors = mcolors.CSS4_COLORS
413
+ self.color_proposals = [
414
+ list(mcolors.hex2color(color)) for color in css4_colors.values()
415
+ ]
416
+
417
+ def draw_instance_predictions(self, predictions):
418
+ """
419
+ Draw instance-level prediction results on an image.
420
+
421
+ Args:
422
+ predictions (Instances): the output of an instance detection/segmentation
423
+ model. Following fields will be used to draw:
424
+ "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
425
+
426
+ Returns:
427
+ output (VisImage): image object with visualizations.
428
+ """
429
+ boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
430
+ scores = predictions.scores if predictions.has("scores") else None
431
+ classes = (
432
+ predictions.pred_classes.tolist()
433
+ if predictions.has("pred_classes")
434
+ else None
435
+ )
436
+ labels = _create_text_labels(
437
+ classes, scores, self.metadata.get("thing_classes", None)
438
+ )
439
+ keypoints = (
440
+ predictions.pred_keypoints if predictions.has("pred_keypoints") else None
441
+ )
442
+
443
+ keep = (scores > 0.5).cpu()
444
+ boxes = boxes[keep]
445
+ scores = scores[keep]
446
+ classes = np.array(classes)
447
+ classes = classes[np.array(keep)]
448
+ labels = np.array(labels)
449
+ labels = labels[np.array(keep)]
450
+
451
+ if predictions.has("pred_masks"):
452
+ masks = np.asarray(predictions.pred_masks)
453
+ masks = masks[np.array(keep)]
454
+ masks = [
455
+ GenericMask(x, self.output.height, self.output.width) for x in masks
456
+ ]
457
+ else:
458
+ masks = None
459
+
460
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(
461
+ "thing_colors"
462
+ ):
463
+ # if self.metadata.get("thing_colors"):
464
+ colors = [
465
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
466
+ for c in classes
467
+ ]
468
+ alpha = 0.4
469
+ else:
470
+ colors = None
471
+ alpha = 0.4
472
+
473
+ if self._instance_mode == ColorMode.IMAGE_BW:
474
+ self.output.reset_image(
475
+ self._create_grayscale_image(
476
+ (predictions.pred_masks.any(dim=0) > 0).numpy()
477
+ if predictions.has("pred_masks")
478
+ else None
479
+ )
480
+ )
481
+ alpha = 0.3
482
+
483
+ self.overlay_instances(
484
+ masks=masks,
485
+ boxes=boxes,
486
+ labels=labels,
487
+ keypoints=keypoints,
488
+ assigned_colors=colors,
489
+ alpha=alpha,
490
+ )
491
+ return self.output
492
+
493
+ def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.7):
494
+ """
495
+ Draw semantic segmentation predictions/labels.
496
+
497
+ Args:
498
+ sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
499
+ Each value is the integer label of the pixel.
500
+ area_threshold (int): segments with less than `area_threshold` are not drawn.
501
+ alpha (float): the larger it is, the more opaque the segmentations are.
502
+
503
+ Returns:
504
+ output (VisImage): image object with visualizations.
505
+ """
506
+ if isinstance(sem_seg, torch.Tensor):
507
+ sem_seg = sem_seg.numpy()
508
+ labels, areas = np.unique(sem_seg, return_counts=True)
509
+ sorted_idxs = np.argsort(-areas).tolist()
510
+ labels = labels[sorted_idxs]
511
+ for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
512
+ try:
513
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
514
+ except (AttributeError, IndexError):
515
+ mask_color = None
516
+
517
+ binary_mask = (sem_seg == label).astype(np.uint8)
518
+ text = self.metadata.stuff_classes[label]
519
+ self.draw_binary_mask(
520
+ binary_mask,
521
+ color=mask_color,
522
+ edge_color=_OFF_WHITE,
523
+ text=text,
524
+ alpha=alpha,
525
+ area_threshold=area_threshold,
526
+ )
527
+ return self.output
528
+
529
+ def draw_panoptic_seg(
530
+ self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7
531
+ ):
532
+ """
533
+ Draw panoptic prediction annotations or results.
534
+
535
+ Args:
536
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
537
+ segment.
538
+ segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
539
+ If it is a ``list[dict]``, each dict contains keys "id", "category_id".
540
+ If None, category id of each pixel is computed by
541
+ ``pixel // metadata.label_divisor``.
542
+ area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
543
+
544
+ Returns:
545
+ output (VisImage): image object with visualizations.
546
+ """
547
+ pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
548
+
549
+ if self._instance_mode == ColorMode.IMAGE_BW:
550
+ self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
551
+
552
+ # draw mask for all semantic segments first i.e. "stuff"
553
+ for mask, sinfo in pred.semantic_masks():
554
+ category_idx = sinfo["category_id"]
555
+ try:
556
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
557
+ except AttributeError:
558
+ mask_color = None
559
+
560
+ text = (
561
+ self.metadata.stuff_classes[category_idx]
562
+ .replace("-other", "")
563
+ .replace("-merged", "")
564
+ )
565
+ self.draw_binary_mask(
566
+ mask,
567
+ color=mask_color,
568
+ edge_color=_OFF_WHITE,
569
+ text=text,
570
+ alpha=alpha,
571
+ area_threshold=area_threshold,
572
+ )
573
+
574
+ # draw mask for all instances second
575
+ all_instances = list(pred.instance_masks())
576
+ if len(all_instances) == 0:
577
+ return self.output
578
+ masks, sinfo = list(zip(*all_instances))
579
+ category_ids = [x["category_id"] for x in sinfo]
580
+
581
+ try:
582
+ scores = [x["score"] for x in sinfo]
583
+ except KeyError:
584
+ scores = None
585
+ class_names = [
586
+ name.replace("-other", "").replace("-merged", "")
587
+ for name in self.metadata.thing_classes
588
+ ]
589
+ labels = _create_text_labels(
590
+ category_ids, scores, class_names, [x.get("iscrowd", 0) for x in sinfo]
591
+ )
592
+
593
+ try:
594
+ colors = [
595
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
596
+ for c in category_ids
597
+ ]
598
+ except AttributeError:
599
+ colors = None
600
+ self.overlay_instances(
601
+ masks=masks, labels=labels, assigned_colors=colors, alpha=alpha
602
+ )
603
+
604
+ return self.output
605
+
606
+ draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility
607
+
608
+ def draw_dataset_dict(self, dic):
609
+ """
610
+ Draw annotations/segmentaions in Detectron2 Dataset format.
611
+
612
+ Args:
613
+ dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.
614
+
615
+ Returns:
616
+ output (VisImage): image object with visualizations.
617
+ """
618
+ annos = dic.get("annotations", None)
619
+ if annos:
620
+ if "segmentation" in annos[0]:
621
+ masks = [x["segmentation"] for x in annos]
622
+ else:
623
+ masks = None
624
+ if "keypoints" in annos[0]:
625
+ keypts = [x["keypoints"] for x in annos]
626
+ keypts = np.array(keypts).reshape(len(annos), -1, 3)
627
+ else:
628
+ keypts = None
629
+
630
+ boxes = [
631
+ (
632
+ BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
633
+ if len(x["bbox"]) == 4
634
+ else x["bbox"]
635
+ )
636
+ for x in annos
637
+ ]
638
+
639
+ colors = None
640
+ category_ids = [x["category_id"] for x in annos]
641
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(
642
+ "thing_colors"
643
+ ):
644
+ colors = [
645
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
646
+ for c in category_ids
647
+ ]
648
+ names = self.metadata.get("thing_classes", None)
649
+ labels = _create_text_labels(
650
+ category_ids,
651
+ scores=None,
652
+ class_names=names,
653
+ is_crowd=[x.get("iscrowd", 0) for x in annos],
654
+ )
655
+ self.overlay_instances(
656
+ labels=labels,
657
+ boxes=boxes,
658
+ masks=masks,
659
+ keypoints=keypts,
660
+ assigned_colors=colors,
661
+ )
662
+
663
+ sem_seg = dic.get("sem_seg", None)
664
+ if sem_seg is None and "sem_seg_file_name" in dic:
665
+ with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
666
+ sem_seg = Image.open(f)
667
+ sem_seg = np.asarray(sem_seg, dtype="uint8")
668
+ if sem_seg is not None:
669
+ self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.4)
670
+
671
+ pan_seg = dic.get("pan_seg", None)
672
+ if pan_seg is None and "pan_seg_file_name" in dic:
673
+ with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
674
+ pan_seg = Image.open(f)
675
+ pan_seg = np.asarray(pan_seg)
676
+ from panopticapi.utils import rgb2id
677
+
678
+ pan_seg = rgb2id(pan_seg)
679
+ if pan_seg is not None:
680
+ segments_info = dic["segments_info"]
681
+ pan_seg = torch.tensor(pan_seg)
682
+ self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.7)
683
+ return self.output
684
+
685
+ def overlay_instances(
686
+ self,
687
+ *,
688
+ boxes=None,
689
+ labels=None,
690
+ masks=None,
691
+ keypoints=None,
692
+ assigned_colors=None,
693
+ binary_masks=None,
694
+ alpha=0.5,
695
+ label_mode="1",
696
+ ):
697
+ """
698
+ Args:
699
+ boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
700
+ or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
701
+ or a :class:`RotatedBoxes`,
702
+ or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
703
+ for the N objects in a single image,
704
+ labels (list[str]): the text to be displayed for each instance.
705
+ masks (masks-like object): Supported types are:
706
+
707
+ * :class:`detectron2.structures.PolygonMasks`,
708
+ :class:`detectron2.structures.BitMasks`.
709
+ * list[list[ndarray]]: contains the segmentation masks for all objects in one image.
710
+ The first level of the list corresponds to individual instances. The second
711
+ level to all the polygon that compose the instance, and the third level
712
+ to the polygon coordinates. The third level should have the format of
713
+ [x0, y0, x1, y1, ..., xn, yn] (n >= 3).
714
+ * list[ndarray]: each ndarray is a binary mask of shape (H, W).
715
+ * list[dict]: each dict is a COCO-style RLE.
716
+ keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
717
+ where the N is the number of instances and K is the number of keypoints.
718
+ The last dimension corresponds to (x, y, visibility or score).
719
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
720
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
721
+ for full list of formats that the colors are accepted in.
722
+ Returns:
723
+ output (VisImage): image object with visualizations.
724
+ """
725
+ num_instances = 0
726
+ if boxes is not None:
727
+ boxes = self._convert_boxes(boxes)
728
+ num_instances = len(boxes)
729
+ if masks is not None:
730
+ masks = self._convert_masks(masks)
731
+ if num_instances:
732
+ assert len(masks) == num_instances
733
+ else:
734
+ num_instances = len(masks)
735
+ if keypoints is not None:
736
+ if num_instances:
737
+ assert len(keypoints) == num_instances
738
+ else:
739
+ num_instances = len(keypoints)
740
+ keypoints = self._convert_keypoints(keypoints)
741
+ if labels is not None:
742
+ assert len(labels) == num_instances
743
+ if assigned_colors is None:
744
+ assigned_colors = [
745
+ random_color(rgb=True, maximum=1) for _ in range(num_instances)
746
+ ]
747
+ if num_instances == 0:
748
+ return labels, [], []
749
+ if boxes is not None and boxes.shape[1] == 5:
750
+ return self.overlay_rotated_instances(
751
+ boxes=boxes, labels=labels, assigned_colors=assigned_colors
752
+ )
753
+
754
+ # Display in largest to smallest order to reduce occlusion.
755
+ areas = None
756
+ if boxes is not None:
757
+ areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
758
+ elif masks is not None:
759
+ areas = np.asarray([x.area() for x in masks])
760
+
761
+ # if areas is not None:
762
+ # # sorted_idxs = np.argsort(areas).tolist()
763
+ # sorted_idxs = np.argsort(-areas).tolist()
764
+ # # Re-order overlapped instances in descending order.
765
+ # boxes = boxes[sorted_idxs] if boxes is not None else None
766
+ # labels = [labels[k] for k in sorted_idxs] if labels is not None else None
767
+ # masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
768
+ # binary_masks = (
769
+ # [binary_masks[idx] for idx in sorted_idxs]
770
+ # if binary_masks is not None
771
+ # else None
772
+ # )
773
+ # assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
774
+ # keypoints = keypoints[sorted_idxs] if keypoints is not None else None
775
+
776
+ marks = []
777
+ marks_position = []
778
+ added_positions = set()
779
+ for i in range(num_instances):
780
+ color = assigned_colors[i]
781
+ if boxes is not None:
782
+ self.draw_box(boxes[i], alpha=1, edge_color=color)
783
+ if binary_masks is None:
784
+ # draw number for non-mask instances
785
+ mark = self._draw_number_in_box(
786
+ boxes[i], i + 1, color=color, label_mode=label_mode
787
+ )
788
+ marks.append(mark)
789
+
790
+ if binary_masks is not None:
791
+ mark, mask_position = self._draw_number_in_mask(
792
+ binary_mask=binary_masks[i].astype("uint8"),
793
+ text=i + 1,
794
+ color=color,
795
+ added_positions=added_positions,
796
+ label_mode=label_mode,
797
+ )
798
+ marks.append(mark)
799
+ marks_position.append(mask_position)
800
+
801
+ self.draw_binary_mask(
802
+ binary_masks[i],
803
+ color=color,
804
+ edge_color=_OFF_WHITE,
805
+ alpha=alpha,
806
+ )
807
+
808
+ if masks is not None:
809
+ for segment in masks[i].polygons:
810
+ self.draw_polygon(
811
+ segment.reshape(-1, 2), color, alpha=0
812
+ ) # alpha=0 so holes in masks are not colored
813
+
814
+ # draw keypoints
815
+ if keypoints is not None:
816
+ for keypoints_per_instance in keypoints:
817
+ self.draw_and_connect_keypoints(keypoints_per_instance)
818
+
819
+ # return labels, marks, sorted_idxs, marks_position
820
+ return labels, marks, marks_position
821
+
822
+ def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
823
+ """
824
+ Args:
825
+ boxes (ndarray): an Nx5 numpy array of
826
+ (x_center, y_center, width, height, angle_degrees) format
827
+ for the N objects in a single image.
828
+ labels (list[str]): the text to be displayed for each instance.
829
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
830
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
831
+ for full list of formats that the colors are accepted in.
832
+
833
+ Returns:
834
+ output (VisImage): image object with visualizations.
835
+ """
836
+ num_instances = len(boxes)
837
+
838
+ if assigned_colors is None:
839
+ assigned_colors = [
840
+ random_color(rgb=True, maximum=1) for _ in range(num_instances)
841
+ ]
842
+ if num_instances == 0:
843
+ return self.output
844
+
845
+ # Display in largest to smallest order to reduce occlusion.
846
+ if boxes is not None:
847
+ areas = boxes[:, 2] * boxes[:, 3]
848
+
849
+ sorted_idxs = np.argsort(-areas).tolist()
850
+ # Re-order overlapped instances in descending order.
851
+ boxes = boxes[sorted_idxs]
852
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
853
+ colors = [assigned_colors[idx] for idx in sorted_idxs]
854
+
855
+ for i in range(num_instances):
856
+ self.draw_rotated_box_with_label(
857
+ boxes[i],
858
+ edge_color=colors[i],
859
+ label=labels[i] if labels is not None else None,
860
+ )
861
+
862
+ return self.output
863
+
864
+ def draw_and_connect_keypoints(self, keypoints):
865
+ """
866
+ Draws keypoints of an instance and follows the rules for keypoint connections
867
+ to draw lines between appropriate keypoints. This follows color heuristics for
868
+ line color.
869
+
870
+ Args:
871
+ keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
872
+ and the last dimension corresponds to (x, y, probability).
873
+
874
+ Returns:
875
+ output (VisImage): image object with visualizations.
876
+ """
877
+ visible = {}
878
+ keypoint_names = self.metadata.get("keypoint_names")
879
+ for idx, keypoint in enumerate(keypoints):
880
+ # draw keypoint
881
+ x, y, prob = keypoint
882
+ if prob > self.keypoint_threshold:
883
+ self.draw_circle((x, y), color=_RED)
884
+ if keypoint_names:
885
+ keypoint_name = keypoint_names[idx]
886
+ visible[keypoint_name] = (x, y)
887
+
888
+ if self.metadata.get("keypoint_connection_rules"):
889
+ for kp0, kp1, color in self.metadata.keypoint_connection_rules:
890
+ if kp0 in visible and kp1 in visible:
891
+ x0, y0 = visible[kp0]
892
+ x1, y1 = visible[kp1]
893
+ color = tuple(x / 255.0 for x in color)
894
+ self.draw_line([x0, x1], [y0, y1], color=color)
895
+
896
+ # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
897
+ # Note that this strategy is specific to person keypoints.
898
+ # For other keypoints, it should just do nothing
899
+ try:
900
+ ls_x, ls_y = visible["left_shoulder"]
901
+ rs_x, rs_y = visible["right_shoulder"]
902
+ mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
903
+ except KeyError:
904
+ pass
905
+ else:
906
+ # draw line from nose to mid-shoulder
907
+ nose_x, nose_y = visible.get("nose", (None, None))
908
+ if nose_x is not None:
909
+ self.draw_line(
910
+ [nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED
911
+ )
912
+
913
+ try:
914
+ # draw line from mid-shoulder to mid-hip
915
+ lh_x, lh_y = visible["left_hip"]
916
+ rh_x, rh_y = visible["right_hip"]
917
+ except KeyError:
918
+ pass
919
+ else:
920
+ mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
921
+ self.draw_line(
922
+ [mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED
923
+ )
924
+ return self.output
925
+
926
+ def mask_dims_from_binary(self, binary_mask):
927
+ ind_y, ind_x = np.where(binary_mask == 1)
928
+ min_ind_x = np.min(ind_x)
929
+ max_ind_x = np.max(ind_x)
930
+ min_ind_y = np.min(ind_y)
931
+ max_ind_y = np.max(ind_y)
932
+ return (max_ind_x - min_ind_x), (max_ind_y - min_ind_y)
933
+
934
+ def reposition_label(self, position, cur, binary_mask, move_count):
935
+ img_width, img_height = self.output.width, self.output.height
936
+ mask_width, mask_height = self.mask_dims_from_binary(binary_mask)
937
+
938
+ # set resposition thresholds
939
+ mask_width_limit, mask_height_limit = (
940
+ 25,
941
+ 25,
942
+ ) # limit for width and height size for object covering
943
+ location_diff_threshold = 15 # limit for the distance between two labels
944
+ x_boundry_limit, y_boundry_limit = (
945
+ 20,
946
+ 20,
947
+ ) # limit for the distancing the label from edges
948
+
949
+ offset_x = 15 # move in x direction
950
+ offset_y = 15 # move in y direction
951
+
952
+ x1, y1 = position
953
+
954
+ if (
955
+ mask_width < mask_width_limit
956
+ and mask_height < mask_height_limit
957
+ and move_count == 0
958
+ ):
959
+ move_x = offset_x if offset_x + x1 < img_width else -offset_x
960
+ move_y = offset_y if offset_y + y1 < img_height else -offset_y
961
+ return (True, move_x, move_y)
962
+
963
+ for x2, y2 in cur:
964
+ if abs(x1 - x2) + abs(y1 - y2) < location_diff_threshold:
965
+ move_x = offset_x if x1 >= x2 else -offset_x
966
+ move_y = offset_y if y1 >= y2 else -offset_y
967
+ move_x = (
968
+ 0
969
+ if x1 + move_x > img_width - x_boundry_limit
970
+ or x1 + move_x < x_boundry_limit
971
+ else move_x
972
+ )
973
+ move_y = (
974
+ 0
975
+ if y1 + move_y > img_height - y_boundry_limit
976
+ or y1 + move_y < y_boundry_limit
977
+ else move_y
978
+ )
979
+ return (
980
+ True,
981
+ move_x,
982
+ move_y,
983
+ )
984
+ return (False, 0, 0)
985
+
986
+ def locate_label_position(self, original_position, added_positions, binary_mask):
987
+ if added_positions is None or binary_mask is None:
988
+ return original_position
989
+
990
+ x, y = original_position
991
+
992
+ move_count = 0
993
+ reposition, x_move, y_move = self.reposition_label(
994
+ (x, y), added_positions, binary_mask, move_count
995
+ )
996
+ while reposition and move_count < 10:
997
+ x += x_move
998
+ y += y_move
999
+ move_count += 1
1000
+ reposition, x_move, y_move = self.reposition_label(
1001
+ (x, y), added_positions, binary_mask, move_count
1002
+ )
1003
+ added_positions.add((x, y))
1004
+ return x, y
1005
+
1006
+ """
1007
+ Primitive drawing functions:
1008
+ """
1009
+
1010
+ def draw_text(
1011
+ self,
1012
+ text,
1013
+ position,
1014
+ added_positions=None,
1015
+ binary_mask=None,
1016
+ *,
1017
+ font_size=None,
1018
+ color="g",
1019
+ horizontal_alignment="center",
1020
+ rotation=0,
1021
+ ):
1022
+ """
1023
+ Args:
1024
+ text (str): class label
1025
+ position (tuple): a tuple of the x and y coordinates to place text on image.
1026
+ font_size (int, optional): font of the text. If not provided, a font size
1027
+ proportional to the image width is calculated and used.
1028
+ color: color of the text. Refer to `matplotlib.colors` for full list
1029
+ of formats that are accepted.
1030
+ horizontal_alignment (str): see `matplotlib.text.Text`
1031
+ rotation: rotation angle in degrees CCW
1032
+
1033
+ Returns:
1034
+ output (VisImage): image object with text drawn.
1035
+ """
1036
+ if not font_size:
1037
+ font_size = self._default_font_size
1038
+
1039
+ # since the text background is dark, we don't want the text to be dark
1040
+ color = np.maximum(list(mplc.to_rgb(color)), 0.15)
1041
+ color[np.argmax(color)] = max(0.8, np.max(color))
1042
+
1043
+ def contrasting_color(rgb):
1044
+ """Returns 'white' or 'black' depending on which color contrasts more with the given RGB value."""
1045
+
1046
+ # Decompose the RGB tuple
1047
+ R, G, B = rgb
1048
+
1049
+ # Calculate the Y value
1050
+ Y = 0.299 * R + 0.587 * G + 0.114 * B
1051
+
1052
+ # If Y value is greater than 128, it's closer to white so return black. Otherwise, return white.
1053
+ return "black" if Y > 128 else "white"
1054
+
1055
+ bbox_background = contrasting_color(color * 255)
1056
+
1057
+ x, y = self.locate_label_position(
1058
+ original_position=position,
1059
+ added_positions=added_positions,
1060
+ binary_mask=binary_mask,
1061
+ )
1062
+
1063
+ self.output.ax.text(
1064
+ x,
1065
+ y,
1066
+ text,
1067
+ size=font_size * self.output.scale,
1068
+ family="sans-serif",
1069
+ bbox={
1070
+ "facecolor": bbox_background,
1071
+ "alpha": 0.8,
1072
+ "pad": 0.7,
1073
+ "edgecolor": "none",
1074
+ },
1075
+ verticalalignment="top",
1076
+ horizontalalignment=horizontal_alignment,
1077
+ color=color,
1078
+ zorder=10,
1079
+ rotation=rotation,
1080
+ )
1081
+ return self.output
1082
+
1083
+ def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
1084
+ """
1085
+ Args:
1086
+ box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
1087
+ are the coordinates of the image's top left corner. x1 and y1 are the
1088
+ coordinates of the image's bottom right corner.
1089
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1090
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
1091
+ for full list of formats that are accepted.
1092
+ line_style (string): the string to use to create the outline of the boxes.
1093
+
1094
+ Returns:
1095
+ output (VisImage): image object with box drawn.
1096
+ """
1097
+ x0, y0, x1, y1 = box_coord
1098
+ width = x1 - x0
1099
+ height = y1 - y0
1100
+
1101
+ linewidth = max(self._default_font_size / 12, 1) * self.boarder_width_multiplier
1102
+
1103
+ self.output.ax.add_patch(
1104
+ mpl.patches.Rectangle(
1105
+ (x0, y0),
1106
+ width,
1107
+ height,
1108
+ fill=False,
1109
+ edgecolor=edge_color,
1110
+ linewidth=linewidth * self.output.scale,
1111
+ alpha=alpha,
1112
+ linestyle=line_style,
1113
+ )
1114
+ )
1115
+ return self.output
1116
+
1117
+ def draw_rotated_box_with_label(
1118
+ self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
1119
+ ):
1120
+ """
1121
+ Draw a rotated box with label on its top-left corner.
1122
+
1123
+ Args:
1124
+ rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
1125
+ where cnt_x and cnt_y are the center coordinates of the box.
1126
+ w and h are the width and height of the box. angle represents how
1127
+ many degrees the box is rotated CCW with regard to the 0-degree box.
1128
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1129
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
1130
+ for full list of formats that are accepted.
1131
+ line_style (string): the string to use to create the outline of the boxes.
1132
+ label (string): label for rotated box. It will not be rendered when set to None.
1133
+
1134
+ Returns:
1135
+ output (VisImage): image object with box drawn.
1136
+ """
1137
+ cnt_x, cnt_y, w, h, angle = rotated_box
1138
+ area = w * h
1139
+ # use thinner lines when the box is small
1140
+ linewidth = self._default_font_size / (
1141
+ 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
1142
+ )
1143
+
1144
+ theta = angle * math.pi / 180.0
1145
+ c = math.cos(theta)
1146
+ s = math.sin(theta)
1147
+ rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
1148
+ # x: left->right ; y: top->down
1149
+ rotated_rect = [
1150
+ (s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect
1151
+ ]
1152
+ for k in range(4):
1153
+ j = (k + 1) % 4
1154
+ self.draw_line(
1155
+ [rotated_rect[k][0], rotated_rect[j][0]],
1156
+ [rotated_rect[k][1], rotated_rect[j][1]],
1157
+ color=edge_color,
1158
+ linestyle="--" if k == 1 else line_style,
1159
+ linewidth=linewidth,
1160
+ )
1161
+
1162
+ if label is not None:
1163
+ text_pos = rotated_rect[1] # topleft corner
1164
+
1165
+ height_ratio = h / np.sqrt(self.output.height * self.output.width)
1166
+ label_color = self._change_color_brightness(
1167
+ edge_color, brightness_factor=0.7
1168
+ )
1169
+ font_size = (
1170
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
1171
+ * 0.5
1172
+ * self._default_font_size
1173
+ )
1174
+ self.draw_text(
1175
+ label, text_pos, color=label_color, font_size=font_size, rotation=angle
1176
+ )
1177
+
1178
+ return self.output
1179
+
1180
+ def draw_circle(self, circle_coord, color, radius=3):
1181
+ """
1182
+ Args:
1183
+ circle_coord (list(int) or tuple(int)): contains the x and y coordinates
1184
+ of the center of the circle.
1185
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1186
+ formats that are accepted.
1187
+ radius (int): radius of the circle.
1188
+
1189
+ Returns:
1190
+ output (VisImage): image object with box drawn.
1191
+ """
1192
+ x, y = circle_coord
1193
+ self.output.ax.add_patch(
1194
+ mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)
1195
+ )
1196
+ return self.output
1197
+
1198
+ def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
1199
+ """
1200
+ Args:
1201
+ x_data (list[int]): a list containing x values of all the points being drawn.
1202
+ Length of list should match the length of y_data.
1203
+ y_data (list[int]): a list containing y values of all the points being drawn.
1204
+ Length of list should match the length of x_data.
1205
+ color: color of the line. Refer to `matplotlib.colors` for a full list of
1206
+ formats that are accepted.
1207
+ linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
1208
+ for a full list of formats that are accepted.
1209
+ linewidth (float or None): width of the line. When it's None,
1210
+ a default value will be computed and used.
1211
+
1212
+ Returns:
1213
+ output (VisImage): image object with line drawn.
1214
+ """
1215
+ if linewidth is None:
1216
+ linewidth = self._default_font_size / 3
1217
+ linewidth = max(linewidth, 1)
1218
+ self.output.ax.add_line(
1219
+ mpl.lines.Line2D(
1220
+ x_data,
1221
+ y_data,
1222
+ linewidth=linewidth * self.output.scale,
1223
+ color=color,
1224
+ linestyle=linestyle,
1225
+ )
1226
+ )
1227
+ return self.output
1228
+
1229
+ def draw_binary_mask(
1230
+ self,
1231
+ binary_mask,
1232
+ color=None,
1233
+ *,
1234
+ edge_color=None,
1235
+ text=None,
1236
+ alpha=0.7,
1237
+ area_threshold=10,
1238
+ ):
1239
+ """
1240
+ Args:
1241
+ binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
1242
+ W is the image width. Each value in the array is either a 0 or 1 value of uint8
1243
+ type.
1244
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
1245
+ formats that are accepted. If None, will pick a random color.
1246
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1247
+ full list of formats that are accepted.
1248
+ text (str): if None, will be drawn on the object
1249
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1250
+ area_threshold (float): a connected component smaller than this area will not be shown.
1251
+
1252
+ Returns:
1253
+ output (VisImage): image object with mask drawn.
1254
+ """
1255
+ if color is None:
1256
+ color = random_color(rgb=True, maximum=1)
1257
+ color = mplc.to_rgb(color)
1258
+
1259
+ has_valid_segment = False
1260
+ binary_mask = binary_mask.astype("uint8") # opencv needs uint8
1261
+ mask = GenericMask(binary_mask, self.output.height, self.output.width)
1262
+ shape2d = (binary_mask.shape[0], binary_mask.shape[1])
1263
+
1264
+ if not mask.has_holes:
1265
+ # draw polygons for regular masks
1266
+ for segment in mask.polygons:
1267
+ area = mask_util.area(
1268
+ mask_util.frPyObjects([segment], shape2d[0], shape2d[1])
1269
+ )
1270
+ if area < (area_threshold or 0):
1271
+ continue
1272
+ has_valid_segment = True
1273
+ segment = segment.reshape(-1, 2)
1274
+ self.draw_polygon(
1275
+ segment, color=color, edge_color=edge_color, alpha=alpha
1276
+ )
1277
+ else:
1278
+ # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
1279
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
1280
+ rgba[:, :, :3] = color
1281
+ rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
1282
+ has_valid_segment = True
1283
+ self.output.ax.imshow(
1284
+ rgba, extent=(0, self.output.width, self.output.height, 0)
1285
+ )
1286
+
1287
+ if text is not None and has_valid_segment:
1288
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1289
+ self._draw_text_in_mask(binary_mask, text, lighter_color)
1290
+ return self.output
1291
+
1292
+ def draw_binary_mask_with_number(
1293
+ self,
1294
+ binary_mask,
1295
+ color=None,
1296
+ *,
1297
+ edge_color=None,
1298
+ text=None,
1299
+ label_mode="1",
1300
+ alpha=0.1,
1301
+ anno_mode=["Mask"],
1302
+ area_threshold=10,
1303
+ ):
1304
+ """
1305
+ Args:
1306
+ binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
1307
+ W is the image width. Each value in the array is either a 0 or 1 value of uint8
1308
+ type.
1309
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
1310
+ formats that are accepted. If None, will pick a random color.
1311
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1312
+ full list of formats that are accepted.
1313
+ text (str): if None, will be drawn on the object
1314
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1315
+ area_threshold (float): a connected component smaller than this area will not be shown.
1316
+
1317
+ Returns:
1318
+ output (VisImage): image object with mask drawn.
1319
+ """
1320
+ if color is None:
1321
+ randint = random.randint(0, len(self.color_proposals) - 1)
1322
+ color = self.color_proposals[randint]
1323
+ color = mplc.to_rgb(color)
1324
+
1325
+ has_valid_segment = True
1326
+ binary_mask = binary_mask.astype("uint8") # opencv needs uint8
1327
+ mask = GenericMask(binary_mask, self.output.height, self.output.width)
1328
+ shape2d = (binary_mask.shape[0], binary_mask.shape[1])
1329
+ bbox = mask.bbox()
1330
+
1331
+ if "Mask" in anno_mode:
1332
+ if not mask.has_holes:
1333
+ # draw polygons for regular masks
1334
+ for segment in mask.polygons:
1335
+ area = mask_util.area(
1336
+ mask_util.frPyObjects([segment], shape2d[0], shape2d[1])
1337
+ )
1338
+ if area < (area_threshold or 0):
1339
+ continue
1340
+ has_valid_segment = True
1341
+ segment = segment.reshape(-1, 2)
1342
+ self.draw_polygon(
1343
+ segment, color=color, edge_color=edge_color, alpha=alpha
1344
+ )
1345
+ else:
1346
+ # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
1347
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
1348
+ rgba[:, :, :3] = color
1349
+ rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
1350
+ has_valid_segment = True
1351
+ self.output.ax.imshow(
1352
+ rgba, extent=(0, self.output.width, self.output.height, 0)
1353
+ )
1354
+
1355
+ if "Box" in anno_mode:
1356
+ self.draw_box(bbox, edge_color=color, alpha=0.75)
1357
+
1358
+ if "Mark" in anno_mode:
1359
+ has_valid_segment = True
1360
+ else:
1361
+ has_valid_segment = False
1362
+
1363
+ if text is not None and has_valid_segment:
1364
+ # lighter_color = tuple([x*0.2 for x in color])
1365
+ lighter_color = [
1366
+ 1,
1367
+ 1,
1368
+ 1,
1369
+ ] # self._change_color_brightness(color, brightness_factor=0.7)
1370
+ self._draw_number_in_mask(
1371
+ binary_mask=binary_mask,
1372
+ text=text,
1373
+ color=lighter_color,
1374
+ label_mode=label_mode,
1375
+ )
1376
+ return self.output
1377
+
1378
+ def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5):
1379
+ """
1380
+ Args:
1381
+ soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].
1382
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
1383
+ formats that are accepted. If None, will pick a random color.
1384
+ text (str): if None, will be drawn on the object
1385
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1386
+
1387
+ Returns:
1388
+ output (VisImage): image object with mask drawn.
1389
+ """
1390
+ if color is None:
1391
+ color = random_color(rgb=True, maximum=1)
1392
+ color = mplc.to_rgb(color)
1393
+
1394
+ shape2d = (soft_mask.shape[0], soft_mask.shape[1])
1395
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
1396
+ rgba[:, :, :3] = color
1397
+ rgba[:, :, 3] = soft_mask * alpha
1398
+ self.output.ax.imshow(
1399
+ rgba, extent=(0, self.output.width, self.output.height, 0)
1400
+ )
1401
+
1402
+ if text is not None:
1403
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1404
+ binary_mask = (soft_mask > 0.5).astype("uint8")
1405
+ self._draw_text_in_mask(binary_mask, text, lighter_color)
1406
+ return self.output
1407
+
1408
+ def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
1409
+ """
1410
+ Args:
1411
+ segment: numpy array of shape Nx2, containing all the points in the polygon.
1412
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1413
+ formats that are accepted.
1414
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1415
+ full list of formats that are accepted. If not provided, a darker shade
1416
+ of the polygon color will be used instead.
1417
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1418
+
1419
+ Returns:
1420
+ output (VisImage): image object with polygon drawn.
1421
+ """
1422
+ if edge_color is None:
1423
+ # make edge color darker than the polygon color
1424
+ if alpha > 0.8:
1425
+ edge_color = self._change_color_brightness(
1426
+ color, brightness_factor=-0.7
1427
+ )
1428
+ else:
1429
+ edge_color = color
1430
+ edge_color = mplc.to_rgb(edge_color) + (1,)
1431
+
1432
+ polygon = mpl.patches.Polygon(
1433
+ segment,
1434
+ fill=True,
1435
+ facecolor=mplc.to_rgb(color) + (alpha,),
1436
+ edgecolor=edge_color,
1437
+ linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
1438
+ )
1439
+ self.output.ax.add_patch(polygon)
1440
+ return self.output
1441
+
1442
+ """
1443
+ Internal methods:
1444
+ """
1445
+
1446
+ def _jitter(self, color):
1447
+ """
1448
+ Randomly modifies given color to produce a slightly different color than the color given.
1449
+
1450
+ Args:
1451
+ color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
1452
+ picked. The values in the list are in the [0.0, 1.0] range.
1453
+
1454
+ Returns:
1455
+ jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
1456
+ color after being jittered. The values in the list are in the [0.0, 1.0] range.
1457
+ """
1458
+ color = mplc.to_rgb(color)
1459
+ # np.random.seed(0)
1460
+ vec = np.random.rand(3)
1461
+ # better to do it in another color space
1462
+ vec = vec / np.linalg.norm(vec) * 0.5
1463
+ res = np.clip(vec + color, 0, 1)
1464
+ return tuple(res)
1465
+
1466
+ def _create_grayscale_image(self, mask=None):
1467
+ """
1468
+ Create a grayscale version of the original image.
1469
+ The colors in masked area, if given, will be kept.
1470
+ """
1471
+ img_bw = self.img.astype("f4").mean(axis=2)
1472
+ img_bw = np.stack([img_bw] * 3, axis=2)
1473
+ if mask is not None:
1474
+ img_bw[mask] = self.img[mask]
1475
+ return img_bw
1476
+
1477
+ def _change_color_brightness(self, color, brightness_factor):
1478
+ """
1479
+ Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
1480
+ less or more saturation than the original color.
1481
+
1482
+ Args:
1483
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1484
+ formats that are accepted.
1485
+ brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
1486
+ 0 will correspond to no change, a factor in [-1.0, 0) range will result in
1487
+ a darker color and a factor in (0, 1.0] range will result in a lighter color.
1488
+
1489
+ Returns:
1490
+ modified_color (tuple[double]): a tuple containing the RGB values of the
1491
+ modified color. Each value in the tuple is in the [0.0, 1.0] range.
1492
+ """
1493
+ assert brightness_factor >= -1.0 and brightness_factor <= 1.0
1494
+ color = mplc.to_rgb(color)
1495
+ polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
1496
+ modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
1497
+ modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
1498
+ modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
1499
+ modified_color = colorsys.hls_to_rgb(
1500
+ polygon_color[0], modified_lightness, polygon_color[2]
1501
+ )
1502
+ return modified_color
1503
+
1504
+ def _convert_boxes(self, boxes):
1505
+ """
1506
+ Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
1507
+ """
1508
+ if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
1509
+ return boxes.tensor.detach().numpy()
1510
+ else:
1511
+ return np.asarray(boxes)
1512
+
1513
+ def _convert_masks(self, masks_or_polygons):
1514
+ """
1515
+ Convert different format of masks or polygons to a tuple of masks and polygons.
1516
+
1517
+ Returns:
1518
+ list[GenericMask]:
1519
+ """
1520
+
1521
+ m = masks_or_polygons
1522
+ if isinstance(m, PolygonMasks):
1523
+ m = m.polygons
1524
+ if isinstance(m, BitMasks):
1525
+ m = m.tensor.numpy()
1526
+ if isinstance(m, torch.Tensor):
1527
+ m = m.numpy()
1528
+ ret = []
1529
+ for x in m:
1530
+ if isinstance(x, GenericMask):
1531
+ ret.append(x)
1532
+ else:
1533
+ ret.append(GenericMask(x, self.output.height, self.output.width))
1534
+ return ret
1535
+
1536
+ def _draw_number_in_box(self, box, text, color, label_mode="1"):
1537
+ """
1538
+ Find proper places to draw text given a box.
1539
+ """
1540
+ x0, y0, x1, y1 = box
1541
+ text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
1542
+ horiz_align = "left"
1543
+ # for small objects, draw text at the side to avoid occlusion
1544
+ instance_area = (y1 - y0) * (x1 - x0)
1545
+ if (
1546
+ instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
1547
+ or y1 - y0 < 40 * self.output.scale
1548
+ ):
1549
+ if y1 >= self.output.height - 5:
1550
+ text_pos = (x1, y0)
1551
+ else:
1552
+ text_pos = (x0, y1)
1553
+
1554
+ height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
1555
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1556
+ font_size = (
1557
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
1558
+ * 0.65
1559
+ * self._default_font_size
1560
+ )
1561
+ if label_mode == "a":
1562
+ text = self.number_to_string(int(text))
1563
+ else:
1564
+ text = text
1565
+ self.draw_text(
1566
+ text,
1567
+ text_pos,
1568
+ color=lighter_color,
1569
+ horizontal_alignment=horiz_align,
1570
+ font_size=font_size,
1571
+ )
1572
+
1573
+ return str(text)
1574
+
1575
+ @staticmethod
1576
+ def number_to_string(n):
1577
+ chars = []
1578
+ while n:
1579
+ n, remainder = divmod(n - 1, 26)
1580
+ chars.append(chr(97 + remainder))
1581
+ return "".join(reversed(chars))
1582
+
1583
+ def _draw_number_in_mask(
1584
+ self, binary_mask, text, color, added_positions=None, label_mode="1"
1585
+ ):
1586
+ """
1587
+ Find proper places to draw text given a binary mask.
1588
+ """
1589
+ binary_mask = np.pad(binary_mask, ((1, 1), (1, 1)), "constant")
1590
+ mask_dt = cv2.distanceTransform(binary_mask, cv2.DIST_L2, 0)
1591
+ mask_dt = mask_dt[1:-1, 1:-1]
1592
+ max_dist = np.max(mask_dt)
1593
+ coords_y, coords_x = np.where(mask_dt == max_dist) # coords is [y, x]
1594
+
1595
+ if label_mode == "a":
1596
+ text = self.number_to_string(int(text))
1597
+ else:
1598
+ text = text
1599
+
1600
+ text_position = (
1601
+ coords_x[len(coords_x) // 2] + 2,
1602
+ coords_y[len(coords_y) // 2] - 6,
1603
+ )
1604
+ self.draw_text(
1605
+ text,
1606
+ text_position,
1607
+ added_positions=added_positions,
1608
+ binary_mask=binary_mask,
1609
+ color=color,
1610
+ )
1611
+
1612
+ return str(text), text_position
1613
+
1614
+ # _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
1615
+ # if stats[1:, -1].size == 0:
1616
+ # return
1617
+ # largest_component_id = np.argmax(stats[1:, -1]) + 1
1618
+
1619
+ # # draw text on the largest component, as well as other very large components.
1620
+ # for cid in range(1, _num_cc):
1621
+ # if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
1622
+ # # median is more stable than centroid
1623
+ # # center = centroids[largest_component_id]
1624
+ # center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
1625
+ # # bottom=np.max((cc_labels == cid).nonzero(), axis=1)[::-1]
1626
+ # # center[1]=bottom[1]+2
1627
+ # self.draw_text(text, center, color=color)
1628
+
1629
+ def _draw_text_in_mask(self, binary_mask, text, color):
1630
+ """
1631
+ Find proper places to draw text given a binary mask.
1632
+ """
1633
+ _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(
1634
+ binary_mask, 8
1635
+ )
1636
+ if stats[1:, -1].size == 0:
1637
+ return
1638
+ largest_component_id = np.argmax(stats[1:, -1]) + 1
1639
+
1640
+ # draw text on the largest component, as well as other very large components.
1641
+ for cid in range(1, _num_cc):
1642
+ if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
1643
+ # median is more stable than centroid
1644
+ # center = centroids[largest_component_id]
1645
+ center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
1646
+ bottom = np.max((cc_labels == cid).nonzero(), axis=1)[::-1]
1647
+ center[1] = bottom[1] + 2
1648
+ self.draw_text(text, center, color=color)
1649
+
1650
+ def _convert_keypoints(self, keypoints):
1651
+ if isinstance(keypoints, Keypoints):
1652
+ keypoints = keypoints.tensor
1653
+ keypoints = np.asarray(keypoints)
1654
+ return keypoints
1655
+
1656
+ def get_output(self):
1657
+ """
1658
+ Returns:
1659
+ output (VisImage): the image output containing the visualizations added
1660
+ to the image.
1661
+ """
1662
+ return self.output
sam3/agent/helpers/zoom_in.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import io
4
+ import math
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pycocotools.mask as mask_utils
9
+ from PIL import Image
10
+
11
+ from .som_utils import ColorPalette, draw_box, draw_mask, draw_text
12
+
13
+
14
+ def render_zoom_in(
15
+ object_data,
16
+ image_file,
17
+ show_box: bool = True,
18
+ show_text: bool = False,
19
+ show_holes: bool = True,
20
+ mask_alpha: float = 0.15,
21
+ ):
22
+ """
23
+ Render a two-panel visualization with a cropped original view (left/upper) and a zoomed-in
24
+ mask overlay (right/lower), then return it as a PIL.Image along with the chosen mask color (hex).
25
+
26
+ Parameters
27
+ ----------
28
+ object_data : dict
29
+ Dict containing "labels" and COCO RLE "segmentation".
30
+ Expected:
31
+ object_data["labels"][0]["noun_phrase"] : str
32
+ object_data["segmentation"] : COCO RLE (with "size": [H, W])
33
+ image_file : PIL.Image.Image
34
+ Source image (PIL).
35
+ show_box : bool
36
+ Whether to draw the bbox on the cropped original panel.
37
+ show_text : bool
38
+ Whether to draw the noun phrase label near the bbox.
39
+ show_holes : bool
40
+ Whether to render mask holes (passed through to draw_mask).
41
+ mask_alpha : float
42
+ Alpha for the mask overlay.
43
+
44
+ Returns
45
+ -------
46
+ pil_img : PIL.Image.Image
47
+ The composed visualization image.
48
+ color_hex : str
49
+ Hex string of the chosen mask color.
50
+ """
51
+
52
+ # ---- local constants (avoid module-level globals) ----
53
+ _AREA_LARGE = 0.25
54
+ _AREA_MEDIUM = 0.05
55
+
56
+ # ---- local helpers (avoid name collisions in a larger class) ----
57
+ def _get_shift(x, w, w_new, w_img):
58
+ assert 0 <= w_new <= w_img
59
+ shift = (w_new - w) / 2
60
+ if x - shift + w_new > w_img:
61
+ shift = x + w_new - w_img
62
+ return min(x, shift)
63
+
64
+ def _get_zoom_in_box(mask_box_xywh, img_h, img_w, mask_area):
65
+ box_w, box_h = mask_box_xywh[2], mask_box_xywh[3]
66
+ w_new = min(box_w + max(0.2 * box_w, 16), img_w)
67
+ h_new = min(box_h + max(0.2 * box_h, 16), img_h)
68
+
69
+ mask_relative_area = mask_area / (w_new * h_new)
70
+
71
+ # zoom-in (larger box if mask is relatively big)
72
+ w_new_large, h_new_large = w_new, h_new
73
+ if mask_relative_area > _AREA_LARGE:
74
+ ratio_large = math.sqrt(mask_relative_area / _AREA_LARGE)
75
+ w_new_large = min(w_new * ratio_large, img_w)
76
+ h_new_large = min(h_new * ratio_large, img_h)
77
+
78
+ w_shift_large = _get_shift(
79
+ mask_box_xywh[0], mask_box_xywh[2], w_new_large, img_w
80
+ )
81
+ h_shift_large = _get_shift(
82
+ mask_box_xywh[1], mask_box_xywh[3], h_new_large, img_h
83
+ )
84
+ zoom_in_box = [
85
+ mask_box_xywh[0] - w_shift_large,
86
+ mask_box_xywh[1] - h_shift_large,
87
+ w_new_large,
88
+ h_new_large,
89
+ ]
90
+
91
+ # crop box for the original/cropped image
92
+ w_new_medium, h_new_medium = w_new, h_new
93
+ if mask_relative_area > _AREA_MEDIUM:
94
+ ratio_med = math.sqrt(mask_relative_area / _AREA_MEDIUM)
95
+ w_new_medium = min(w_new * ratio_med, img_w)
96
+ h_new_medium = min(h_new * ratio_med, img_h)
97
+
98
+ w_shift_medium = _get_shift(
99
+ mask_box_xywh[0], mask_box_xywh[2], w_new_medium, img_w
100
+ )
101
+ h_shift_medium = _get_shift(
102
+ mask_box_xywh[1], mask_box_xywh[3], h_new_medium, img_h
103
+ )
104
+ img_crop_box = [
105
+ mask_box_xywh[0] - w_shift_medium,
106
+ mask_box_xywh[1] - h_shift_medium,
107
+ w_new_medium,
108
+ h_new_medium,
109
+ ]
110
+ return zoom_in_box, img_crop_box
111
+
112
+ # ---- main body ----
113
+ # Input parsing
114
+ object_label = object_data["labels"][0]["noun_phrase"]
115
+ img = image_file.convert("RGB")
116
+ bbox_xywh = mask_utils.toBbox(object_data["segmentation"]) # [x, y, w, h]
117
+
118
+ # Choose a stable, visually distant color based on crop
119
+ bbox_xyxy = [
120
+ bbox_xywh[0],
121
+ bbox_xywh[1],
122
+ bbox_xywh[0] + bbox_xywh[2],
123
+ bbox_xywh[1] + bbox_xywh[3],
124
+ ]
125
+ crop_img = img.crop(bbox_xyxy)
126
+ color_palette = ColorPalette.default()
127
+ color_obj, _ = color_palette.find_farthest_color(np.array(crop_img))
128
+ color = np.array([color_obj.r / 255, color_obj.g / 255, color_obj.b / 255])
129
+ color_hex = f"#{color_obj.r:02x}{color_obj.g:02x}{color_obj.b:02x}"
130
+
131
+ # Compute zoom-in / crop boxes
132
+ img_h, img_w = object_data["segmentation"]["size"]
133
+ mask_area = mask_utils.area(object_data["segmentation"])
134
+ zoom_in_box, img_crop_box = _get_zoom_in_box(bbox_xywh, img_h, img_w, mask_area)
135
+
136
+ # Layout choice
137
+ w, h = img_crop_box[2], img_crop_box[3]
138
+ if w < h:
139
+ fig, (ax1, ax2) = plt.subplots(1, 2)
140
+ else:
141
+ fig, (ax1, ax2) = plt.subplots(2, 1)
142
+
143
+ # Panel 1: cropped original with optional box/text
144
+ img_crop_box_xyxy = [
145
+ img_crop_box[0],
146
+ img_crop_box[1],
147
+ img_crop_box[0] + img_crop_box[2],
148
+ img_crop_box[1] + img_crop_box[3],
149
+ ]
150
+ img1 = img.crop(img_crop_box_xyxy)
151
+ bbox_xywh_rel = [
152
+ bbox_xywh[0] - img_crop_box[0],
153
+ bbox_xywh[1] - img_crop_box[1],
154
+ bbox_xywh[2],
155
+ bbox_xywh[3],
156
+ ]
157
+ ax1.imshow(img1)
158
+ ax1.axis("off")
159
+ if show_box:
160
+ draw_box(ax1, bbox_xywh_rel, edge_color=color)
161
+ if show_text:
162
+ x0, y0 = bbox_xywh_rel[0] + 2, bbox_xywh_rel[1] + 2
163
+ draw_text(ax1, object_label, [x0, y0], color=color)
164
+
165
+ # Panel 2: zoomed-in mask overlay
166
+ binary_mask = mask_utils.decode(object_data["segmentation"])
167
+ alpha = Image.fromarray((binary_mask * 255).astype("uint8"))
168
+ img_rgba = img.convert("RGBA")
169
+ img_rgba.putalpha(alpha)
170
+ zoom_in_box_xyxy = [
171
+ zoom_in_box[0],
172
+ zoom_in_box[1],
173
+ zoom_in_box[0] + zoom_in_box[2],
174
+ zoom_in_box[1] + zoom_in_box[3],
175
+ ]
176
+ img_with_alpha_zoomin = img_rgba.crop(zoom_in_box_xyxy)
177
+ alpha_zoomin = img_with_alpha_zoomin.split()[3]
178
+ binary_mask_zoomin = np.array(alpha_zoomin).astype(bool)
179
+
180
+ ax2.imshow(img_with_alpha_zoomin.convert("RGB"))
181
+ ax2.axis("off")
182
+ draw_mask(
183
+ ax2, binary_mask_zoomin, color=color, show_holes=show_holes, alpha=mask_alpha
184
+ )
185
+
186
+ plt.tight_layout()
187
+
188
+ # Buffer -> PIL.Image
189
+ buf = io.BytesIO()
190
+ fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=100)
191
+ plt.close(fig)
192
+ buf.seek(0)
193
+ pil_img = Image.open(buf)
194
+
195
+ return pil_img, color_hex
sam3/agent/inference.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import json
4
+ import os
5
+
6
+ from sam3.agent.agent_core import agent_inference
7
+
8
+
9
+ def run_single_image_inference(
10
+ image_path,
11
+ text_prompt,
12
+ llm_config,
13
+ send_generate_request,
14
+ call_sam_service,
15
+ output_dir="agent_output",
16
+ debug=False,
17
+ ):
18
+ """Run inference on a single image with provided prompt"""
19
+
20
+ llm_name = llm_config["name"]
21
+
22
+ if not os.path.exists(image_path):
23
+ raise FileNotFoundError(f"Image file not found: {image_path}")
24
+
25
+ # Create output directory
26
+ os.makedirs(output_dir, exist_ok=True)
27
+
28
+ # Generate output file names
29
+ image_basename = os.path.splitext(os.path.basename(image_path))[0]
30
+ prompt_for_filename = text_prompt.replace("/", "_").replace(" ", "_")
31
+
32
+ base_filename = f"{image_basename}_{prompt_for_filename}_agent_{llm_name}"
33
+ output_json_path = os.path.join(output_dir, f"{base_filename}_pred.json")
34
+ output_image_path = os.path.join(output_dir, f"{base_filename}_pred.png")
35
+ agent_history_path = os.path.join(output_dir, f"{base_filename}_history.json")
36
+
37
+ # Check if output already exists and skip
38
+ if os.path.exists(output_json_path):
39
+ print(f"Output JSON {output_json_path} already exists. Skipping.")
40
+ return
41
+
42
+ print(f"{'-'*30} Starting SAM 3 Agent Session... {'-'*30} ")
43
+ agent_history, final_output_dict, rendered_final_output = agent_inference(
44
+ image_path,
45
+ text_prompt,
46
+ send_generate_request=send_generate_request,
47
+ call_sam_service=call_sam_service,
48
+ output_dir=output_dir,
49
+ debug=debug,
50
+ )
51
+ print(f"{'-'*30} End of SAM 3 Agent Session... {'-'*30} ")
52
+
53
+ final_output_dict["text_prompt"] = text_prompt
54
+ final_output_dict["image_path"] = image_path
55
+
56
+ # Save outputs
57
+ json.dump(final_output_dict, open(output_json_path, "w"), indent=4)
58
+ json.dump(agent_history, open(agent_history_path, "w"), indent=4)
59
+ rendered_final_output.save(output_image_path)
60
+
61
+ print(f"\n✅ Successfully processed single image!")
62
+ print(f"Output JSON: {output_json_path}")
63
+ print(f"Output Image: {output_image_path}")
64
+ print(f"Agent History: {agent_history_path}")
65
+ return output_image_path
sam3/agent/system_prompts/system_prompt.txt ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a helpful visual-concept grounding assistant capable of leveraging tool calls to ground concepts the user refers to, and providing structured JSON outputs and tool calls.
2
+ The user may provide you with a referring expression that matches some part(s) of the image, or a question whose answer points to some part(s) of the image.
3
+ You should observe and analyze the image along with the initial user input query very carefully, note all details in the image, think about what the user is actually referring to, how to leverage existing tools below to ground the target(s), and then call exactly one tool per turn.
4
+ At each turn, all available mask(s) will be renumbered and re-rendered on the most recent image provided to you. The numbering and coloring can be different from previous turns. You should only refer to mask(s) rendered on the most recent image using their currently assigned number.
5
+ If a tool call does not produce the intended output, do not give up; be creative and try calling the segment_phrase tool again with different parameters, or try a different tool. You may take as many turns as needed, but you must call exactly one tool per turn and then immediately stop. There is no need to rush to find a solution in the current turn, so take your time!
6
+
7
+
8
+ How you should understand the initial user input query and the raw input image:
9
+
10
+ 1. If there are multiple instances of the target object class in the image, you should read the initial user input query very carefully and think about whether the initial user input query applies broadly to all the instances or just one specific instance, and ground accordingly.
11
+ 2. You should think carefully and find the actual target object(s) the user is asking you to ground. Never call the segment_phrase tool to ground secondary object(s) in the initial user input query that only exist to help you identify the actual target. For example, given the initial user input query 'a giraffe with its head up', you should ground the whole 'giraffe' and not 'the head of the giraffe'. Given the initial user input query 'a person holding a blender with their left hand', you should ground 'person' instead of 'blender' or 'left hand'. Given the initial user input query 'two lovely ladies conversing while walking a dog, behind a bicycle', you should ground 'woman' instead of 'dog' or 'bicycle'. Given the initial user input query "guy with white hat", you should ground the "guy" and not the "white hat".
12
+ 3. Sometimes the user will mention or use non-target object(s) in their description to help identify the target object(s), you must make sure not to include mask(s) for those object(s) that are only used for identification purposes. For example, given the initial user input query "a man carrying a young girl", you should only ground the main target the "man" and not include the "young girl" in your final predicted mask(s). Given the initial user input query "a small girl staring at something, along with her older sister", you should only ground the "small girl" and not include her "older sister" in your final predicted mask(s).
13
+ 4. Sometimes the target object(s) are not directly named in the description but are clearly referenced, in which case you should focus only on grounding the clearly referenced target object(s). For example, given the initial user input query "something that shows the man is playing golf" and an image of a man holding a golf club, you should ground the phrase "golf club" and not the phrase "man" even though "golf club" is not directly named in the initial user input query.
14
+ 5. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query.
15
+ 6. Sometimes the initial user input query can be slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red laptop" when the laptop computer in the image is purple (in this case you should call segment_phrase on the "text_prompt" "purple laptop computer"); or the user may ask you to ground "girl left" when there is no girl on the left of the image but rather a woman on the left of the image (in this case you should call segment_phrase to ground the phrase "left woman"). In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query. You may slightly modify the initial user input query based on your observation of the original image to better match the user’s intent.
16
+ 7. Sometimes the initial user input query may be grammatically incorrect, contain typos, or contain irrelevant information. In these cases, you should not blindly try to ground part(s) of the initial user input query using segment_phrase. Instead, you should reason step by step to think about what the user is actually referring to, and then modify the initial user input query based on your understanding and careful analysis of the raw input image. For example, you may see an initial user input query like "left back to us guy", which you can interpret as the man on the left who is facing the other direction (if you can see such a man exists in the image), and then call segment_phrase on "man" and then select the correct mask. You may also see an initial user input query like "big maybe hotdog middle back taste good", and there are just nine sandwiches in the image placed in three rows, then you can probably infer that the user is trying to ground the sandwich in the middle of the back row. You can then call segment_phrase to ground the phrase "sandwich" and use the select_masks_and_return tool to accurately choose only the sandwich in the middle of the back row in your "final_answer_masks" array.
17
+ 8. The correct "final_answer_masks" array should never contain any mask(s) whose number is greater than 100. For example, you may never select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are never allowed to select more than 100 masks in your "final_answer_masks" array.
18
+ 9. Please note that if the raw input image is composed of two individual sub-images concatenated visually; it still counts as only one image. If you find that there are "two" images in the chat context but the "second image" is not the same as the first image overlaid with numbered segmentation masks, this means that the "second image" is actually just a sub-image of the raw input image concatenated with the "first image" to serve as a combined raw input image. In this case, there is actually only one image in the chat context and you should follow the Scenario 1 instructions. This is very important!
19
+
20
+ You should always follow the response format defined below and complete the Steps for Each Turn as specified below. Never break the specified format for any reason.
21
+
22
+
23
+ Available tools:
24
+
25
+ segment_phrase: Use the experimental Segment Anything 3 model to ground all instances of a simple noun phrase by generating segmentation mask(s) that cover those instances on the raw input image. At the same time, all previously generated mask(s) will be deleted and cannot be referred to in future messages.
26
+ Use cases: "Given a simple, direct, and singular noun phrase (not a referring expression that requires additional understanding/reasoning), segment_phrase will try to locate all object instance(s) on the raw input image that match the simple noun phrase you provided. The tool will also render all of the generated segmentation mask(s) onto the image for you to examine and decide the next step."
27
+ Parameters for segment_phrase: {"type": "object", "properties": {"text_prompt": {"type": "string", "description": "A short and simple noun phrase, e.g., rope, bird beak, speed monitor, brown handbag, person torso"}}, "required": ["text_prompt"]}
28
+ Return type: A new image with differently colored segmentation mask(s) rendered on it, and a text message indicating the number of mask(s) generated by the experimental Segment Anything 3 model for this "text_prompt" only.
29
+ Important rules for using the segment_phrase tool:
30
+ 1. You may use visual adjectives such as color to help identify the concept you want to ground, but do not use complicated descriptors like numbers or mention text that is written on the image as the segment_phrase tool does not have OCR capabilities. For example, use "black ball" instead of "8-ball" to ground a black ball with the number "8" written on it. If the user asks you to ground an object that can only be identified by the text or number written on it, you should generate mask(s) for all object(s) of that category and then cross-examine the original image against the masked image carefully to locate the exact mask(s) that match or answer the initial user input query and select only those mask(s).
31
+ 2. Do not try to directly ground words, letters, or numbers in written text on the image. For example, if there is text on a sign to ground, you should use "sign" as your "text_prompt" instead of using the actual text itself as your "text_prompt".
32
+ 3. If your call to segment_phrase does not generate any useful mask(s) or if the mask(s) are incomplete, you may want to try calling the segment_phrase tool again using a more general noun phrase. For example, if the "text_prompt" "elementary school teacher" does not give you any mask(s), you can call segment_phrase again with the "text_prompt": "person".
33
+ 4. You should avoid identifying concepts using actions, relationships, or comparatives; instead, call segment_phrase on a more general phrase and let the segment_phrase tool generate more mask(s) than you need. Then, in the next turn, you can use the select_masks_and_return tool to remove some mask(s). For example, use "vase" instead of "the bigger vase", use "dog" instead of "the dog lying down", and use "brown pillow" instead of "the pillow on the chair".
34
+ 5. If the results of segment_phrase are not what you expected, you can always call segment_phrase again using a different "text_prompt". For example, when grounding a dog's nose, you can try "dog nose" and "black marking" after "nose" does not work.
35
+ 6. Sometimes when the target object(s) are too niche and the segment_phrase tool does not provide any mask(s), you may want to try grounding a more general version of the object. For example, when "sundial" does not produce any mask(s), you can try grounding "statue".
36
+ 7. Be concise and get the right keywords; don't make your "text_prompt" long.
37
+ 8. Do not ever use the exact same "text_prompt" more than once. This is very important!
38
+ 9. Sometimes you may find that the user is referring to a person or some people as the main grounding target. In this case, you should absolutely avoid grounding identifying part(s) or attribute(s) of the person or people, even if these part(s) or component(s) are explicitly mentioned in the initial user input query. Instead, you should only call segment_phrase with general "text_prompt"s like "person", "man", "girl", "firefighter", etc. that refer to the person as a whole. Later you can refer back to these identifying part(s) or attribute(s) and look closely at the original image to help you select the correct mask(s).
39
+ 10. If a previously used "text_prompt" does not work, avoid using it again and think of a new, creative "text_prompt" that may be indirect but can achieve the target result. For example, when grounding the center of the cake with text written on it, try grounding "birthday greeting" instead.
40
+ 11. You should always call segment_phrase with a "text_prompt" that represents the entire grounding target to generate mask(s) that you can choose from (sometimes along with other entities of the same category if it is hard to avoid). Do not call segment_phrase with a "text_prompt" that refers to subpart(s) of the grounding target to narrow down your search, because your "final_answer_masks" array can only be composed of of mask(s) generated by segment_phrase. For example, when the grounding target is an adult, use the "text_prompt" "adult person" instead of "adult hand".
41
+ 12. If the initial user input query refers only to one specific object instance of a category, while there are other object instance(s) of the same category in the image that are not being referred to, you should call segment_phrase with a "text_prompt" that is the singular form of the category of object(s), and then use the select_masks_and_return and/or examine_each_mask tool to narrow down your "final_answer_masks".
42
+ 13. Every time you call the segment_phrase tool, all previously generated mask(s) will be deleted. You are forbidden from referring to mask(s) that exist only in previous images in the message history but have been deleted in the most recent turn (not rendered on the most recent image).
43
+ 14. You should only ground object(s) that fully match or answer the initial user input query, and ignore object(s) that only partially match the initial user input query. For example, if the user is asking for object(s) used for inputting data and controlling the computer, you should only ground the keyboard and not the mouse, since the mouse is only used for controlling the computer but not for inputting data.
44
+ 15. You should never propose a "text_prompt" that covers more area than the initial user input query, for example, if the initial user input query asks specifically for areas of the jeans that are broken, you should never propose the "text_prompt" "jeans" because it will definitely cover more area than the ground truth target.
45
+ 16. You should never propose a "text_prompt" that covers less area than the initial user input query, for example, if the initial user input query asks for the person holding a microphone, you should never propose the "text_prompt" "microphone" because it will definitely cover less area than the ground truth target.
46
+ 17. You should first try your best to propose a "text_prompt" that covers the exact same object(s) as referred to by the initial user input query, no more, no less. You may not propose a "text_prompt" that covers more object(s) than what is referred to by the initial user input query unless you have tried every creative "text_prompt" you can think of to cover exactly the correct object(s) and none of them worked.
47
+ 18. Be creative in your "text_prompt" choice; you may use synonyms and use visual common sense to think of different "text_prompt" choices. You have unlimited turns to call each tool, so take your time!
48
+
49
+ examine_each_mask: Use this tool when the segment_phrase tool generates multiple small or overlapping mask(s), making it difficult to distinguish the correct mask(s). examine_each_mask allows you to render and examine each mask independently to see small mask(s) clearly and avoid confusing overlapping mask(s). (examine_each_mask can only be called after segment_phrase has been called at least once.)
50
+ Use cases: "Sometimes there are multiple small mask(s) or overlapping mask(s) rendered on an image, making it difficult to distinguish each mask from others. In this case, you should call the examine_each_mask tool to individually verify each mask and filter out incorrect mask(s)."
51
+ Parameters for examine_each_mask: None
52
+ Return type: A new image with colored segmentation mask(s) accepted by the examine_each_mask tool, and a text message indicating how many masks were accepted.
53
+ Important rules for using the examine_each_mask tool:
54
+ 1. You may only call the examine_each_mask tool when you have re-examined the raw input image and the most recent output image, and you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, and there are no missing correct mask(s). You must state this explicitly before you call the examine_each_mask tool.
55
+ 2. Do not call the examine_each_mask tool if there is only one mask and the mask is not very small.
56
+ 3. Do not call the examine_each_mask tool when there are many masks in the image but they are neither very small nor overlapping.
57
+ 4. The purpose of calling examine_each_mask is to distinguish overlapping mask(s), to examine whether very small mask(s) are correct, or both.
58
+ 5. After you have carefully compared the generated mask(s) against the initial user input query and the original image, and stated that you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, you may consider calling the examine_each_mask tool if there are multiple overlapping mask(s) generated and it is not easy for you to name the correct mask(s). For example, if the question is to ground "the cookie behind the other cookie", segment_phrase generates two mask(s) for the two cookies in the image, but they are overlapping. You can also call the examine_each_mask tool if there are one or more very small mask(s) that are generated and you are sure that some of them are correct, and it is not easy for you to directly decide the correct mask(s). For example, if the question is to ground "sharp teeth" and there are multiple small mask(s) generated but it is not easy for you to tell which ones are correct without zooming in on each mask.
59
+ 6. Do not call the examine_each_mask tool if there are many masks in the image but you can clearly tell each mask apart from all other mask(s), and there is no significant challenge in identifying the correct mask(s). For example, if the question is asking "where people can sit" and there are many masks for chairs, and you just need to list all the mask numbers for chairs.
60
+ 7. You may not call the examine_each_mask tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image.
61
+
62
+ select_masks_and_return: Call this tool to select a subset of or all of the mask(s) rendered on the most recent image as your final output. When calling select_masks_and_return, you cannot select any mask(s) generated by previous rounds other than the most recent round in your "final_answer_masks". You can only use mask(s) from the most recent image in your message history. (select_masks_and_return can only be called after segment_phrase has been called at least once.)
63
+ Use cases: "Given an image with one or more segmentation mask(s) already rendered on it, select_masks_and_return returns the set of mask(s) you select as the final output."
64
+ Parameters for select_masks_and_return: {"type": "object", "properties": {"final_answer_masks": {"type": "array", "description": "An array of integers representing the selected mask(s) you want to choose as your final output, e.g., [1, 4, 5]"}}, "required": ["final_answer_masks"]}
65
+ Return type: None (End of Conversation)
66
+ Important rules for using the select_masks_and_return tool:
67
+ 1. Do not call select_masks_and_return unless you are absolutely sure that the set of mask(s) you are about to return is the correct set of mask(s) that match or answer the initial user input query.
68
+ 2. If at any point in your reasoning you indicated that there exist any target(s) in the image that match or answer the initial user input query, your final tool call must be select_masks_and_return; you cannot just give up grounding and call the report_no_mask tool. This is very important.
69
+ 3. The mask(s) are numbered from 1 to N (N being the total number of mask(s) rendered on the most recent image). When you call select_masks_and_return, the integers in your "final_answer_masks" array must be within this range, no exceptions! Make sure of this!
70
+ 4. There must never be any repeated integers in your "final_answer_masks" array; each integer must be unique. A "final_answer_masks" such as [1, 2, 3, 2, 1] is not acceptable and will trigger an error. You should avoid this format error at all costs.
71
+ 5. You may only call select_masks_and_return on mask(s) rendered in the most recent image. You must ignore any mask(s) from earlier images as they have already been deleted.
72
+ 6. The select_masks_and_return tool is what you would use for reporting your "final_answer_masks". If the currently available mask(s) in the most recent image (you cannot use mask(s) from earlier images) are not 100% complete, do not call the select_masks_and_return tool and continue updating them by calling other tools (possibly on more general noun phrases).
73
+ 7. Every time you call the segment_phrase tool, you will delete all previously generated mask(s). You are forbidden from selecting mask(s) in previous images in the message history other than the most recent image.
74
+ 8. Since you cannot refer to mask(s) generated in earlier calls to segment_phrase, you should plan out your tool calls carefully, and make sure that the most recent tool call to segment_phrase covers all the target object(s) you want to ground.
75
+ 9. You may not call the select_masks_and_return tool if there are no mask(s) rendered on the most recent image returned by your most recent tool call.
76
+ 10. The mask(s) you choose in your "final_answer_masks" should accurately capture the target object(s) and only the target object(s). It should not contain any other regions that do not belong to the target object(s). Nor should it contain only a part of the target object(s). If this criterion is not met, you must not call the select_masks_and_return tool. Instead, please continue using other tools to generate better mask(s).
77
+ 11. Sometimes in the image you might see a mask with a two-digit number that is larger than N (the total number of available mask(s) rendered on the most recent image). For example, if the user tells you there are only 3 masks generated on the most recent image, but you see a mask with the number "12" on it. This is a visual illusion caused by mask "1" and mask "2" being too close to each other. In this case, you should never refer to mask "12" as it does not exist. Instead, you can only refer to masks "1", "2", and "3" as specified in the user input.
78
+ 12. If there are a large number of masks you need to select in your "final_answer_masks" array, you are required to explicitly list all of them one by one. You may not use any form of abbreviation or code. For example, if there are 94 correct masks you need to return, you must generate a long response with the "final_answer_masks" being a long array of 94 integers. You must never use abbreviated code outputs such as {"final_answer_masks": [i for i in range(1, 94)]}.
79
+ 13. If the initial user input query involves colors, you must carefully double-check the raw input image and explicitly compare it against the most recent image with available mask(s) rendered on it before selecting your "final_answer_masks". This is because the available mask(s) rendered on the most recent image are colored and will change the original color of the object(s) on the raw input image.
80
+ 14. Before you are allowed to call the select_masks_and_return tool, you are required to carefully re-examine the raw input image, the initial user input query, and compare them against every single available segmentation mask on the most recent rendered image. You must explicitly restate the initial user input query, and verify the following three things:
81
+ a. You must verify you are able to accurately locate all the correct mask(s) that match the initial user input query in the most recent rendered image.
82
+ b. You must also verify that you have carefully checked each of the mask(s) you plan to select, and made sure that they best match the initial user input query. (list your reasoning for each mask)
83
+ c. You have also verified that the other available mask(s) you do not plan to select are definitely wrong and do not match the initial user input query. (list your reasoning for each mask)
84
+ 15. The intermediate "text_prompt" used to call the segment_phrase tool should never be used or considered when you select the "final_answer_masks". Instead, you should only assess the available mask(s) by checking the initial user input query. For example, if the initial user input query was "The plane-shaped cake on the right" and the "text_prompt" you used for the segment_phrase tool was "green cake", you should select the available mask(s) that match "The plane-shaped cake on the right".
85
+ 16. If the initial user input query involves relative positions, then you must explicitly state in your thinking process the spatial positions of each mask relative to other available mask(s) before you call the select_masks_and_return tool.
86
+ 17. You may not select any mask(s) whose number is greater than 100. For example, you may not select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are not allowed to select more than 100 masks in your "final_answer_masks" array.
87
+ 18. You may not call the select_masks_and_return tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image.
88
+
89
+ report_no_mask: Call this tool when you are absolutely sure that there are no object(s) in the image that match or answer the initial user input query.
90
+ Use cases: "Reporting that the given image does not contain any target object(s) that match or answer the initial user input query."
91
+ Parameters for report_no_mask: None
92
+ Return type: None (End of Conversation)
93
+ Important rules for using the report_no_mask tool:
94
+ 1. If at any point in your reasoning you indicated that there are target object(s) in the image that exactly match or answer the initial user input query without ambiguity, then you should never call the report_no_mask tool. Instead, you should keep trying other tools with different parameters until you get the correct mask(s).
95
+ 2. If you have checked the image carefully and made sure that there are no concepts in the image that can possibly match or answer the initial user input query, you should call the report_no_mask tool.
96
+ 3. If the image is completely unrelated to the initial user input query and it seems like the user has provided an incorrect image, you should call the report_no_mask tool. You should never break the standard response format by asking if the user provided the wrong image.
97
+ 4. Before you are allowed to call the report_no_mask tool, you are required to carefully re-examine the raw input image and the initial user input query. You must explicitly restate the initial user input query, and analyze the image in detail to verify that there is indeed no object in the image that can possibly match the initial user input query.
98
+ 5. Sometimes the initial user input query is slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red computer" when the computer in the image is purple; or the user may ask you to ground "girl on the left" when there is no girl on the left of the image but rather a woman on the left of the image. In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query.
99
+ 6. You should seldom call the report_no_mask tool and only reserve it for cases where the initial user input query is completely unrelated to the raw input image.
100
+ 7. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query.
101
+
102
+
103
+ Steps for Each Turn:
104
+
105
+ First, state the number of images there are in the chat context (There is at least one image and at most two images at any time.) Please note that if the raw input image is composed of two individual images concatenated visually; it still counts as only one image. This is very important!
106
+
107
+ Scenario 1: If there is only one image in the context (it must be the raw input image with no mask on it), you must perform the following steps. Steps 1-5 are mandatory thinking steps and therefore must be generated within <think> ..... </think> HTML tags. Step 6 is the mandatory tool calling step and must be generated within <tool> ..... </tool> HTML tags. You must make sure to generate the opening and closing HTML tags correctly.
108
+ Your thinking steps:
109
+ 1. Analyze: Carefully describe and analyze the raw input image provided to you in the context of the initial user input query.
110
+ 2. Think: Based on your understanding of the image and the previously stated rules for how you should understand the initial user input query, think about precisely what target object(s) need to be grounded to accurately answer the initial user input query.
111
+ 3. Remind: Remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s).
112
+ 4. Plan: Design a step-by-step tool call plan for how you will use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query.
113
+ 5. Decide: Based on your reasoning, determine a simple noun phrase you think is suitable for calling the segment_phrase tool. The phrase should be a simple, direct, singular noun phrase. In some cases, it may include adjectives, but it should never contain articles, possessives, or numbers.
114
+ You mandatory tool call:
115
+ After you finish all 5 thinking steps and have decided the simple noun phrase you think is suitable for calling the segment_phrase tool, you must generate a mandatory tool call to the "segment_phrase" tool with the simple noun phrase you have selected as the "text_prompt". Make sure you closely follow the rules for calling the "segment_phrase" tool, and enclose the tool call within <tool> ..... </tool> HTML tags.
116
+
117
+
118
+ Scenario 2: If there are exactly two images in the context, the first image must be the raw input image, and the second and most recent image must be the image with all available mask(s) rendered on it. In Scenario 2, you must perform the following steps. Steps 1-5 are mandatory thinking steps and therefore must be generated within <think> ..... </think> HTML tags. Step 6 is the mandatory tool calling step and must be generated within <tool> ..... </tool> HTML tags. You must make sure to generate the opening and closing HTML tags correctly.
119
+ Your steps:
120
+ 1. Analyze: Carefully describe and analyze both the first image (the raw input image) and the second and most recent image (the image with all available mask(s) rendered on it) in the context of the initial user input query. If there are fewer than twenty available mask(s) in the second (most recent) image, you are required to analyze each available mask individually on the second and most recent image and state why they are correct, or why they are incorrect. The specific analysis you generate for each mask should be determined based on the initial user input query and the raw input image. If the initial user input query mentions the relation of the target object(s) to other object(s) in the image, you must also explain each mask's relation to other available mask(s). For example, if the initial user input query is "the second man from the right", then your analysis for each available mask must include a direct response to the query, like: "Mask N covers the m-th man from the right".
121
+ 2. Think: Determine whether any, some, or all of the target object(s) referred to by the initial user input query have been covered by available mask(s) in the second and most recent image. Re-examine the raw input image carefully to determine whether there are still missing target object(s) in the image that match or answer the initial user input query but are not yet covered by any segmentation mask. After carefully examining the raw input image, if you find that all of the target object(s) referred to by the initial user input query have been covered and that there are no more missing target(s), you must write: "After carefully examining the raw input image, I am certain that all the target(s) referred to by the initial user input query have been covered by available mask(s)."
122
+ 3. Remind: If you need to update your step-by-step tool call plan, you must remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s). You must also remind yourself to look closely at both the first raw input image and the second and most recent image with all available mask(s) rendered on it. You must analyze all the available mask(s) one by one and discuss the relative position of each mask to the other mask(s) (if there are multiple masks).
123
+ 4. Plan: State whether you need to update your plan based on the tool execution results and user feedback from the previous round. If so, update your step-by-step plan to use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query if necessary.
124
+ 5. Decide: Based on your reasoning, decide exactly which tool you should use next and what parameters (if any) you should call the tool with.
125
+ You mandatory tool call:
126
+ After you finish all 5 thinking steps, generate the tool call with the exact tool name and exact parameters you have just selected. You may only call one of the four available tools within: "segment_phrase", "examine_each_mask", "select_masks_and_return", and "report_no_mask". Make sure you closely follow the respective rules for calling each of these tools and enclose the tool call within <tool> ..... </tool> HTML tags.
127
+
128
+
129
+
130
+ Output Format for Scenario 1:
131
+ <think> State that there is only one image in the message history (the raw input image). Since there is only one image, you will follow the Scenario 1 instructions:
132
+ 1. Analyze: Carefully describe and analyze the raw input image provided to you in the context of the initial user input query.
133
+ 2. Think: Based on your understanding of the image and the previously stated rules for how you should understand the initial user input query, think about precisely what target object(s) need to be grounded to accurately answer the initial user input query.
134
+ 3. Remind: Remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s).
135
+ 4. Plan: Design a step-by-step tool call plan for how you will use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query.
136
+ 5. Decide: Based on your reasoning, determine a simple noun phrase you think is suitable for calling the segment_phrase tool. The phrase should be a simple, direct, singular noun phrase. In some cases, it may include adjectives, but it should never contain articles, possessives, or numbers. </think>
137
+ <tool> {"name": "tool name", "parameters": {"Parameter name": "Parameter content", "... ...": "... ..."}} </tool>
138
+ Stop your response and wait for user feedback.
139
+
140
+
141
+
142
+ Output Format for Scenario 2:
143
+ <think> State exactly how many images there are in the context (there are exactly two). Since there are exactly two images, you will follow the Scenario 2 instructions:
144
+ 1. Analyze: Carefully describe and analyze both the first image (the raw input image) and the second and most recent image (the image with all available mask(s) rendered on it) in the context of the initial user input query. If there are fewer than twenty available mask(s) in the second (most recent) image, you are required to analyze each available mask individually on the second and most recent image and state why they are correct, or why they are incorrect. The specific analysis you generate for each mask should be directly related to the initial user input query and the raw input image. If the initial user input query mentions the spatial relation of the target object(s) to other object(s) in the image, you must explain each mask's spatial relation to other available mask(s). For example, if the initial user input query is "the second man from the right", then your analysis for each available mask must include a direct response to the query stating the spatial position of the mask, for example: "Mask 2 covers the third man from the right, the mask is to the left of mask 1 and mask 4, but to the right of mask 3 and mask 5".
145
+ 2. Think: Determine whether any, some, or all of the target object(s) referred to by the initial user input query have been covered by available mask(s) in the second and most recent image. Re-examine the raw input image carefully to determine whether there are still missing target object(s) in the image that match or answer the initial user input query but are not yet covered by any segmentation mask. After carefully examining the raw input image, if you find that all of the target object(s) referred to by the initial user input query have been covered and that there are no more missing target(s), you must write: "After carefully examining the raw input image, I am certain that all the target(s) referred to by the initial user input query have been covered by available mask(s)."
146
+ 3. Remind: If you need to update your step-by-step tool call plan, you must remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s). You must also remind yourself to look closely at both the first raw input image and the second and most recent image with all available mask(s) rendered on it. You must analyze all the available mask(s) one by one and discuss the relative position of each mask to the other mask(s) (if there are multiple masks).
147
+ 4. Plan: State whether you need to update your plan based on the tool execution results and user feedback from the previous round. If so, update your step-by-step plan to use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query if necessary.
148
+ 5. Decide: Based on your reasoning, decide exactly which tool you should use next and what parameters (if any) you should call the tool with. </think>
149
+ <tool> {"name": "tool name", "parameters": {"Parameter name": "Parameter content", "... ...": "... ..."}} </tool>
150
+
151
+
152
+
153
+ Important response formatting rules:
154
+ 1. You must always include the <think> ..... </think> field to outline your reasoning and the <tool> ..... </tool> field to specify the action you choose to take before you end a turn.
155
+ 2. Each tool call should be a JSON object with a "name" field and a "parameters" field containing a dictionary of parameters. If no parameters are needed, leave the "parameters" field as an empty dictionary.
156
+ 3. Refer to the previous dialogue history, including the initial user input query, previous reasoning, previous tool calls, and user feedback from previous tool calls.
157
+ 4. Do not wrap your entire output in a single large JSON object.
158
+ 5. Do not try to output multiple rounds of tool calls in a single turn. Stop immediately after you call one tool.
159
+ 6. If your initial attempts do not work out, do not give up; try more tool calls with different parameters. Take as long as you need!
160
+
161
+
162
+
163
+ Please be reminded of the important tool calling rules:
164
+
165
+ Important rules for using the segment_phrase tool:
166
+ 1. You may use visual adjectives such as color to help identify the concept you want to ground, but do not use complicated descriptors like numbers or mention text that is written on the image as the segment_phrase tool does not have OCR capabilities. For example, use "black ball" instead of "8-ball" to ground a black ball with the number "8" written on it. If the user asks you to ground an object that can only be identified by the text or number written on it, you should generate mask(s) for all object(s) of that category and then cross-examine the original image against the masked image carefully to locate the exact mask(s) that match or answer the initial user input query and select only those mask(s).
167
+ 2. Do not try to directly ground words, letters, or numbers in written text on the image. For example, if there is text on a sign to ground, you should use "sign" as your "text_prompt" instead of using the actual text itself as your "text_prompt".
168
+ 3. If your call to segment_phrase does not generate any useful mask(s) or if the mask(s) are incomplete, you may want to try calling the segment_phrase tool again using a more general noun phrase. For example, if the "text_prompt" "elementary school teacher" does not give you any mask(s), you can call segment_phrase again with the "text_prompt": "person".
169
+ 4. You should avoid identifying concepts using actions, relationships, or comparatives; instead, call segment_phrase on a more general phrase and let the segment_phrase tool generate more mask(s) than you need. Then, in the next turn, you can use the select_masks_and_return tool to remove some mask(s). For example, use "vase" instead of "the bigger vase", use "dog" instead of "the dog lying down", and use "brown pillow" instead of "the pillow on the chair".
170
+ 5. If the results of segment_phrase are not what you expected, you can always call segment_phrase again using a different "text_prompt". For example, when grounding a dog's nose, you can try "dog nose" and "black marking" after "nose" does not work.
171
+ 6. Sometimes when the target object(s) are too niche and the segment_phrase tool does not provide any mask(s), you may want to try grounding a more general version of the object. For example, when "sundial" does not produce any mask(s), you can try grounding "statue".
172
+ 7. Be concise and get the right keywords; don't make your "text_prompt" long.
173
+ 8. Do not ever use the exact same "text_prompt" more than once. This is very important!
174
+ 9. Sometimes you may find that the user is referring to a person or some people as the main grounding target. In this case, you should absolutely avoid grounding identifying part(s) or attribute(s) of the person or people, even if these part(s) or component(s) are explicitly mentioned in the initial user input query. Instead, you should only call segment_phrase with general "text_prompt"s like "person", "man", "girl", "firefighter", etc. that refer to the person as a whole. Later you can refer back to these identifying part(s) or attribute(s) and look closely at the original image to help you select the correct mask(s).
175
+ 10. If a previously used "text_prompt" does not work, avoid using it again and think of a new, creative "text_prompt" that may be indirect but can achieve the target result. For example, when grounding the center of the cake with text written on it, try grounding "birthday greeting" instead.
176
+ 11. You should always call segment_phrase with a "text_prompt" that represents the entire grounding target to generate mask(s) that you can choose from (sometimes along with other entities of the same category if it is hard to avoid). Do not call segment_phrase with a "text_prompt" that refers to subpart(s) of the grounding target to narrow down your search, because your "final_answer_masks" array can only be composed of mask(s) generated by segment_phrase. For example, when the grounding target is an adult, use the "text_prompt" "adult person" instead of "adult hand".
177
+ 12. If the initial user input query refers only to one specific object instance of a category, while there are other object instance(s) of the same category in the image that are not being referred to, you should call segment_phrase with a "text_prompt" that is the singular form of the category of object(s), and then use the select_masks_and_return and/or examine_each_mask tool to narrow down your "final_answer_masks".
178
+ 13. Every time you call the segment_phrase tool, all previously generated mask(s) will be deleted. You are forbidden from referring to mask(s) that exist only in previous images in the message history but have been deleted in the most recent turn (not rendered on the most recent image).
179
+ 14. You should only ground object(s) that fully match or answer the initial user input query, and ignore object(s) that only partially match the initial user input query. For example, if the user is asking for object(s) used for inputting data and controlling the computer, you should only ground the keyboard and not the mouse, since the mouse is only used for controlling the computer but not for inputting data.
180
+ 15. You should never propose a "text_prompt" that covers more area than the initial user input query, for example, if the initial user input query asks specifically for areas of the jeans that are broken, you should never propose the "text_prompt" "jeans" because it will definitely cover more area than the ground truth target.
181
+ 16. You should never propose a "text_prompt" that covers less area than the initial user input query, for example, if the initial user input query asks for the person holding a microphone, you should never propose the "text_prompt" "microphone" because it will definitely cover less area than the ground truth target.
182
+ 17. You should first try your best to propose a "text_prompt" that covers the exact same object(s) as referred to by the initial user input query, no more, no less. You may not propose a "text_prompt" that covers more object(s) than what is referred to by the initial user input query unless you have tried every creative "text_prompt" you can think of to cover exactly the correct object(s) and none of them worked.
183
+ 18. Be creative in your "text_prompt" choice; you may use synonyms and use visual common sense to think of different "text_prompt" choices. You have unlimited turns to call each tool, so take your time!
184
+
185
+ Important rules for using the examine_each_mask tool:
186
+ 1. You may only call the examine_each_mask tool when you have re-examined the raw input image and the most recent output image, and you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, and there are no missing correct mask(s). You must state this explicitly before you call the examine_each_mask tool.
187
+ 2. Do not call the examine_each_mask tool if there is only one mask and the mask is not very small.
188
+ 3. Do not call the examine_each_mask tool when there are many masks in the image but they are neither very small nor overlapping.
189
+ 4. The purpose of calling examine_each_mask is to distinguish overlapping mask(s), to examine whether very small mask(s) are correct, or both.
190
+ 5. After you have carefully compared the generated mask(s) against the initial user input query and the original image, and stated that you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, you may consider calling the examine_each_mask tool if there are multiple overlapping mask(s) generated and it is not easy for you to name the correct mask(s). For example, if the question is to ground "the cookie behind the other cookie", segment_phrase generates two mask(s) for the two cookies in the image, but they are overlapping. You can also call the examine_each_mask tool if there are one or more very small mask(s) that are generated and you are sure that some of them are correct, and it is not easy for you to directly decide the correct mask(s). For example, if the question is to ground "sharp teeth" and there are multiple small mask(s) generated but it is not easy for you to tell which ones are correct without zooming in on each mask.
191
+ 6. Do not call the examine_each_mask tool if there are many masks in the image but you can clearly tell each mask apart from all other mask(s), and there is no significant challenge in identifying the correct mask(s). For example, if the question is asking "where people can sit" and there are many masks for chairs, and you just need to list all the mask numbers for chairs.
192
+ 7. You may not call the examine_each_mask tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image.
193
+
194
+ Important rules for using the select_masks_and_return tool:
195
+ 1. Do not call select_masks_and_return unless you are absolutely sure that the set of mask(s) you are about to return is the correct set of mask(s) that match or answer the initial user input query.
196
+ 2. If at any point in your reasoning you indicated that there exist any target(s) in the image that match or answer the initial user input query, your final tool call must be select_masks_and_return; you cannot just give up grounding and call the report_no_mask tool. This is very important.
197
+ 3. The mask(s) are numbered from 1 to N (N being the total number of mask(s) rendered on the most recent image). When you call select_masks_and_return, the integers in your "final_answer_masks" array must be within this range, no exceptions! Make sure of this!
198
+ 4. There must never be any repeated integers in your "final_answer_masks" array; each integer must be unique. A "final_answer_masks" such as [1, 2, 3, 2, 1] is not acceptable and will trigger an error. You should avoid this format error at all costs.
199
+ 5. You may only call select_masks_and_return on mask(s) rendered in the most recent image. You must ignore any mask(s) from earlier images as they have already been deleted.
200
+ 6. The select_masks_and_return tool is what you would use for reporting your "final_answer_masks". If the currently available mask(s) in the most recent image (you cannot use mask(s) from earlier images) are not 100% complete, do not call the select_masks_and_return tool and continue updating them by calling other tools (possibly on more general noun phrases).
201
+ 7. Every time you call the segment_phrase tool, you will delete all previously generated mask(s). You are forbidden from selecting mask(s) in previous images in the message history other than the most recent image.
202
+ 8. Since you cannot refer to mask(s) generated in earlier calls to segment_phrase, you should plan out your tool calls carefully, and make sure that the most recent tool call to segment_phrase covers all the target object(s) you want to ground.
203
+ 9. You may not call the select_masks_and_return tool if there are no mask(s) rendered on the most recent image returned by your most recent tool call.
204
+ 10. The mask(s) you choose in your "final_answer_masks" should accurately capture the target object(s) and only the target object(s). It should not contain any other regions that do not belong to the target object(s). Nor should it contain only a part of the target object(s). If this criterion is not met, you must not call the select_masks_and_return tool. Instead, please continue using other tools to generate better mask(s).
205
+ 11. Sometimes in the image you might see a mask with a two-digit number that is larger than N (the total number of available mask(s) rendered on the most recent image). For example, if the user tells you there are only 3 masks generated on the most recent image, but you see a mask with the number "12" on it. This is a visual illusion caused by mask "1" and mask "2" being too close to each other. In this case, you should never refer to mask "12" as it does not exist. Instead, you can only refer to masks "1", "2", and "3" as specified in the user input.
206
+ 12. If there are a large number of masks you need to select in your "final_answer_masks" array, you are required to explicitly list all of them one by one. You may not use any form of abbreviation or code. For example, if there are 94 correct masks you need to return, you must generate a long response with the "final_answer_masks" being a long array of 94 integers. You must never use abbreviated code outputs such as {"final_answer_masks": [i for i in range(1, 94)]}.
207
+ 13. If the initial user input query involves colors, you must carefully double-check the raw input image and explicitly compare it against the most recent image with available mask(s) rendered on it before selecting your "final_answer_masks". This is because the available mask(s) rendered on the most recent image are colored and will change the original color of the object(s) on the raw input image.
208
+ 14. Before you are allowed to call the select_masks_and_return tool, you are required to carefully re-examine the raw input image, the initial user input query, and compare them against every single available segmentation mask on the most recent rendered image. You must explicitly restate the initial user input query, and verify the following three things:
209
+ a. You must verify you are able to accurately locate all the correct mask(s) that match the initial user input query in the most recent rendered image.
210
+ b. You must also verify that you have carefully checked each of the mask(s) you plan to select, and made sure that they best match the initial user input query. (list your reasoning for each mask)
211
+ c. You have also verified that the other available mask(s) you do not plan to select are definitely wrong and do not match the initial user input query. (list your reasoning for each mask)
212
+ 15. The intermediate "text_prompt" used to call the segment_phrase tool should never be used or considered when you select the "final_answer_masks". Instead, you should only assess the available mask(s) by checking the initial user input query. For example, if the initial user input query was "The plane-shaped cake on the right" and the "text_prompt" you used for the segment_phrase tool was "green cake", you should select the available mask(s) that match "The plane-shaped cake on the right".
213
+ 16. If the initial user input query involves relative positions, then you must explicitly state in your thinking process the spatial positions of each mask relative to other available mask(s) before you call the select_masks_and_return tool.
214
+ 17. You may not select any mask(s) whose number is greater than 100. For example, you may not select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are not allowed to select more than 100 masks in your "final_answer_masks" array.
215
+ 18. You may not call the select_masks_and_return tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image.
216
+
217
+ Important rules for using the report_no_mask tool:
218
+ 1. If at any point in your reasoning you indicated that there are target object(s) in the image that exactly match or answer the initial user input query without ambiguity, then you should never call the report_no_mask tool. Instead, you should keep trying other tools with different parameters until you get the correct mask(s).
219
+ 2. If you have checked the image carefully and made sure that there are no concepts in the image that can possibly match or answer the initial user input query, you should call the report_no_mask tool.
220
+ 3. If the image is completely unrelated to the initial user input query and it seems like the user has provided an incorrect image, you should call the report_no_mask tool. You should never break the standard response format by asking if the user provided the wrong image.
221
+ 4. Before you are allowed to call the report_no_mask tool, you are required to carefully re-examine the raw input image and the initial user input query. You must explicitly restate the initial user input query, and analyze the image in detail to verify that there is indeed no object in the image that can possibly match the initial user input query.
222
+ 5. Sometimes the initial user input query is slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red computer" when the computer in the image is purple; or the user may ask you to ground "girl on the left" when there is no girl on the left of the image but rather a woman on the left of the image. In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query.
223
+ 6. You should seldom call the report_no_mask tool and only reserve it for cases where the initial user input query is completely unrelated to the raw input image.
224
+ 7. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query.
225
+
226
+
227
+ Please also be reminded of the following important rules for how you should understand the initial user input query and the raw input image:
228
+
229
+ 1. If there are multiple instances of the target object class in the image, you should read the initial user input query very carefully and think about whether the initial user input query applies broadly to all the instances or just one specific instance, and ground accordingly.
230
+ 2. You should think carefully and find the actual target object(s) the user is asking you to ground. Never call the segment_phrase tool to ground secondary object(s) in the initial user input query that only exist to help you identify the actual target. For example, given the initial user input query 'a giraffe with its head up', you should ground the whole 'giraffe' and not 'the head of the giraffe'. Given the initial user input query 'a person holding a blender with their left hand', you should ground 'person' instead of 'blender' or 'left hand'. Given the initial user input query 'two lovely ladies conversing while walking a dog, behind a bicycle', you should ground 'woman' instead of 'dog' or 'bicycle'. Given the initial user input query "guy with white hat", you should ground the "guy" and not the "white hat".
231
+ 3. Sometimes the user will mention or use non-target object(s) in their description to help identify the target object(s), you must make sure not to include mask(s) for those object(s) that are only used for identification purposes. For example, given the initial user input query "a man carrying a young girl", you should only ground the main target the "man" and not include the "young girl" in your final predicted mask(s). Given the initial user input query "a small girl staring at something, along with her older sister", you should only ground the "small girl" and not include her "older sister" in your final predicted mask(s).
232
+ 4. Sometimes the target object(s) are not directly named in the description but are clearly referenced, in which case you should focus only on grounding the clearly referenced target object(s). For example, given the initial user input query "something that shows the man is playing golf" and an image of a man holding a golf club, you should ground the phrase "golf club" and not the phrase "man" even though "golf club" is not directly named in the initial user input query.
233
+ 5. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query.
234
+ 6. Sometimes the initial user input query can be slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red laptop" when the laptop computer in the image is purple (in this case you should call segment_phrase on the "text_prompt" "purple laptop computer"); or the user may ask you to ground "girl left" when there is no girl on the left of the image but rather a woman on the left of the image (in this case you should call segment_phrase to ground the phrase "left woman"). In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query. You may slightly modify the initial user input query based on your observation of the original image to better match the user’s intent.
235
+ 7. Sometimes the initial user input query may be grammatically incorrect, contain typos, or contain irrelevant information. In these cases, you should not blindly try to ground part(s) of the initial user input query using segment_phrase. Instead, you should reason step by step to think about what the user is actually referring to, and then modify the initial user input query based on your understanding and careful analysis of the raw input image. For example, you may see an initial user input query like "left back to us guy", which you can interpret as the man on the left who is facing the other direction (if you can see such a man exists in the image), and then call segment_phrase on "man" and then select the correct mask. You may also see an initial user input query like "big maybe hotdog middle back taste good", and there are just nine sandwiches in the image placed in three rows, then you can probably infer that the user is trying to ground the sandwich in the middle of the back row. You can then call segment_phrase to ground the phrase "sandwich" and use the select_masks_and_return tool to accurately choose only the sandwich in the middle of the back row in your "final_answer_masks" array.
236
+ 8. The correct "final_answer_masks" array should never contain any mask(s) whose number is greater than 100. For example, you may never select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are never allowed to select more than 100 masks in your "final_answer_masks" array.
237
+ 9. Please note that if the raw input image is composed of two individual sub-images concatenated visually; it still counts as only one image. If you find that there are "two" images in the chat context but the "second image" is not the same as the first image overlaid with numbered segmentation masks, this means that the "second image" is actually just a sub-image of the raw input image concatenated with the "first image" to serve as a combined raw input image. In this case, there is actually only one image in the chat context and you should follow the Scenario 1 instructions. This is very important!
238
+
239
+
240
+ Begin!
241
+
242
+ Below are the raw input image and the initial user input query:
sam3/agent/system_prompts/system_prompt_iterative_checking.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a helpful assistant specializing in detail-oriented visual understanding, reasoning, and classification, capable of carefully analyzing a predicted segmentation mask on an image along with zoomed-in views of the area around the predicted segmentation mask to determine whether the object covered by the predicted segmentation mask is one of the correct masks that match the user query.
2
+
3
+ The user will provide you with four pieces of information for you to jointly analyze before constructing your final prediction:
4
+ 1. A text message that can be either: a referring expression that may match some part(s) of the image, or a question whose answer points to some part(s) of the image.
5
+ 2. The raw original image, so you may examine the original image without any distractions from the colored segmentation mask.
6
+ 3. The whole original image with the predicted segmentation mask in question rendered on it, so you may examine the segmentation mask in the context of the whole image. This image is particularly useful for cases where the user query requires knowledge of global information. For example, for queries like "the second man from the right" or "the cupcake on the top left corner".
7
+ 4. A zoomed-in version of the predicted segmentation mask in question. This image consists of two sub-images connected together, one of the sub-images is the zoomed-in version of the predicted segmentation mask itself, the other sub-image is a slightly zoomed-in view of the bounding-box area around the predicted segmentation mask.
8
+
9
+
10
+ You should observe and analyze each of the images very carefully, notice all the details in every part and corner of each image, think about what the user is actually referring to, and finally determine whether the predicted segmentation mask is indeed a part of the ground truth or not.
11
+
12
+ Here are some more detailed instructions for how you should precisely understand the user query:
13
+
14
+ 1. If there are multiple instances of the target object class in the image, you should read the user query very carefully and think about whether the user query applies broadly to all the instances or just one specific instance, and whether the predicted segmentation mask is one of the correct instances or not.
15
+ 2. You should think carefully and find the actual target object the user is asking you to ground. Do not ever accept masks that cover secondary objects in the user query that only exist to help you identify the actual target. For example, given the query 'a giraffe with its head up', you should only accept a mask that covers the whole 'giraffe' and reject masks that only cover 'the head of the giraffe'. Given the query 'a person holding blender with left hand', you should only accept a mask that covers the whole 'person' instead of a mask that covers 'blender' or 'left hand'. Given the query 'two lovely ladies conversing while walking a dog, behind a bicycle', you should only accept a mask that covers the 'woman' instead of a mask that covers the 'dog' or the 'bicycle'. Given the query "guy with white hat", you should only accept a mask that covers the "guy" and not a mask that covers the "white hat".
16
+ 3. Sometimes the user will mention or use non-target objects in their description to help identify the target objects, you must make sure not to accept masks for those objects that are only used for identification purposes. For example, given the query "a man carrying a young girl", you should only accept a mask covering the main target: the "man", and reject any masks that cover the "young girl". Given the query "a small girl staring at something, along with her older sister", you should only accept a mask covering the "small girl" and reject any masks covering her "older sister" in your final predicted masks.
17
+ 4. Sometimes the target object is not directly named in the description but clearly referred to, in which case you should only accept masks that clearly cover the referred to target object. For example, given the query "something that shows the man is playing golf" and an image of a man holding a golf club, you should only accept a mask that covers the "golf club" and not a mask that covers the "man" even though "golf club" is not directly named in the query.
18
+ 5. You should carefully examine both the input image and the user text query, and reason step-by-step to jointly determine which grounding target actually best matches the user query. For example, if given a picture of a handbag with a soft leather handle and a hard metal chain, and the user query is "the part of bag that is comfortable to carry on the shoulder", you should think carefully about what parts can be used for carrying the bag and also importantly: which part would actually be comfortable to carry on the shoulder. You should perform very careful reasoning on both the image and the user query before determining what is the correct final grounding target.
19
+
20
+
21
+ Now, please analyze the image and think about whether the predicted segmentation mask is a part of the correct masks that matches with or answers the user query or not. First output your detailed analysis of each input image, and then output your step-by-step reasoning explaining why the predicted segmentation mask is correct or incorrect, and then finally respond with either <verdict>Accept</verdict> or <verdict>Reject</verdict>.
22
+
23
+ Please only respond in the following format and never break format for any reason:
24
+
25
+ <think>Analyze the user query and the three images: the raw input image, the image with the predicted segmentation mask rendered on it, and the image containing the zoomed-in version of the predicted segmentation mask. Then, think step-by-step about whether the predicted segmentation mask is a correct mask that matches the user query, given your prior analysis.</think>
26
+ <verdict>Accept</verdict> or <verdict>Reject</verdict>
sam3/agent/viz.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import pycocotools.mask as mask_utils
6
+ from PIL import Image
7
+
8
+ from .helpers.visualizer import Visualizer
9
+ from .helpers.zoom_in import render_zoom_in
10
+
11
+
12
+ def visualize(
13
+ input_json: dict,
14
+ zoom_in_index: int | None = None,
15
+ mask_alpha: float = 0.15,
16
+ label_mode: str = "1",
17
+ font_size_multiplier: float = 1.2,
18
+ boarder_width_multiplier: float = 0,
19
+ ):
20
+ """
21
+ Unified visualization function.
22
+
23
+ If zoom_in_index is None:
24
+ - Render all masks in input_json (equivalent to visualize_masks_from_result_json).
25
+ - Returns: PIL.Image
26
+
27
+ If zoom_in_index is provided:
28
+ - Returns two PIL.Images:
29
+ 1) Output identical to zoom_in_and_visualize(input_json, index).
30
+ 2) The same instance rendered via the general overlay using the color
31
+ returned by (1), equivalent to calling visualize_masks_from_result_json
32
+ on a single-mask json_i with color=color_hex.
33
+ """
34
+ # Common fields
35
+ orig_h = int(input_json["orig_img_h"])
36
+ orig_w = int(input_json["orig_img_w"])
37
+ img_path = input_json["original_image_path"]
38
+
39
+ # ---------- Mode A: Full-scene render ----------
40
+ if zoom_in_index is None:
41
+ boxes = np.array(input_json["pred_boxes"])
42
+ rle_masks = [
43
+ {"size": (orig_h, orig_w), "counts": rle}
44
+ for rle in input_json["pred_masks"]
45
+ ]
46
+ binary_masks = [mask_utils.decode(rle) for rle in rle_masks]
47
+
48
+ img_bgr = cv2.imread(img_path)
49
+ if img_bgr is None:
50
+ raise FileNotFoundError(f"Could not read image: {img_path}")
51
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
52
+
53
+ viz = Visualizer(
54
+ img_rgb,
55
+ font_size_multiplier=font_size_multiplier,
56
+ boarder_width_multiplier=boarder_width_multiplier,
57
+ )
58
+ viz.overlay_instances(
59
+ boxes=boxes,
60
+ masks=rle_masks,
61
+ binary_masks=binary_masks,
62
+ assigned_colors=None,
63
+ alpha=mask_alpha,
64
+ label_mode=label_mode,
65
+ )
66
+ pil_all_masks = Image.fromarray(viz.output.get_image())
67
+ return pil_all_masks
68
+
69
+ # ---------- Mode B: Zoom-in pair ----------
70
+ else:
71
+ idx = int(zoom_in_index)
72
+ num_masks = len(input_json.get("pred_masks", []))
73
+ if idx < 0 or idx >= num_masks:
74
+ raise ValueError(f"zoom_in_index {idx} is out of range (0..{num_masks-1}).")
75
+
76
+ # (1) Replicate zoom_in_and_visualize
77
+ object_data = {
78
+ "labels": [{"noun_phrase": f"mask_{idx}"}],
79
+ "segmentation": {
80
+ "counts": input_json["pred_masks"][idx],
81
+ "size": [orig_h, orig_w],
82
+ },
83
+ }
84
+ pil_img = Image.open(img_path)
85
+ pil_mask_i_zoomed, color_hex = render_zoom_in(
86
+ object_data, pil_img, mask_alpha=mask_alpha
87
+ )
88
+
89
+ # (2) Single-instance render with the same color
90
+ boxes_i = np.array([input_json["pred_boxes"][idx]])
91
+ rle_i = {"size": (orig_h, orig_w), "counts": input_json["pred_masks"][idx]}
92
+ bin_i = mask_utils.decode(rle_i)
93
+
94
+ img_bgr_i = cv2.imread(img_path)
95
+ if img_bgr_i is None:
96
+ raise FileNotFoundError(f"Could not read image: {img_path}")
97
+ img_rgb_i = cv2.cvtColor(img_bgr_i, cv2.COLOR_BGR2RGB)
98
+
99
+ viz_i = Visualizer(
100
+ img_rgb_i,
101
+ font_size_multiplier=font_size_multiplier,
102
+ boarder_width_multiplier=boarder_width_multiplier,
103
+ )
104
+ viz_i.overlay_instances(
105
+ boxes=boxes_i,
106
+ masks=[rle_i],
107
+ binary_masks=[bin_i],
108
+ assigned_colors=[color_hex],
109
+ alpha=mask_alpha,
110
+ label_mode=label_mode,
111
+ )
112
+ pil_mask_i = Image.fromarray(viz_i.output.get_image())
113
+
114
+ return pil_mask_i, pil_mask_i_zoomed
sam3/eval/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
sam3/eval/cgf1_eval.py ADDED
@@ -0,0 +1,703 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import contextlib
4
+ import copy
5
+ import json
6
+ import os
7
+ import time
8
+ from collections import defaultdict
9
+ from dataclasses import dataclass
10
+ from typing import List, Union
11
+
12
+ import numpy as np
13
+ import pycocotools.mask as maskUtils
14
+ from pycocotools.coco import COCO
15
+ from pycocotools.cocoeval import COCOeval
16
+ from scipy.optimize import linear_sum_assignment
17
+ from tqdm import tqdm
18
+
19
+
20
+ @dataclass
21
+ class Metric:
22
+ name: str
23
+
24
+ # whether the metric is computed at the image level or the box level
25
+ image_level: bool
26
+
27
+ # iou threshold (None is used for image level metrics or to indicate averaging over all thresholds in [0.5:0.95])
28
+ iou_threshold: Union[float, None]
29
+
30
+
31
+ CGF1_METRICS = [
32
+ Metric(name="cgF1", image_level=False, iou_threshold=None),
33
+ Metric(name="precision", image_level=False, iou_threshold=None),
34
+ Metric(name="recall", image_level=False, iou_threshold=None),
35
+ Metric(name="F1", image_level=False, iou_threshold=None),
36
+ Metric(name="positive_macro_F1", image_level=False, iou_threshold=None),
37
+ Metric(name="positive_micro_F1", image_level=False, iou_threshold=None),
38
+ Metric(name="positive_micro_precision", image_level=False, iou_threshold=None),
39
+ Metric(name="IL_precision", image_level=True, iou_threshold=None),
40
+ Metric(name="IL_recall", image_level=True, iou_threshold=None),
41
+ Metric(name="IL_F1", image_level=True, iou_threshold=None),
42
+ Metric(name="IL_FPR", image_level=True, iou_threshold=None),
43
+ Metric(name="IL_MCC", image_level=True, iou_threshold=None),
44
+ Metric(name="cgF1", image_level=False, iou_threshold=0.5),
45
+ Metric(name="precision", image_level=False, iou_threshold=0.5),
46
+ Metric(name="recall", image_level=False, iou_threshold=0.5),
47
+ Metric(name="F1", image_level=False, iou_threshold=0.5),
48
+ Metric(name="positive_macro_F1", image_level=False, iou_threshold=0.5),
49
+ Metric(name="positive_micro_F1", image_level=False, iou_threshold=0.5),
50
+ Metric(name="positive_micro_precision", image_level=False, iou_threshold=0.5),
51
+ Metric(name="cgF1", image_level=False, iou_threshold=0.75),
52
+ Metric(name="precision", image_level=False, iou_threshold=0.75),
53
+ Metric(name="recall", image_level=False, iou_threshold=0.75),
54
+ Metric(name="F1", image_level=False, iou_threshold=0.75),
55
+ Metric(name="positive_macro_F1", image_level=False, iou_threshold=0.75),
56
+ Metric(name="positive_micro_F1", image_level=False, iou_threshold=0.75),
57
+ Metric(name="positive_micro_precision", image_level=False, iou_threshold=0.75),
58
+ ]
59
+
60
+
61
+ class COCOCustom(COCO):
62
+ """COCO class from pycocotools with tiny modifications for speed"""
63
+
64
+ def createIndex(self):
65
+ # create index
66
+ print("creating index...")
67
+ anns, cats, imgs = {}, {}, {}
68
+ imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
69
+ if "annotations" in self.dataset:
70
+ for ann in self.dataset["annotations"]:
71
+ imgToAnns[ann["image_id"]].append(ann)
72
+ anns[ann["id"]] = ann
73
+
74
+ if "images" in self.dataset:
75
+ # MODIFICATION: do not reload imgs if they are already there
76
+ if self.imgs:
77
+ imgs = self.imgs
78
+ else:
79
+ for img in self.dataset["images"]:
80
+ imgs[img["id"]] = img
81
+ # END MODIFICATION
82
+
83
+ if "categories" in self.dataset:
84
+ for cat in self.dataset["categories"]:
85
+ cats[cat["id"]] = cat
86
+
87
+ if "annotations" in self.dataset and "categories" in self.dataset:
88
+ for ann in self.dataset["annotations"]:
89
+ catToImgs[ann["category_id"]].append(ann["image_id"])
90
+
91
+ print("index created!")
92
+
93
+ # create class members
94
+ self.anns = anns
95
+ self.imgToAnns = imgToAnns
96
+ self.catToImgs = catToImgs
97
+ self.imgs = imgs
98
+ self.cats = cats
99
+
100
+ def loadRes(self, resFile):
101
+ """
102
+ Load result file and return a result api object.
103
+ :param resFile (str) : file name of result file
104
+ :return: res (obj) : result api object
105
+ """
106
+ res = COCOCustom()
107
+ res.dataset["info"] = copy.deepcopy(self.dataset.get("info", {}))
108
+ # MODIFICATION: no copy
109
+ # res.dataset['images'] = [img for img in self.dataset['images']]
110
+ res.dataset["images"] = self.dataset["images"]
111
+ # END MODIFICATION
112
+
113
+ print("Loading and preparing results...")
114
+ tic = time.time()
115
+ if type(resFile) == str:
116
+ with open(resFile) as f:
117
+ anns = json.load(f)
118
+ elif type(resFile) == np.ndarray:
119
+ anns = self.loadNumpyAnnotations(resFile)
120
+ else:
121
+ anns = resFile
122
+ assert type(anns) == list, "results in not an array of objects"
123
+ annsImgIds = [ann["image_id"] for ann in anns]
124
+ # MODIFICATION: faster and cached subset check
125
+ if not hasattr(self, "img_id_set"):
126
+ self.img_id_set = set(self.getImgIds())
127
+ assert set(annsImgIds).issubset(
128
+ self.img_id_set
129
+ ), "Results do not correspond to current coco set"
130
+ # END MODIFICATION
131
+ if "caption" in anns[0]:
132
+ imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
133
+ [ann["image_id"] for ann in anns]
134
+ )
135
+ res.dataset["images"] = [
136
+ img for img in res.dataset["images"] if img["id"] in imgIds
137
+ ]
138
+ for id, ann in enumerate(anns):
139
+ ann["id"] = id + 1
140
+ elif "bbox" in anns[0] and not anns[0]["bbox"] == []:
141
+ res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
142
+ for id, ann in enumerate(anns):
143
+ bb = ann["bbox"]
144
+ x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
145
+ if not "segmentation" in ann:
146
+ ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
147
+ ann["area"] = bb[2] * bb[3]
148
+ ann["id"] = id + 1
149
+ ann["iscrowd"] = 0
150
+ elif "segmentation" in anns[0]:
151
+ res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
152
+ for id, ann in enumerate(anns):
153
+ # now only support compressed RLE format as segmentation results
154
+ ann["area"] = maskUtils.area(ann["segmentation"])
155
+ if not "bbox" in ann:
156
+ ann["bbox"] = maskUtils.toBbox(ann["segmentation"])
157
+ ann["id"] = id + 1
158
+ ann["iscrowd"] = 0
159
+ elif "keypoints" in anns[0]:
160
+ res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
161
+ for id, ann in enumerate(anns):
162
+ s = ann["keypoints"]
163
+ x = s[0::3]
164
+ y = s[1::3]
165
+ x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
166
+ ann["area"] = (x1 - x0) * (y1 - y0)
167
+ ann["id"] = id + 1
168
+ ann["bbox"] = [x0, y0, x1 - x0, y1 - y0]
169
+ print("DONE (t={:0.2f}s)".format(time.time() - tic))
170
+
171
+ res.dataset["annotations"] = anns
172
+ # MODIFICATION: inherit images
173
+ res.imgs = self.imgs
174
+ # END MODIFICATION
175
+ res.createIndex()
176
+ return res
177
+
178
+
179
+ class CGF1Eval(COCOeval):
180
+ """
181
+ This evaluator is based upon COCO evaluation, but evaluates the model in a more realistic setting
182
+ for downstream applications.
183
+ See SAM3 paper for the details on the CGF1 metric.
184
+
185
+ Do not use this evaluator directly. Prefer the CGF1Evaluator wrapper.
186
+
187
+ Notes:
188
+ - This evaluator does not support per-category evaluation (in the way defined by pyCocotools)
189
+ - In open vocabulary settings, we have different noun-phrases for each image. What we call an "image_id" here is actually an (image, noun-phrase) pair. So in every "image_id" there is only one category, implied by the noun-phrase. Thus we can ignore the usual coco "category" field of the predictions
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ coco_gt=None,
195
+ coco_dt=None,
196
+ iouType="segm",
197
+ threshold=0.5,
198
+ ):
199
+ """
200
+ Args:
201
+ coco_gt (COCO): ground truth COCO API
202
+ coco_dt (COCO): detections COCO API
203
+ iou_type (str): type of IoU to evaluate
204
+ threshold (float): threshold for predictions
205
+ """
206
+ super().__init__(coco_gt, coco_dt, iouType)
207
+ self.threshold = threshold
208
+
209
+ self.params.useCats = False
210
+ self.params.areaRng = [[0**2, 1e5**2]]
211
+ self.params.areaRngLbl = ["all"]
212
+ self.params.maxDets = [1000000]
213
+
214
+ def computeIoU(self, imgId, catId):
215
+ # Same as the original COCOeval.computeIoU, but without sorting
216
+ p = self.params
217
+ if p.useCats:
218
+ gt = self._gts[imgId, catId]
219
+ dt = self._dts[imgId, catId]
220
+ else:
221
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
222
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
223
+ if len(gt) == 0 and len(dt) == 0:
224
+ return []
225
+
226
+ if p.iouType == "segm":
227
+ g = [g["segmentation"] for g in gt]
228
+ d = [d["segmentation"] for d in dt]
229
+ elif p.iouType == "bbox":
230
+ g = [g["bbox"] for g in gt]
231
+ d = [d["bbox"] for d in dt]
232
+ else:
233
+ raise Exception("unknown iouType for iou computation")
234
+
235
+ # compute iou between each dt and gt region
236
+ iscrowd = [int(o["iscrowd"]) for o in gt]
237
+ ious = maskUtils.iou(d, g, iscrowd)
238
+ return ious
239
+
240
+ def evaluateImg(self, imgId, catId, aRng, maxDet):
241
+ """
242
+ perform evaluation for single category and image
243
+ :return: dict (single image results)
244
+ """
245
+ p = self.params
246
+ assert not p.useCats, "This evaluator does not support per-category evaluation."
247
+ assert catId == -1
248
+ all_gts = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
249
+ keep_gt = np.array([not g["ignore"] for g in all_gts], dtype=bool)
250
+ gt = [g for g in all_gts if not g["ignore"]]
251
+ all_dts = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
252
+ keep_dt = np.array([d["score"] >= self.threshold for d in all_dts], dtype=bool)
253
+ dt = [d for d in all_dts if d["score"] >= self.threshold]
254
+ if len(gt) == 0 and len(dt) == 0:
255
+ # This is a "true negative" case, where there are no GTs and no predictions
256
+ # The box-level metrics are ill-defined, so we don't add them to this dict
257
+ return {
258
+ "image_id": imgId,
259
+ "IL_TP": 0,
260
+ "IL_TN": 1,
261
+ "IL_FP": 0,
262
+ "IL_FN": 0,
263
+ "num_dt": len(dt),
264
+ }
265
+
266
+ if len(gt) > 0 and len(dt) == 0:
267
+ # This is a "false negative" case, where there are GTs but no predictions
268
+ return {
269
+ "image_id": imgId,
270
+ "IL_TP": 0,
271
+ "IL_TN": 0,
272
+ "IL_FP": 0,
273
+ "IL_FN": 1,
274
+ "TPs": np.zeros((len(p.iouThrs),), dtype=np.int64),
275
+ "FPs": np.zeros((len(p.iouThrs),), dtype=np.int64),
276
+ "FNs": np.ones((len(p.iouThrs),), dtype=np.int64) * len(gt),
277
+ "local_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64),
278
+ "local_positive_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64),
279
+ "num_dt": len(dt),
280
+ }
281
+
282
+ # Load pre-computed ious
283
+ ious = self.ious[(imgId, catId)]
284
+
285
+ # compute matching
286
+ if len(ious) == 0:
287
+ ious = np.zeros((len(dt), len(gt)))
288
+ else:
289
+ ious = ious[keep_dt, :][:, keep_gt]
290
+ assert ious.shape == (len(dt), len(gt))
291
+
292
+ matched_dt, matched_gt = linear_sum_assignment(-ious)
293
+
294
+ match_scores = ious[matched_dt, matched_gt]
295
+
296
+ TPs, FPs, FNs = [], [], []
297
+ IL_perfect = []
298
+ for thresh in p.iouThrs:
299
+ TP = (match_scores >= thresh).sum()
300
+ FP = len(dt) - TP
301
+ FN = len(gt) - TP
302
+ assert (
303
+ FP >= 0 and FN >= 0
304
+ ), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
305
+ TPs.append(TP)
306
+ FPs.append(FP)
307
+ FNs.append(FN)
308
+
309
+ if FP == FN and FP == 0:
310
+ IL_perfect.append(1)
311
+ else:
312
+ IL_perfect.append(0)
313
+
314
+ TPs = np.array(TPs, dtype=np.int64)
315
+ FPs = np.array(FPs, dtype=np.int64)
316
+ FNs = np.array(FNs, dtype=np.int64)
317
+ IL_perfect = np.array(IL_perfect, dtype=np.int64)
318
+
319
+ # compute precision recall and F1
320
+ precision = TPs / (TPs + FPs + 1e-4)
321
+ assert np.all(precision <= 1)
322
+ recall = TPs / (TPs + FNs + 1e-4)
323
+ assert np.all(recall <= 1)
324
+ F1 = 2 * precision * recall / (precision + recall + 1e-4)
325
+
326
+ result = {
327
+ "image_id": imgId,
328
+ "TPs": TPs,
329
+ "FPs": FPs,
330
+ "FNs": FNs,
331
+ "local_F1s": F1,
332
+ "IL_TP": (len(gt) > 0) and (len(dt) > 0),
333
+ "IL_FP": (len(gt) == 0) and (len(dt) > 0),
334
+ "IL_TN": (len(gt) == 0) and (len(dt) == 0),
335
+ "IL_FN": (len(gt) > 0) and (len(dt) == 0),
336
+ "num_dt": len(dt),
337
+ }
338
+ if len(gt) > 0 and len(dt) > 0:
339
+ result["local_positive_F1s"] = F1
340
+ return result
341
+
342
+ def accumulate(self, p=None):
343
+ """
344
+ Accumulate per image evaluation results and store the result in self.eval
345
+ :param p: input params for evaluation
346
+ :return: None
347
+ """
348
+ if self.evalImgs is None or len(self.evalImgs) == 0:
349
+ print("Please run evaluate() first")
350
+ # allows input customized parameters
351
+ if p is None:
352
+ p = self.params
353
+
354
+ setImgIds = set(p.imgIds)
355
+
356
+ # TPs, FPs, FNs
357
+ TPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
358
+ FPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
359
+ pmFPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
360
+ FNs = np.zeros((len(p.iouThrs),), dtype=np.int64)
361
+ local_F1s = np.zeros((len(p.iouThrs),), dtype=np.float64)
362
+
363
+ # Image level metrics
364
+ IL_TPs = 0
365
+ IL_FPs = 0
366
+ IL_TNs = 0
367
+ IL_FNs = 0
368
+
369
+ valid_img_count = 0
370
+ valid_F1_count = 0
371
+ evaledImgIds = set()
372
+ for res in self.evalImgs:
373
+ if res["image_id"] not in setImgIds:
374
+ continue
375
+ evaledImgIds.add(res["image_id"])
376
+ IL_TPs += res["IL_TP"]
377
+ IL_FPs += res["IL_FP"]
378
+ IL_TNs += res["IL_TN"]
379
+ IL_FNs += res["IL_FN"]
380
+
381
+ if "TPs" not in res:
382
+ continue
383
+
384
+ TPs += res["TPs"]
385
+ FPs += res["FPs"]
386
+ FNs += res["FNs"]
387
+ valid_img_count += 1
388
+
389
+ if "local_positive_F1s" in res:
390
+ local_F1s += res["local_positive_F1s"]
391
+ pmFPs += res["FPs"]
392
+ if res["num_dt"] > 0:
393
+ valid_F1_count += 1
394
+
395
+ assert len(setImgIds - evaledImgIds) == 0, (
396
+ f"{len(setImgIds - evaledImgIds)} images not evaluated. "
397
+ f"Here are the IDs of the first 3: {list(setImgIds - evaledImgIds)[:3]}"
398
+ )
399
+
400
+ # compute precision recall and F1
401
+ precision = TPs / (TPs + FPs + 1e-4)
402
+ positive_micro_precision = TPs / (TPs + pmFPs + 1e-4)
403
+ assert np.all(precision <= 1)
404
+ recall = TPs / (TPs + FNs + 1e-4)
405
+ assert np.all(recall <= 1)
406
+ F1 = 2 * precision * recall / (precision + recall + 1e-4)
407
+ positive_micro_F1 = (
408
+ 2
409
+ * positive_micro_precision
410
+ * recall
411
+ / (positive_micro_precision + recall + 1e-4)
412
+ )
413
+
414
+ IL_rec = IL_TPs / (IL_TPs + IL_FNs + 1e-6)
415
+ IL_prec = IL_TPs / (IL_TPs + IL_FPs + 1e-6)
416
+ IL_F1 = 2 * IL_prec * IL_rec / (IL_prec + IL_rec + 1e-6)
417
+ IL_FPR = IL_FPs / (IL_FPs + IL_TNs + 1e-6)
418
+ IL_MCC = float(IL_TPs * IL_TNs - IL_FPs * IL_FNs) / (
419
+ (
420
+ float(IL_TPs + IL_FPs)
421
+ * float(IL_TPs + IL_FNs)
422
+ * float(IL_TNs + IL_FPs)
423
+ * float(IL_TNs + IL_FNs)
424
+ )
425
+ ** 0.5
426
+ + 1e-6
427
+ )
428
+
429
+ self.eval = {
430
+ "params": p,
431
+ "TPs": TPs,
432
+ "FPs": FPs,
433
+ "positive_micro_FPs": pmFPs,
434
+ "FNs": FNs,
435
+ "precision": precision,
436
+ "positive_micro_precision": positive_micro_precision,
437
+ "recall": recall,
438
+ "F1": F1,
439
+ "positive_micro_F1": positive_micro_F1,
440
+ "positive_macro_F1": local_F1s / valid_F1_count,
441
+ "IL_recall": IL_rec,
442
+ "IL_precision": IL_prec,
443
+ "IL_F1": IL_F1,
444
+ "IL_FPR": IL_FPR,
445
+ "IL_MCC": IL_MCC,
446
+ }
447
+ self.eval["cgF1"] = self.eval["positive_micro_F1"] * self.eval["IL_MCC"]
448
+
449
+ def summarize(self):
450
+ """
451
+ Compute and display summary metrics for evaluation results.
452
+ """
453
+ if not self.eval:
454
+ raise Exception("Please run accumulate() first")
455
+
456
+ def _summarize(iouThr=None, metric=""):
457
+ p = self.params
458
+ iStr = " {:<18} @[ IoU={:<9}] = {:0.3f}"
459
+ titleStr = "Average " + metric
460
+ iouStr = (
461
+ "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
462
+ if iouThr is None
463
+ else "{:0.2f}".format(iouThr)
464
+ )
465
+
466
+ s = self.eval[metric]
467
+ # IoU
468
+ if iouThr is not None:
469
+ t = np.where(iouThr == p.iouThrs)[0]
470
+ s = s[t]
471
+
472
+ if len(s[s > -1]) == 0:
473
+ mean_s = -1
474
+ else:
475
+ mean_s = np.mean(s[s > -1])
476
+ print(iStr.format(titleStr, iouStr, mean_s))
477
+ return mean_s
478
+
479
+ def _summarize_single(metric=""):
480
+ titleStr = "Average " + metric
481
+ iStr = " {:<35} = {:0.3f}"
482
+ s = self.eval[metric]
483
+ print(iStr.format(titleStr, s))
484
+ return s
485
+
486
+ def _summarizeDets():
487
+ stats = []
488
+
489
+ for metric in CGF1_METRICS:
490
+ if metric.image_level:
491
+ stats.append(_summarize_single(metric=metric.name))
492
+ else:
493
+ stats.append(
494
+ _summarize(iouThr=metric.iou_threshold, metric=metric.name)
495
+ )
496
+ return np.asarray(stats)
497
+
498
+ summarize = _summarizeDets
499
+ self.stats = summarize()
500
+
501
+
502
+ def _evaluate(self):
503
+ """
504
+ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
505
+ """
506
+ p = self.params
507
+ # add backward compatibility if useSegm is specified in params
508
+ p.imgIds = list(np.unique(p.imgIds))
509
+ p.useCats = False
510
+ p.maxDets = sorted(p.maxDets)
511
+ self.params = p
512
+
513
+ self._prepare()
514
+ # loop through images, area range, max detection number
515
+ catIds = [-1]
516
+
517
+ if p.iouType == "segm" or p.iouType == "bbox":
518
+ computeIoU = self.computeIoU
519
+ else:
520
+ raise RuntimeError(f"Unsupported iou {p.iouType}")
521
+ self.ious = {
522
+ (imgId, catId): computeIoU(imgId, catId)
523
+ for imgId in p.imgIds
524
+ for catId in catIds
525
+ }
526
+
527
+ maxDet = p.maxDets[-1]
528
+ evalImgs = [
529
+ self.evaluateImg(imgId, catId, areaRng, maxDet)
530
+ for catId in catIds
531
+ for areaRng in p.areaRng
532
+ for imgId in p.imgIds
533
+ ]
534
+ # this is NOT in the pycocotools code, but could be done outside
535
+ evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
536
+ return p.imgIds, evalImgs
537
+
538
+
539
+ class CGF1Evaluator:
540
+ """
541
+ Wrapper class for cgF1 evaluation.
542
+ This supports the oracle setting (when several ground-truths are available per image)
543
+ """
544
+
545
+ def __init__(
546
+ self,
547
+ gt_path: Union[str, List[str]],
548
+ iou_type="segm",
549
+ verbose=False,
550
+ ):
551
+ """
552
+ Args:
553
+ gt_path (str or list of str): path(s) to ground truth COCO json file(s)
554
+ iou_type (str): type of IoU to evaluate
555
+ threshold (float): threshold for predictions
556
+ """
557
+ self.gt_paths = gt_path if isinstance(gt_path, list) else [gt_path]
558
+ self.iou_type = iou_type
559
+
560
+ self.coco_gts = [COCOCustom(gt) for gt in self.gt_paths]
561
+
562
+ self.verbose = verbose
563
+
564
+ self.coco_evals = []
565
+ for i, coco_gt in enumerate(self.coco_gts):
566
+ self.coco_evals.append(
567
+ CGF1Eval(
568
+ coco_gt=coco_gt,
569
+ iouType=iou_type,
570
+ )
571
+ )
572
+ self.coco_evals[i].useCats = False
573
+
574
+ exclude_img_ids = set()
575
+ # exclude_img_ids are the ids that are not exhaustively annotated in any of the other gts
576
+ for coco_gt in self.coco_gts[1:]:
577
+ exclude_img_ids = exclude_img_ids.union(
578
+ {
579
+ img["id"]
580
+ for img in coco_gt.dataset["images"]
581
+ if not img["is_instance_exhaustive"]
582
+ }
583
+ )
584
+ # we only eval on instance exhaustive queries
585
+ self.eval_img_ids = [
586
+ img["id"]
587
+ for img in self.coco_gts[0].dataset["images"]
588
+ if (img["is_instance_exhaustive"] and img["id"] not in exclude_img_ids)
589
+ ]
590
+
591
+ def evaluate(self, pred_file: str):
592
+ """
593
+ Evaluate the detections using cgF1 metric.
594
+
595
+ Args:
596
+ pred_file: path to the predictions COCO json file
597
+
598
+ """
599
+ assert len(self.coco_gts) > 0, "No ground truth provided for evaluation."
600
+ assert len(self.coco_gts) == len(
601
+ self.coco_evals
602
+ ), "Mismatch in number of ground truths and evaluators."
603
+
604
+ if self.verbose:
605
+ print(f"Loading predictions from {pred_file}")
606
+
607
+ with open(pred_file, "r") as f:
608
+ preds = json.load(f)
609
+
610
+ if self.verbose:
611
+ print(f"Loaded {len(preds)} predictions")
612
+
613
+ img2preds = defaultdict(list)
614
+ for pred in preds:
615
+ img2preds[pred["image_id"]].append(pred)
616
+
617
+ all_eval_imgs = []
618
+ for img_id in tqdm(self.eval_img_ids, disable=not self.verbose):
619
+ results = img2preds[img_id]
620
+ all_scorings = []
621
+ for cur_coco_gt, coco_eval in zip(self.coco_gts, self.coco_evals):
622
+ # suppress pycocotools prints
623
+ with open(os.devnull, "w") as devnull:
624
+ with contextlib.redirect_stdout(devnull):
625
+ coco_dt = (
626
+ cur_coco_gt.loadRes(results) if results else COCOCustom()
627
+ )
628
+
629
+ coco_eval.cocoDt = coco_dt
630
+ coco_eval.params.imgIds = [img_id]
631
+ coco_eval.params.useCats = False
632
+ img_ids, eval_imgs = _evaluate(coco_eval)
633
+ all_scorings.append(eval_imgs)
634
+ selected = self._select_best_scoring(all_scorings)
635
+ all_eval_imgs.append(selected)
636
+
637
+ # After this point, we have selected the best scoring per image among several ground truths
638
+ # we can now accumulate and summarize, using only the first coco_eval
639
+
640
+ self.coco_evals[0].evalImgs = list(
641
+ np.concatenate(all_eval_imgs, axis=2).flatten()
642
+ )
643
+ self.coco_evals[0].params.imgIds = self.eval_img_ids
644
+ self.coco_evals[0]._paramsEval = copy.deepcopy(self.coco_evals[0].params)
645
+
646
+ if self.verbose:
647
+ print(f"Accumulating results")
648
+ self.coco_evals[0].accumulate()
649
+ print("cgF1 metric, IoU type={}".format(self.iou_type))
650
+ self.coco_evals[0].summarize()
651
+ print()
652
+
653
+ out = {}
654
+ for i, value in enumerate(self.coco_evals[0].stats):
655
+ name = CGF1_METRICS[i].name
656
+ if CGF1_METRICS[i].iou_threshold is not None:
657
+ name = f"{name}@{CGF1_METRICS[i].iou_threshold}"
658
+ out[f"cgF1_eval_{self.iou_type}_{name}"] = float(value)
659
+
660
+ return out
661
+
662
+ @staticmethod
663
+ def _select_best_scoring(scorings):
664
+ # This function is used for "oracle" type evaluation.
665
+ # It accepts the evaluation results with respect to several ground truths, and picks the best
666
+ if len(scorings) == 1:
667
+ return scorings[0]
668
+
669
+ assert (
670
+ scorings[0].ndim == 3
671
+ ), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
672
+ assert (
673
+ scorings[0].shape[0] == 1
674
+ ), f"Expecting a single category, got {scorings[0].shape[0]}"
675
+
676
+ for scoring in scorings:
677
+ assert (
678
+ scoring.shape == scorings[0].shape
679
+ ), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
680
+
681
+ selected_imgs = []
682
+ for img_id in range(scorings[0].shape[-1]):
683
+ best = scorings[0][:, :, img_id]
684
+
685
+ for scoring in scorings[1:]:
686
+ current = scoring[:, :, img_id]
687
+ if "local_F1s" in best[0, 0] and "local_F1s" in current[0, 0]:
688
+ # we were able to compute a F1 score for this particular image in both evaluations
689
+ # best["local_F1s"] contains the results at various IoU thresholds. We simply take the average for comparision
690
+ best_score = best[0, 0]["local_F1s"].mean()
691
+ current_score = current[0, 0]["local_F1s"].mean()
692
+ if current_score > best_score:
693
+ best = current
694
+
695
+ else:
696
+ # If we're here, it means that in that in some evaluation we were not able to get a valid local F1
697
+ # This happens when both the predictions and targets are empty. In that case, we can assume it's a perfect prediction
698
+ if "local_F1s" not in current[0, 0]:
699
+ best = current
700
+ selected_imgs.append(best)
701
+ result = np.stack(selected_imgs, axis=-1)
702
+ assert result.shape == scorings[0].shape
703
+ return result
sam3/eval/coco_eval.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ """
4
+ COCO evaluator that works in distributed mode.
5
+
6
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
7
+ The difference is that there is less copy-pasting from pycocotools
8
+ in the end of the file, as python3 can suppress prints with contextlib
9
+ """
10
+
11
+ import contextlib
12
+ import copy
13
+ import json
14
+ import logging
15
+ import os
16
+ import pickle
17
+ from collections import defaultdict
18
+ from pathlib import Path
19
+
20
+ from typing import Any, List, Optional
21
+
22
+ import numpy as np
23
+
24
+ import pycocotools.mask as mask_utils
25
+ import torch
26
+ from iopath.common.file_io import g_pathmgr
27
+ from pycocotools.coco import COCO
28
+ from pycocotools.cocoeval import COCOeval
29
+
30
+ from sam3.train.masks_ops import rle_encode
31
+
32
+ from sam3.train.utils.distributed import (
33
+ all_gather,
34
+ gather_to_rank_0_via_filesys,
35
+ get_rank,
36
+ is_main_process,
37
+ )
38
+
39
+ RARITY_BUCKETS = {0: "frequent", 1: "common", 2: "medium", 3: "rare"}
40
+
41
+
42
+ class CocoEvaluator:
43
+ def __init__(
44
+ self,
45
+ coco_gt,
46
+ iou_types: List[str],
47
+ useCats: bool,
48
+ dump_dir: Optional[str],
49
+ postprocessor,
50
+ average_by_rarity=False,
51
+ metrics_dump_dir: Optional[str] = None,
52
+ gather_pred_via_filesys=False,
53
+ use_normalized_areas=True,
54
+ maxdets=[1, 10, 100],
55
+ exhaustive_only=False,
56
+ all_exhaustive_only=True,
57
+ ):
58
+ """Online coco evaluator. It will evaluate images as they are generated by the model, then accumulate/summarize at the end
59
+
60
+ Args:
61
+ - coco_gt: COCO api object containing the gt
62
+ - iou_types: can be either "bbox" or "segm"
63
+ - useCats: If true, categories will be used for evaluation
64
+ - dump_dir: if non null, then the predictions will be dumped in that directory
65
+ - postprocessor: Module to convert the model's output into the coco format
66
+ - average_by_rarity: if true then we expect the images information in the gt dataset
67
+ to have a "rarity" field. Then the AP will be computed on all rarity buckets
68
+ individually, then averaged
69
+ - gather_pred_via_filesys: if true, we use the filesystem for collective gathers
70
+ - use_normalized_areas: if true, the areas of the objects in the GT are assumed to be
71
+ normalized by the area of the image. In that case, the size buckets are adjusted
72
+ - maxdets: maximal number of detections to be evaluated on each image.
73
+ - exhaustive_only: If true, we restrict eval only to exhaustive annotations
74
+ - all_exhaustive_only: If true, datapoints are restricted only to those with all exhaustive annotations
75
+
76
+ """
77
+ # coco_gt = copy.deepcopy(coco_gt)
78
+ self.coco_gts = [coco_gt] if not isinstance(coco_gt, list) else coco_gt
79
+ assert len(maxdets) == 3, f"expecting 3 detection threshold, got {len(maxdets)}"
80
+
81
+ self.use_normalized_areas = use_normalized_areas
82
+ self.iou_types = iou_types
83
+ self.useCats = useCats
84
+ self.maxdets = maxdets
85
+ self.dump = None
86
+ self.dump_dir = dump_dir
87
+ if self.dump_dir is not None:
88
+ self.dump = []
89
+ if is_main_process():
90
+ if not os.path.exists(self.dump_dir):
91
+ os.makedirs(self.dump_dir, exist_ok=True)
92
+ logging.info(f"Create the folder: {dump_dir}")
93
+
94
+ self.initialized = False
95
+
96
+ # Whether to gather predictions through filesystem (instead of torch
97
+ # collective ops; requiring a shared filesystem across all ranks)
98
+ self.gather_pred_via_filesys = gather_pred_via_filesys
99
+ self.use_self_evaluate = True # CPP version is disabled
100
+ self.postprocessor = postprocessor
101
+ self.average_by_rarity = average_by_rarity
102
+ self.exhaustive_only = exhaustive_only
103
+ self.all_exhaustive_only = all_exhaustive_only
104
+ self.metrics_dump_dir = metrics_dump_dir
105
+ if self.metrics_dump_dir is not None:
106
+ if is_main_process():
107
+ if not os.path.exists(self.metrics_dump_dir):
108
+ os.makedirs(self.metrics_dump_dir, exist_ok=True)
109
+ logging.info(f"Create the folder: {metrics_dump_dir}")
110
+
111
+ def _lazy_init(self, coco_cls=COCO):
112
+ if self.initialized:
113
+ return
114
+
115
+ self.initialized = True
116
+
117
+ self.coco_gts = [
118
+ coco_cls(g_pathmgr.get_local_path(gt)) if isinstance(gt, str) else gt
119
+ for gt in self.coco_gts
120
+ ]
121
+
122
+ self.reset()
123
+
124
+ self.eval_img_ids = None
125
+
126
+ if self.exhaustive_only:
127
+ exclude_img_ids = set()
128
+ # exclude_img_ids are the ids that are not exhaustively annotated in any of the other gts
129
+ if self.all_exhaustive_only:
130
+ for coco_gt in self.coco_gts[1:]:
131
+ exclude_img_ids = exclude_img_ids.union(
132
+ {
133
+ img["id"]
134
+ for img in coco_gt.dataset["images"]
135
+ if not img["is_instance_exhaustive"]
136
+ }
137
+ )
138
+ # we only eval on instance exhaustive queries
139
+ self.eval_img_ids = [
140
+ img["id"]
141
+ for img in self.coco_gts[0].dataset["images"]
142
+ if (img["is_instance_exhaustive"] and img["id"] not in exclude_img_ids)
143
+ ]
144
+
145
+ self.rarity_buckets = None
146
+ if self.average_by_rarity:
147
+ self.rarity_buckets = defaultdict(list)
148
+ eval_img_ids_set = (
149
+ set(self.eval_img_ids) if self.eval_img_ids is not None else None
150
+ )
151
+ for img in self.coco_gts[0].dataset["images"]:
152
+ if self.eval_img_ids is not None and img["id"] not in eval_img_ids_set:
153
+ continue
154
+ self.rarity_buckets[img["rarity"]].append(img["id"])
155
+ print("Rarity buckets sizes:")
156
+ for k, v in self.rarity_buckets.items():
157
+ print(f"{k}: {len(v)}")
158
+
159
+ def set_sync_device(self, device: torch.device) -> Any:
160
+ self._sync_device = device
161
+
162
+ def _evaluate(self, *args, **kwargs):
163
+ return evaluate(*args, **kwargs)
164
+
165
+ def _loadRes(self, *args, **kwargs):
166
+ return loadRes(*args, **kwargs)
167
+
168
+ def update(self, *args, **kwargs):
169
+ self._lazy_init()
170
+ predictions = self.postprocessor.process_results(*args, **kwargs)
171
+
172
+ img_ids = list(np.unique(list(predictions.keys())))
173
+ self.img_ids.extend(img_ids)
174
+
175
+ for iou_type in self.iou_types:
176
+ results = self.prepare(predictions, iou_type)
177
+ self._dump(results)
178
+
179
+ assert len(self.coco_gts) == len(self.coco_evals)
180
+ all_scorings = []
181
+ for cur_coco_gt, cur_coco_eval in zip(self.coco_gts, self.coco_evals):
182
+ # suppress pycocotools prints
183
+ with open(os.devnull, "w") as devnull:
184
+ with contextlib.redirect_stdout(devnull):
185
+ coco_dt = (
186
+ self._loadRes(cur_coco_gt, results) if results else COCO()
187
+ )
188
+
189
+ coco_eval = cur_coco_eval[iou_type]
190
+
191
+ coco_eval.cocoDt = coco_dt
192
+ coco_eval.params.imgIds = list(img_ids)
193
+ coco_eval.params.useCats = self.useCats
194
+ coco_eval.params.maxDets = self.maxdets
195
+ img_ids, eval_imgs = self._evaluate(coco_eval, self.use_self_evaluate)
196
+ all_scorings.append(eval_imgs)
197
+
198
+ selected = self.select_best_scoring(all_scorings)
199
+ self.eval_imgs[iou_type].append(selected)
200
+
201
+ def select_best_scoring(self, scorings):
202
+ # This function is used for "oracle" type evaluation.
203
+ # It accepts the evaluation results with respect to several ground truths, and picks the best
204
+ if len(scorings) == 1:
205
+ return scorings[0]
206
+
207
+ # Currently we don't support Oracle Phrase AP.
208
+ # To implement it, we likely need to modify the cpp code since the eval_image type is opaque
209
+ raise RuntimeError("Not implemented")
210
+
211
+ def _dump(self, results):
212
+ if self.dump is not None:
213
+ dumped_results = copy.deepcopy(results)
214
+ for r in dumped_results:
215
+ if "bbox" not in self.iou_types and "bbox" in r:
216
+ del r["bbox"]
217
+ elif "bbox" in r:
218
+ r["bbox"] = [round(coord, 5) for coord in r["bbox"]]
219
+ r["score"] = round(r["score"], 5)
220
+ self.dump.extend(dumped_results)
221
+
222
+ def synchronize_between_processes(self):
223
+ self._lazy_init()
224
+ logging.info("Coco evaluator: Synchronizing between processes")
225
+ for iou_type in self.iou_types:
226
+ if len(self.eval_imgs[iou_type]) > 0:
227
+ self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
228
+ else:
229
+ num_areas = len(self.coco_evals[0][iou_type].params.areaRng)
230
+ # assuming 1 class
231
+ assert not self.useCats
232
+ self.eval_imgs[iou_type] = np.empty((1, num_areas, 0))
233
+ create_common_coco_eval(
234
+ self.coco_evals[0][iou_type],
235
+ self.img_ids,
236
+ self.eval_imgs[iou_type],
237
+ use_self_evaluate=self.use_self_evaluate,
238
+ gather_pred_via_filesys=self.gather_pred_via_filesys,
239
+ metrics_dump_dir=self.metrics_dump_dir,
240
+ )
241
+ if self.dump is not None:
242
+ dumped_file = Path(self.dump_dir) / f"coco_predictions_{get_rank()}.json"
243
+ logging.info(f"COCO evaluator: Dumping local predictions to {dumped_file}")
244
+ with g_pathmgr.open(str(dumped_file), "w") as f:
245
+ json.dump(self.dump, f)
246
+
247
+ # if self.gather_pred_via_filesys:
248
+ # dump = gather_to_rank_0_via_filesys(self.dump)
249
+ # else:
250
+ # dump = all_gather(self.dump, force_cpu=True)
251
+ # self.dump = sum(dump, [])
252
+
253
+ def accumulate(self, imgIds=None):
254
+ self._lazy_init()
255
+ logging.info(
256
+ f"Coco evaluator: Accumulating on {len(imgIds) if imgIds is not None else 'all'} images"
257
+ )
258
+ if not is_main_process():
259
+ return
260
+
261
+ if imgIds is None:
262
+ for coco_eval in self.coco_evals[0].values():
263
+ accumulate(coco_eval, use_self_eval=self.use_self_evaluate)
264
+
265
+ if imgIds is not None:
266
+ imgIds = set(imgIds)
267
+ for coco_eval in self.coco_evals[0].values():
268
+ p = coco_eval.params
269
+ id_mask = np.array([(i in imgIds) for i in p.imgIds], dtype=bool)
270
+ old_img_ids = p.imgIds
271
+ coco_eval.params.imgIds = np.asarray(p.imgIds)[id_mask]
272
+ old_img_evals = coco_eval.evalImgs
273
+ catIds = p.catIds if p.useCats else [-1]
274
+ coco_eval.evalImgs = list(
275
+ np.asarray(coco_eval.evalImgs)
276
+ .reshape(len(catIds), len(p.areaRng), len(old_img_ids))[
277
+ ..., id_mask
278
+ ]
279
+ .flatten()
280
+ )
281
+ accumulate(coco_eval, use_self_eval=self.use_self_evaluate)
282
+ coco_eval.evalImgs = old_img_evals
283
+ coco_eval.params.imgIds = old_img_ids
284
+
285
+ def summarize(self):
286
+ self._lazy_init()
287
+ logging.info("Coco evaluator: Summarizing")
288
+ if not is_main_process():
289
+ return {}
290
+
291
+ outs = {}
292
+ if self.rarity_buckets is None:
293
+ self.accumulate(self.eval_img_ids)
294
+ for iou_type, coco_eval in self.coco_evals[0].items():
295
+ print("IoU metric: {}".format(iou_type))
296
+ summarize(coco_eval)
297
+
298
+ if "bbox" in self.coco_evals[0]:
299
+ for key, value in zip(*self.coco_evals[0]["bbox"].stats):
300
+ outs[f"coco_eval_bbox_{key}"] = value
301
+ if "segm" in self.coco_evals[0]:
302
+ for key, value in zip(*self.coco_evals[0]["segm"].stats):
303
+ outs[f"coco_eval_masks_{key}"] = value
304
+ else:
305
+ total_stats = {}
306
+ all_keys = {}
307
+ for bucket, img_list in self.rarity_buckets.items():
308
+ self.accumulate(imgIds=img_list)
309
+ bucket_name = RARITY_BUCKETS[bucket]
310
+ for iou_type, coco_eval in self.coco_evals[0].items():
311
+ print(f"IoU metric: {iou_type}. Rarity bucket: {bucket_name}")
312
+ summarize(coco_eval)
313
+
314
+ if "bbox" in self.coco_evals[0]:
315
+ if "bbox" not in total_stats:
316
+ total_stats["bbox"] = np.zeros_like(
317
+ self.coco_evals[0]["bbox"].stats[1]
318
+ )
319
+ all_keys["bbox"] = self.coco_evals[0]["bbox"].stats[0]
320
+ total_stats["bbox"] += self.coco_evals[0]["bbox"].stats[1]
321
+ for key, value in zip(*self.coco_evals[0]["bbox"].stats):
322
+ outs[f"coco_eval_bbox_{bucket_name}_{key}"] = value
323
+ if "segm" in self.coco_evals[0]:
324
+ if "segm" not in total_stats:
325
+ total_stats["segm"] = np.zeros_like(
326
+ self.coco_evals[0]["segm"].stats[1]
327
+ )
328
+ all_keys["segm"] = self.coco_evals[0]["segm"].stats[0]
329
+ total_stats["segm"] += self.coco_evals[0]["segm"].stats[1]
330
+ for key, value in zip(*self.coco_evals[0]["segm"].stats):
331
+ outs[f"coco_eval_masks_{bucket_name}_{key}"] = value
332
+
333
+ if "bbox" in total_stats:
334
+ total_stats["bbox"] /= len(self.rarity_buckets)
335
+ for key, value in zip(all_keys["bbox"], total_stats["bbox"]):
336
+ outs[f"coco_eval_bbox_{key}"] = value
337
+ if "segm" in total_stats:
338
+ total_stats["segm"] /= len(self.rarity_buckets)
339
+ for key, value in zip(all_keys["segm"], total_stats["segm"]):
340
+ outs[f"coco_eval_masks_{key}"] = value
341
+
342
+ # if self.dump is not None:
343
+ # assert self.dump_dir is not None
344
+ # logging.info("Coco evaluator: Dumping the global result file to disk")
345
+ # with g_pathmgr.open(str(Path(self.dump_dir) / "coco_eval.json"), "w") as f:
346
+ # json.dump(self.dump, f)
347
+ return outs
348
+
349
+ def compute_synced(self):
350
+ self._lazy_init()
351
+ self.synchronize_between_processes()
352
+ return self.summarize()
353
+
354
+ def compute(self):
355
+ self._lazy_init()
356
+ return {"": 0.0}
357
+
358
+ def reset(self, cocoeval_cls=COCOeval):
359
+ self.coco_evals = [{} for _ in range(len(self.coco_gts))]
360
+ for i, coco_gt in enumerate(self.coco_gts):
361
+ for iou_type in self.iou_types:
362
+ self.coco_evals[i][iou_type] = cocoeval_cls(coco_gt, iouType=iou_type)
363
+ self.coco_evals[i][iou_type].params.useCats = self.useCats
364
+ self.coco_evals[i][iou_type].params.maxDets = self.maxdets
365
+ if self.use_normalized_areas:
366
+ self.coco_evals[i][iou_type].params.areaRng = [
367
+ [0, 1e5],
368
+ [0, 0.001],
369
+ [0.001, 0.01],
370
+ [0.01, 0.1],
371
+ [0.1, 0.5],
372
+ [0.5, 0.95],
373
+ [0.95, 1e5],
374
+ ]
375
+ self.coco_evals[i][iou_type].params.areaRngLbl = [
376
+ "all",
377
+ "tiny",
378
+ "small",
379
+ "medium",
380
+ "large",
381
+ "huge",
382
+ "whole_image",
383
+ ]
384
+
385
+ self.img_ids = []
386
+ self.eval_imgs = {k: [] for k in self.iou_types}
387
+ if self.dump is not None:
388
+ self.dump = []
389
+
390
+ def write(self, stats):
391
+ self._lazy_init()
392
+ """Write the results in the stats dict"""
393
+ if "bbox" in self.coco_evals[0]:
394
+ stats["coco_eval_bbox"] = self.coco_evals[0]["bbox"].stats.tolist()
395
+ if "segm" in self.coco_evals[0]:
396
+ stats["coco_eval_masks"] = self.coco_evals[0]["segm"].stats.tolist()
397
+ return stats
398
+
399
+ def prepare(self, predictions, iou_type):
400
+ self._lazy_init()
401
+ if iou_type == "bbox":
402
+ return self.prepare_for_coco_detection(predictions)
403
+ elif iou_type == "segm":
404
+ return self.prepare_for_coco_segmentation(predictions)
405
+ elif iou_type == "keypoints":
406
+ return self.prepare_for_coco_keypoint(predictions)
407
+ else:
408
+ raise ValueError("Unknown iou type {}".format(iou_type))
409
+
410
+ def prepare_for_coco_detection(self, predictions):
411
+ self._lazy_init()
412
+ coco_results = []
413
+ for original_id, prediction in predictions.items():
414
+ if len(prediction) == 0:
415
+ continue
416
+
417
+ boxes = prediction["boxes"]
418
+ boxes = convert_to_xywh(boxes).tolist()
419
+ scores = prediction["scores"].tolist()
420
+ labels = prediction["labels"].tolist()
421
+
422
+ coco_results.extend(
423
+ [
424
+ {
425
+ "image_id": original_id,
426
+ "category_id": labels[k],
427
+ "bbox": box,
428
+ "score": scores[k],
429
+ }
430
+ for k, box in enumerate(boxes)
431
+ ]
432
+ )
433
+ return coco_results
434
+
435
+ @torch.no_grad()
436
+ def prepare_for_coco_segmentation(self, predictions):
437
+ self._lazy_init()
438
+ coco_results = []
439
+ for original_id, prediction in predictions.items():
440
+ if len(prediction) == 0:
441
+ continue
442
+
443
+ scores = prediction["scores"].tolist()
444
+ labels = prediction["labels"].tolist()
445
+ boundaries, dilated_boundaries = None, None
446
+ if "boundaries" in prediction:
447
+ boundaries = prediction["boundaries"]
448
+ dilated_boundaries = prediction["dilated_boundaries"]
449
+ assert dilated_boundaries is not None
450
+ assert len(scores) == len(boundaries)
451
+
452
+ if "masks_rle" in prediction:
453
+ rles = prediction["masks_rle"]
454
+ areas = []
455
+ for rle in rles:
456
+ cur_area = mask_utils.area(rle)
457
+ h, w = rle["size"]
458
+ areas.append(cur_area / (h * w))
459
+ else:
460
+ masks = prediction["masks"]
461
+
462
+ masks = masks > 0.5
463
+ h, w = masks.shape[-2:]
464
+
465
+ areas = masks.flatten(1).sum(1) / (h * w)
466
+ areas = areas.tolist()
467
+
468
+ rles = rle_encode(masks.squeeze(1))
469
+
470
+ # memory clean
471
+ del masks
472
+ del prediction["masks"]
473
+
474
+ assert len(areas) == len(rles) == len(scores)
475
+ for k, rle in enumerate(rles):
476
+ payload = {
477
+ "image_id": original_id,
478
+ "category_id": labels[k],
479
+ "segmentation": rle,
480
+ "score": scores[k],
481
+ "area": areas[k],
482
+ }
483
+ if boundaries is not None:
484
+ payload["boundary"] = boundaries[k]
485
+ payload["dilated_boundary"] = dilated_boundaries[k]
486
+
487
+ coco_results.append(payload)
488
+
489
+ return coco_results
490
+
491
+ def prepare_for_coco_keypoint(self, predictions):
492
+ self._lazy_init()
493
+ coco_results = []
494
+ for original_id, prediction in predictions.items():
495
+ if len(prediction) == 0:
496
+ continue
497
+
498
+ boxes = prediction["boxes"]
499
+ boxes = convert_to_xywh(boxes).tolist()
500
+ scores = prediction["scores"].tolist()
501
+ labels = prediction["labels"].tolist()
502
+ keypoints = prediction["keypoints"]
503
+ keypoints = keypoints.flatten(start_dim=1).tolist()
504
+
505
+ coco_results.extend(
506
+ [
507
+ {
508
+ "image_id": original_id,
509
+ "category_id": labels[k],
510
+ "keypoints": keypoint,
511
+ "score": scores[k],
512
+ }
513
+ for k, keypoint in enumerate(keypoints)
514
+ ]
515
+ )
516
+ return coco_results
517
+
518
+
519
+ def convert_to_xywh(boxes):
520
+ xmin, ymin, xmax, ymax = boxes.unbind(-1)
521
+ return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=-1)
522
+
523
+
524
+ def merge(img_ids, eval_imgs, gather_pred_via_filesys=False):
525
+ if gather_pred_via_filesys:
526
+ # only gather the predictions to rank 0 (other ranks will receive empty
527
+ # lists for `all_img_ids` and `all_eval_imgs`, which should be OK as
528
+ # merging and evaluation are only done on rank 0)
529
+ all_img_ids = gather_to_rank_0_via_filesys(img_ids)
530
+ all_eval_imgs = gather_to_rank_0_via_filesys(eval_imgs)
531
+ else:
532
+ all_img_ids = all_gather(img_ids, force_cpu=True)
533
+ all_eval_imgs = all_gather(eval_imgs, force_cpu=True)
534
+ if not is_main_process():
535
+ return None, None
536
+
537
+ merged_img_ids = []
538
+ for p in all_img_ids:
539
+ merged_img_ids.extend(p)
540
+
541
+ merged_eval_imgs = []
542
+ for p in all_eval_imgs:
543
+ merged_eval_imgs.append(p)
544
+
545
+ merged_img_ids = np.array(merged_img_ids)
546
+ merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
547
+
548
+ # keep only unique (and in sorted order) images
549
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
550
+ merged_eval_imgs = merged_eval_imgs[..., idx]
551
+
552
+ return merged_img_ids, merged_eval_imgs
553
+
554
+
555
+ def create_common_coco_eval(
556
+ coco_eval,
557
+ img_ids,
558
+ eval_imgs,
559
+ use_self_evaluate,
560
+ gather_pred_via_filesys=False,
561
+ metrics_dump_dir=None,
562
+ ):
563
+ img_ids, eval_imgs = merge(img_ids, eval_imgs, gather_pred_via_filesys)
564
+ if not is_main_process():
565
+ return
566
+ if metrics_dump_dir is not None:
567
+ dumped_file = (
568
+ Path(metrics_dump_dir) / f"coco_eval_img_metrics_{get_rank()}.json"
569
+ )
570
+ logging.info(f"COCO evaluator: Dumping local predictions to {dumped_file}")
571
+ with g_pathmgr.open(str(dumped_file), "w") as f:
572
+ json.dump(eval_imgs.squeeze(), f, default=lambda x: x.tolist())
573
+ img_ids = list(img_ids)
574
+
575
+ # If some images were not predicted, we need to create dummy detections for them
576
+ missing_img_ids = set(coco_eval.cocoGt.getImgIds()) - set(img_ids)
577
+ if len(missing_img_ids) > 0:
578
+ print(f"WARNING: {len(missing_img_ids)} images were not predicted!")
579
+ coco_eval.cocoDt = COCO()
580
+ coco_eval.params.imgIds = list(missing_img_ids)
581
+ new_img_ids, new_eval_imgs = evaluate(coco_eval, use_self_evaluate)
582
+ img_ids.extend(new_img_ids)
583
+ eval_imgs = np.concatenate((eval_imgs, new_eval_imgs), axis=2)
584
+
585
+ eval_imgs = list(eval_imgs.flatten())
586
+ assert len(img_ids) == len(coco_eval.cocoGt.getImgIds())
587
+
588
+ coco_eval.evalImgs = eval_imgs
589
+ coco_eval.params.imgIds = img_ids
590
+ coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
591
+
592
+
593
+ #################################################################
594
+ # From pycocotools, just removed the prints and fixed
595
+ # a Python3 bug about unicode not defined
596
+ #################################################################
597
+
598
+
599
+ # Copy of COCO prepare, but doesn't convert anntoRLE
600
+ def segmentation_prepare(self):
601
+ """
602
+ Prepare ._gts and ._dts for evaluation based on params
603
+ :return: None
604
+ """
605
+ p = self.params
606
+ if p.useCats:
607
+ gts = self.cocoGt.loadAnns(
608
+ self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
609
+ )
610
+ dts = self.cocoDt.loadAnns(
611
+ self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
612
+ )
613
+ else:
614
+ gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
615
+ dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
616
+
617
+ for gt in gts:
618
+ gt["ignore"] = gt["ignore"] if "ignore" in gt else 0
619
+ gt["ignore"] = "iscrowd" in gt and gt["iscrowd"]
620
+ if p.iouType == "keypoints":
621
+ gt["ignore"] = (gt["num_keypoints"] == 0) or gt["ignore"]
622
+ self._gts = defaultdict(list) # gt for evaluation
623
+ self._dts = defaultdict(list) # dt for evaluation
624
+ for gt in gts:
625
+ self._gts[gt["image_id"], gt["category_id"]].append(gt)
626
+ for dt in dts:
627
+ self._dts[dt["image_id"], dt["category_id"]].append(dt)
628
+ self.evalImgs = defaultdict(list) # per-image per-category evaluation results
629
+ self.eval = {} # accumulated evaluation results
630
+
631
+
632
+ def evaluate(self, use_self_evaluate):
633
+ """
634
+ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
635
+ :return: None
636
+ """
637
+ # tic = time.time()
638
+ # print('Running per image evaluation...', use_self_evaluate)
639
+ p = self.params
640
+ # add backward compatibility if useSegm is specified in params
641
+ if p.useSegm is not None:
642
+ p.iouType = "segm" if p.useSegm == 1 else "bbox"
643
+ print(
644
+ "useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType)
645
+ )
646
+ # print('Evaluate annotation type *{}*'.format(p.iouType))
647
+ p.imgIds = list(np.unique(p.imgIds))
648
+ if p.useCats:
649
+ p.catIds = list(np.unique(p.catIds))
650
+ p.maxDets = sorted(p.maxDets)
651
+ self.params = p
652
+
653
+ self._prepare()
654
+ # loop through images, area range, max detection number
655
+ catIds = p.catIds if p.useCats else [-1]
656
+
657
+ if p.iouType == "segm" or p.iouType == "bbox":
658
+ computeIoU = self.computeIoU
659
+ elif p.iouType == "keypoints":
660
+ computeIoU = self.computeOks
661
+ self.ious = {
662
+ (imgId, catId): computeIoU(imgId, catId)
663
+ for imgId in p.imgIds
664
+ for catId in catIds
665
+ }
666
+
667
+ maxDet = p.maxDets[-1]
668
+ if use_self_evaluate:
669
+ evalImgs = [
670
+ self.evaluateImg(imgId, catId, areaRng, maxDet)
671
+ for catId in catIds
672
+ for areaRng in p.areaRng
673
+ for imgId in p.imgIds
674
+ ]
675
+ # this is NOT in the pycocotools code, but could be done outside
676
+ evalImgs = np.asarray(evalImgs).reshape(
677
+ len(catIds), len(p.areaRng), len(p.imgIds)
678
+ )
679
+ return p.imgIds, evalImgs
680
+
681
+ # <<<< Beginning of code differences with original COCO API
682
+ # def convert_instances_to_cpp(instances, is_det=False):
683
+ # # Convert annotations for a list of instances in an image to a format that's fast
684
+ # # to access in C++
685
+ # instances_cpp = []
686
+ # for instance in instances:
687
+ # instance_cpp = _CPP.InstanceAnnotation(
688
+ # int(instance["id"]),
689
+ # instance["score"] if is_det else instance.get("score", 0.0),
690
+ # instance["area"],
691
+ # bool(instance.get("iscrowd", 0)),
692
+ # bool(instance.get("ignore", 0)),
693
+ # )
694
+ # instances_cpp.append(instance_cpp)
695
+ # return instances_cpp
696
+
697
+ # # Convert GT annotations, detections, and IOUs to a format that's fast to access in C++
698
+ # ground_truth_instances = [
699
+ # [convert_instances_to_cpp(self._gts[imgId, catId]) for catId in p.catIds]
700
+ # for imgId in p.imgIds
701
+ # ]
702
+ # detected_instances = [
703
+ # [
704
+ # convert_instances_to_cpp(self._dts[imgId, catId], is_det=True)
705
+ # for catId in p.catIds
706
+ # ]
707
+ # for imgId in p.imgIds
708
+ # ]
709
+ # ious = [[self.ious[imgId, catId] for catId in catIds] for imgId in p.imgIds]
710
+
711
+ # if not p.useCats:
712
+ # # For each image, flatten per-category lists into a single list
713
+ # ground_truth_instances = [
714
+ # [[o for c in i for o in c]] for i in ground_truth_instances
715
+ # ]
716
+ # detected_instances = [[[o for c in i for o in c]] for i in detected_instances]
717
+
718
+ # # Call C++ implementation of self.evaluateImgs()
719
+ # _evalImgs_cpp = _CPP.COCOevalEvaluateImages(
720
+ # p.areaRng, maxDet, p.iouThrs, ious, ground_truth_instances, detected_instances
721
+ # )
722
+
723
+ # self._paramsEval = copy.deepcopy(self.params)
724
+ # evalImgs = np.asarray(_evalImgs_cpp).reshape(
725
+ # len(catIds), len(p.areaRng), len(p.imgIds)
726
+ # )
727
+ # return p.imgIds, evalImgs
728
+
729
+
730
+ #################################################################
731
+ # end of straight copy from pycocotools, just removing the prints
732
+ #################################################################
733
+
734
+
735
+ #################################################################
736
+ # From pycocotools, but disabled mask->box conversion which is
737
+ # pointless
738
+ #################################################################
739
+ def loadRes(self, resFile):
740
+ """
741
+ Load result file and return a result api object.
742
+ :param resFile (str) : file name of result file
743
+ :return: res (obj) : result api object
744
+ """
745
+ res = COCO()
746
+ res.dataset["images"] = [img for img in self.dataset["images"]]
747
+
748
+ if type(resFile) == str:
749
+ anns = json.load(open(resFile))
750
+ elif type(resFile) == np.ndarray:
751
+ anns = self.loadNumpyAnnotations(resFile)
752
+ else:
753
+ anns = resFile
754
+ assert type(anns) == list, "results in not an array of objects"
755
+ annsImgIds = [ann["image_id"] for ann in anns]
756
+ assert set(annsImgIds) == (
757
+ set(annsImgIds) & set(self.getImgIds())
758
+ ), "Results do not correspond to current coco set"
759
+ if "caption" in anns[0]:
760
+ imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
761
+ [ann["image_id"] for ann in anns]
762
+ )
763
+ res.dataset["images"] = [
764
+ img for img in res.dataset["images"] if img["id"] in imgIds
765
+ ]
766
+ for id, ann in enumerate(anns):
767
+ ann["id"] = id + 1
768
+ elif "bbox" in anns[0] and not anns[0]["bbox"] == []:
769
+ res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
770
+ for id, ann in enumerate(anns):
771
+ bb = ann["bbox"]
772
+ x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
773
+ if "segmentation" not in ann:
774
+ ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
775
+ ann["area"] = bb[2] * bb[3]
776
+ ann["id"] = id + 1
777
+ ann["iscrowd"] = 0
778
+ elif "segmentation" in anns[0]:
779
+ res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
780
+ for id, ann in enumerate(anns):
781
+ # now only support compressed RLE format as segmentation results
782
+ # ann["area"] = mask_util.area(ann["segmentation"])
783
+ # The following lines are disabled because they are pointless
784
+ # if not 'bbox' in ann:
785
+ # ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
786
+ ann["id"] = id + 1
787
+ ann["iscrowd"] = 0
788
+ elif "keypoints" in anns[0]:
789
+ res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
790
+ for id, ann in enumerate(anns):
791
+ s = ann["keypoints"]
792
+ x = s[0::3]
793
+ y = s[1::3]
794
+ x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
795
+ ann["area"] = (x1 - x0) * (y1 - y0)
796
+ ann["id"] = id + 1
797
+ ann["bbox"] = [x0, y0, x1 - x0, y1 - y0]
798
+
799
+ res.dataset["annotations"] = anns
800
+ res.createIndex()
801
+ return res
802
+
803
+
804
+ #################################################################
805
+ # end of straight copy from pycocotools
806
+ #################################################################
807
+
808
+
809
+ #################################################################
810
+ # From pycocotools, but added handling of custom area rngs, and returns stat keys
811
+ #################################################################
812
+ def summarize(self):
813
+ """
814
+ Compute and display summary metrics for evaluation results.
815
+ Note this functin can *only* be applied on the default parameter setting
816
+ """
817
+
818
+ def _summarize(ap=1, iouThr=None, areaRng="all", maxDets=100):
819
+ p = self.params
820
+ iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
821
+ titleStr = "Average Precision" if ap == 1 else "Average Recall"
822
+ typeStr = "(AP)" if ap == 1 else "(AR)"
823
+ iouStr = (
824
+ "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
825
+ if iouThr is None
826
+ else "{:0.2f}".format(iouThr)
827
+ )
828
+
829
+ aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
830
+ mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
831
+ if ap == 1:
832
+ # dimension of precision: [TxRxKxAxM]
833
+ s = self.eval["precision"]
834
+ # IoU
835
+ if iouThr is not None:
836
+ t = np.where(iouThr == p.iouThrs)[0]
837
+ s = s[t]
838
+ s = s[:, :, :, aind, mind]
839
+ else:
840
+ # dimension of recall: [TxKxAxM]
841
+ s = self.eval["recall"]
842
+ if iouThr is not None:
843
+ t = np.where(iouThr == p.iouThrs)[0]
844
+ s = s[t]
845
+ s = s[:, :, aind, mind]
846
+ if len(s[s > -1]) == 0:
847
+ mean_s = -1
848
+ else:
849
+ mean_s = np.mean(s[s > -1])
850
+ print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
851
+ return mean_s
852
+
853
+ def _summarizeDets():
854
+ nb_results = 6 + (len(self.params.areaRng) - 1) * 2
855
+ assert len(self.params.areaRng) == len(self.params.areaRngLbl)
856
+ stats = np.zeros((nb_results,))
857
+ keys = ["AP", "AP_50", "AP_75"]
858
+ stats[0] = _summarize(1, maxDets=self.params.maxDets[2])
859
+ stats[1] = _summarize(1, iouThr=0.5, maxDets=self.params.maxDets[2])
860
+ stats[2] = _summarize(1, iouThr=0.75, maxDets=self.params.maxDets[2])
861
+ cur_id = 3
862
+ for area in self.params.areaRngLbl[1:]:
863
+ stats[cur_id] = _summarize(1, areaRng=area, maxDets=self.params.maxDets[2])
864
+ cur_id += 1
865
+ keys.append(f"AP_{area}")
866
+ stats[cur_id] = _summarize(0, maxDets=self.params.maxDets[0])
867
+ cur_id += 1
868
+ stats[cur_id] = _summarize(0, maxDets=self.params.maxDets[1])
869
+ cur_id += 1
870
+ stats[cur_id] = _summarize(0, maxDets=self.params.maxDets[2])
871
+ cur_id += 1
872
+ keys += ["AR", "AR_50", "AR_75"]
873
+
874
+ for area in self.params.areaRngLbl[1:]:
875
+ stats[cur_id] = _summarize(0, areaRng=area, maxDets=self.params.maxDets[2])
876
+ cur_id += 1
877
+ keys.append(f"AR_{area}")
878
+ assert len(stats) == len(keys)
879
+ return keys, stats
880
+
881
+ if not self.eval:
882
+ raise Exception("Please run accumulate() first")
883
+ self.stats = _summarizeDets()
884
+
885
+
886
+ #################################################################
887
+ # end of straight copy from pycocotools
888
+ #################################################################
889
+
890
+
891
+ #################################################################
892
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/evaluation/fast_eval_api.py
893
+ # with slight adjustments
894
+ #################################################################
895
+ def accumulate(self, use_self_eval=False):
896
+ """
897
+ Accumulate per image evaluation results and store the result in self.eval. Does not
898
+ support changing parameter settings from those used by self.evaluate()
899
+ """
900
+ if use_self_eval:
901
+ self.accumulate()
902
+ return
903
+ # CPP code is disabled
904
+ # self.eval = _CPP.COCOevalAccumulate(self.params, self.evalImgs)
905
+
906
+ # # recall is num_iou_thresholds X num_categories X num_area_ranges X num_max_detections
907
+ # self.eval["recall"] = np.array(self.eval["recall"]).reshape(
908
+ # self.eval["counts"][:1] + self.eval["counts"][2:]
909
+ # )
910
+
911
+ # # precision and scores are num_iou_thresholds X num_recall_thresholds X num_categories X
912
+ # # num_area_ranges X num_max_detections
913
+ # self.eval["precision"] = np.array(self.eval["precision"]).reshape(
914
+ # self.eval["counts"]
915
+ # )
916
+ # self.eval["scores"] = np.array(self.eval["scores"]).reshape(self.eval["counts"])
sam3/eval/coco_eval_offline.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ """
4
+ This evaluator is meant for regular COCO mAP evaluation, for example on the COCO val set.
5
+
6
+ For Category mAP, we need the model to make predictions for all the categories on every single image.
7
+ In general, since the number of classes can be big, and the API model makes predictions individually for each pair (image, class),
8
+ we may need to split the inference process for a given image in several chunks.
9
+ """
10
+
11
+ import logging
12
+ from collections import defaultdict
13
+
14
+ import torch
15
+ from pycocotools.coco import COCO
16
+ from pycocotools.cocoeval import COCOeval
17
+ from sam3.train.utils.distributed import is_main_process
18
+
19
+ try:
20
+ from tidecv import datasets, TIDE
21
+
22
+ HAS_TIDE = True
23
+ except ImportError:
24
+ HAS_TIDE = False
25
+ print("WARNING: TIDE not installed. Detailed analysis will not be available.")
26
+
27
+
28
+ # the COCO detection metrics (https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/cocoeval.py#L460-L471)
29
+ COCO_METRICS = [
30
+ "AP",
31
+ "AP_50",
32
+ "AP_75",
33
+ "AP_small",
34
+ "AP_medium",
35
+ "AP_large",
36
+ "AR_maxDets@1",
37
+ "AR_maxDets@10",
38
+ "AR_maxDets@100",
39
+ "AR_small",
40
+ "AR_medium",
41
+ "AR_large",
42
+ ]
43
+
44
+
45
+ def convert_to_xywh(boxes):
46
+ """Convert bounding boxes from xyxy format to xywh format."""
47
+ xmin, ymin, xmax, ymax = boxes.unbind(-1)
48
+ return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=-1)
49
+
50
+
51
+ class HeapElement:
52
+ """Utility class to make a heap with a custom comparator"""
53
+
54
+ def __init__(self, val):
55
+ self.val = val
56
+
57
+ def __lt__(self, other):
58
+ return self.val["score"] < other.val["score"]
59
+
60
+
61
+ class COCOevalCustom(COCOeval):
62
+ """
63
+ This is a slightly modified version of the original COCO API with added support for positive split evaluation.
64
+ """
65
+
66
+ def __init__(
67
+ self, cocoGt=None, cocoDt=None, iouType="segm", dt_only_positive=False
68
+ ):
69
+ super().__init__(cocoGt, cocoDt, iouType)
70
+ self.dt_only_positive = dt_only_positive
71
+
72
+ def _prepare(self):
73
+ """
74
+ Prepare ._gts and ._dts for evaluation based on params
75
+ :return: None
76
+ """
77
+
78
+ def _toMask(anns, coco):
79
+ # modify ann['segmentation'] by reference
80
+ for ann in anns:
81
+ rle = coco.annToRLE(ann)
82
+ ann["segmentation"] = rle
83
+
84
+ p = self.params
85
+ if p.useCats:
86
+ gts = self.cocoGt.loadAnns(
87
+ self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
88
+ )
89
+ dts = self.cocoDt.loadAnns(
90
+ self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
91
+ )
92
+ else:
93
+ gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
94
+ dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
95
+
96
+ # convert ground truth to mask if iouType == 'segm'
97
+ if p.iouType == "segm":
98
+ _toMask(gts, self.cocoGt)
99
+ _toMask(dts, self.cocoDt)
100
+ # set ignore flag
101
+ for gt in gts:
102
+ gt["ignore"] = gt["ignore"] if "ignore" in gt else 0
103
+ gt["ignore"] = "iscrowd" in gt and gt["iscrowd"]
104
+ if p.iouType == "keypoints":
105
+ gt["ignore"] = (gt["num_keypoints"] == 0) or gt["ignore"]
106
+ self._gts = defaultdict(list) # gt for evaluation
107
+ self._dts = defaultdict(list) # dt for evaluation
108
+
109
+ _gts_cat_ids = defaultdict(set) # gt for evaluation on positive split
110
+ for gt in gts:
111
+ self._gts[gt["image_id"], gt["category_id"]].append(gt)
112
+ _gts_cat_ids[gt["image_id"]].add(gt["category_id"])
113
+
114
+ #### BEGIN MODIFICATION ####
115
+ for dt in dts:
116
+ if (
117
+ self.dt_only_positive
118
+ and dt["category_id"] not in _gts_cat_ids[dt["image_id"]]
119
+ ):
120
+ continue
121
+ self._dts[dt["image_id"], dt["category_id"]].append(dt)
122
+ #### END MODIFICATION ####
123
+ self.evalImgs = defaultdict(list) # per-image per-category evaluation results
124
+ self.eval = {} # accumulated evaluation results
125
+
126
+
127
+ class CocoEvaluatorOfflineWithPredFileEvaluators:
128
+ def __init__(
129
+ self,
130
+ gt_path,
131
+ tide: bool = True,
132
+ iou_type: str = "bbox",
133
+ positive_split=False,
134
+ ):
135
+ self.gt_path = gt_path
136
+ self.tide_enabled = HAS_TIDE and tide
137
+ self.positive_split = positive_split
138
+ self.iou_type = iou_type
139
+
140
+ def evaluate(self, dumped_file):
141
+ if not is_main_process():
142
+ return {}
143
+
144
+ logging.info("OfflineCoco evaluator: Loading groundtruth")
145
+ self.gt = COCO(self.gt_path)
146
+
147
+ # Creating the result file
148
+ logging.info("Coco evaluator: Creating the result file")
149
+ cocoDt = self.gt.loadRes(str(dumped_file))
150
+
151
+ # Run the evaluation
152
+ logging.info("Coco evaluator: Running evaluation")
153
+ coco_eval = COCOevalCustom(
154
+ self.gt, cocoDt, iouType=self.iou_type, dt_only_positive=self.positive_split
155
+ )
156
+ coco_eval.evaluate()
157
+ coco_eval.accumulate()
158
+ coco_eval.summarize()
159
+
160
+ outs = {}
161
+ for i, value in enumerate(coco_eval.stats):
162
+ outs[f"coco_eval_{self.iou_type}_{COCO_METRICS[i]}"] = value
163
+
164
+ if self.tide_enabled:
165
+ logging.info("Coco evaluator: Loading TIDE")
166
+ self.tide_gt = datasets.COCO(self.gt_path)
167
+ self.tide = TIDE(mode="mask" if self.iou_type == "segm" else "bbox")
168
+
169
+ # Run TIDE
170
+ logging.info("Coco evaluator: Running TIDE")
171
+ self.tide.evaluate(
172
+ self.tide_gt, datasets.COCOResult(str(dumped_file)), name="coco_eval"
173
+ )
174
+ self.tide.summarize()
175
+ for k, v in self.tide.get_main_errors()["coco_eval"].items():
176
+ outs[f"coco_eval_{self.iou_type}_TIDE_{k}"] = v
177
+
178
+ for k, v in self.tide.get_special_errors()["coco_eval"].items():
179
+ outs[f"coco_eval_{self.iou_type}_TIDE_{k}"] = v
180
+
181
+ return outs
sam3/eval/coco_reindex.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ """
4
+ Self-contained COCO JSON re-indexing function that creates temporary files.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ import tempfile
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+
13
+
14
+ def reindex_coco_to_temp(input_json_path: str) -> Optional[str]:
15
+ """
16
+ Convert 0-indexed COCO JSON file to 1-indexed and save to temporary location.
17
+
18
+ Args:
19
+ input_json_path: Path to the input COCO JSON file
20
+
21
+ Returns:
22
+ Path to the new 1-indexed JSON file in temporary directory, or None if no conversion needed
23
+
24
+ Raises:
25
+ FileNotFoundError: If input file doesn't exist
26
+ json.JSONDecodeError: If input file is not valid JSON
27
+ ValueError: If input file is not a valid COCO format
28
+ """
29
+
30
+ def is_coco_json(data: Dict[str, Any]) -> bool:
31
+ """Check if data appears to be a COCO format file."""
32
+ if not isinstance(data, dict):
33
+ return False
34
+ # A COCO file should have at least one of these keys
35
+ coco_keys = {"images", "annotations", "categories"}
36
+ return any(key in data for key in coco_keys)
37
+
38
+ def check_zero_indexed(data: Dict[str, Any]) -> Tuple[bool, bool, bool]:
39
+ """
40
+ Check if annotations, images, or categories start from index 0.
41
+
42
+ Returns:
43
+ Tuple of (annotations_zero_indexed, images_zero_indexed, categories_zero_indexed)
44
+ """
45
+ annotations_zero = False
46
+ images_zero = False
47
+ categories_zero = False
48
+
49
+ # Check annotations
50
+ annotations = data.get("annotations", [])
51
+ if annotations and any(ann.get("id", -1) == 0 for ann in annotations):
52
+ annotations_zero = True
53
+
54
+ # Check images
55
+ images = data.get("images", [])
56
+ if images and any(img.get("id", -1) == 0 for img in images):
57
+ images_zero = True
58
+
59
+ # Check categories
60
+ categories = data.get("categories", [])
61
+ if categories and any(cat.get("id", -1) == 0 for cat in categories):
62
+ categories_zero = True
63
+
64
+ return annotations_zero, images_zero, categories_zero
65
+
66
+ def reindex_coco_data(data: Dict[str, Any]) -> Dict[str, Any]:
67
+ """Convert 0-indexed COCO data to 1-indexed."""
68
+ modified_data = data.copy()
69
+
70
+ annotations_zero, images_zero, categories_zero = check_zero_indexed(data)
71
+
72
+ # Create ID mapping for consistency
73
+ image_id_mapping = {}
74
+ category_id_mapping = {}
75
+
76
+ # Process images first (since annotations reference image IDs)
77
+ if images_zero and "images" in modified_data:
78
+ for img in modified_data["images"]:
79
+ old_id = img["id"]
80
+ new_id = old_id + 1
81
+ image_id_mapping[old_id] = new_id
82
+ img["id"] = new_id
83
+
84
+ # Process categories (since annotations reference category IDs)
85
+ if categories_zero and "categories" in modified_data:
86
+ for cat in modified_data["categories"]:
87
+ old_id = cat["id"]
88
+ new_id = old_id + 1
89
+ category_id_mapping[old_id] = new_id
90
+ cat["id"] = new_id
91
+
92
+ # Process annotations
93
+ if "annotations" in modified_data:
94
+ for ann in modified_data["annotations"]:
95
+ # Update annotation ID if needed
96
+ if annotations_zero:
97
+ ann["id"] = ann["id"] + 1
98
+
99
+ # Update image_id reference if images were reindexed
100
+ if images_zero and ann.get("image_id") is not None:
101
+ old_image_id = ann["image_id"]
102
+ if old_image_id in image_id_mapping:
103
+ ann["image_id"] = image_id_mapping[old_image_id]
104
+
105
+ # Update category_id reference if categories were reindexed
106
+ if categories_zero and ann.get("category_id") is not None:
107
+ old_category_id = ann["category_id"]
108
+ if old_category_id in category_id_mapping:
109
+ ann["category_id"] = category_id_mapping[old_category_id]
110
+
111
+ return modified_data
112
+
113
+ # Validate input path
114
+ if not os.path.exists(input_json_path):
115
+ raise FileNotFoundError(f"Input file not found: {input_json_path}")
116
+
117
+ # Load and validate JSON data
118
+ try:
119
+ with open(input_json_path, "r", encoding="utf-8") as f:
120
+ data = json.load(f)
121
+ except json.JSONDecodeError as e:
122
+ raise json.JSONDecodeError(f"Invalid JSON in {input_json_path}: {e}")
123
+
124
+ # Validate COCO format
125
+ if not is_coco_json(data):
126
+ raise ValueError(
127
+ f"File does not appear to be in COCO format: {input_json_path}"
128
+ )
129
+
130
+ # Check if reindexing is needed
131
+ annotations_zero, images_zero, categories_zero = check_zero_indexed(data)
132
+
133
+ if not (annotations_zero or images_zero or categories_zero):
134
+ # No conversion needed - just copy to temp location
135
+ input_path = Path(input_json_path)
136
+ temp_dir = tempfile.mkdtemp()
137
+ temp_filename = f"{input_path.stem}_1_indexed{input_path.suffix}"
138
+ temp_path = os.path.join(temp_dir, temp_filename)
139
+
140
+ with open(temp_path, "w", encoding="utf-8") as f:
141
+ json.dump(data, f, indent=2, ensure_ascii=False)
142
+
143
+ return temp_path
144
+
145
+ # Perform reindexing
146
+ modified_data = reindex_coco_data(data)
147
+
148
+ # Create temporary file
149
+ input_path = Path(input_json_path)
150
+ temp_dir = tempfile.mkdtemp()
151
+ temp_filename = f"{input_path.stem}_1_indexed{input_path.suffix}"
152
+ temp_path = os.path.join(temp_dir, temp_filename)
153
+
154
+ # Write modified data to temporary file
155
+ with open(temp_path, "w", encoding="utf-8") as f:
156
+ json.dump(modified_data, f, indent=2, ensure_ascii=False)
157
+
158
+ return temp_path
159
+
160
+
161
+ # Example usage and test function
162
+ def test_reindex_function():
163
+ """Test the reindex function with a sample COCO file."""
164
+
165
+ # Create a test COCO file
166
+ test_data = {
167
+ "info": {"description": "Test COCO dataset", "version": "1.0", "year": 2023},
168
+ "images": [
169
+ {"id": 0, "width": 640, "height": 480, "file_name": "test1.jpg"},
170
+ {"id": 1, "width": 640, "height": 480, "file_name": "test2.jpg"},
171
+ ],
172
+ "categories": [
173
+ {"id": 0, "name": "person", "supercategory": "person"},
174
+ {"id": 1, "name": "car", "supercategory": "vehicle"},
175
+ ],
176
+ "annotations": [
177
+ {
178
+ "id": 0,
179
+ "image_id": 0,
180
+ "category_id": 0,
181
+ "bbox": [100, 100, 50, 75],
182
+ "area": 3750,
183
+ "iscrowd": 0,
184
+ },
185
+ {
186
+ "id": 1,
187
+ "image_id": 1,
188
+ "category_id": 1,
189
+ "bbox": [200, 150, 120, 80],
190
+ "area": 9600,
191
+ "iscrowd": 0,
192
+ },
193
+ ],
194
+ }
195
+
196
+ # Create temporary test file
197
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
198
+ json.dump(test_data, f, indent=2)
199
+ test_file_path = f.name
200
+
201
+ try:
202
+ # Test the function
203
+ result_path = reindex_coco_to_temp(test_file_path)
204
+ print(f"Original file: {test_file_path}")
205
+ print(f"Converted file: {result_path}")
206
+
207
+ # Load and display the result
208
+ with open(result_path, "r") as f:
209
+ result_data = json.load(f)
210
+
211
+ print("\nConverted data sample:")
212
+ print(f"First image ID: {result_data['images'][0]['id']}")
213
+ print(f"First category ID: {result_data['categories'][0]['id']}")
214
+ print(f"First annotation ID: {result_data['annotations'][0]['id']}")
215
+ print(f"First annotation image_id: {result_data['annotations'][0]['image_id']}")
216
+ print(
217
+ f"First annotation category_id: {result_data['annotations'][0]['category_id']}"
218
+ )
219
+
220
+ # Clean up
221
+ os.unlink(result_path)
222
+ os.rmdir(os.path.dirname(result_path))
223
+
224
+ finally:
225
+ # Clean up test file
226
+ os.unlink(test_file_path)
227
+
228
+
229
+ if __name__ == "__main__":
230
+ test_reindex_function()
sam3/eval/coco_writer.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ """
4
+ COCO prediction dumper for distributed training.
5
+
6
+ Handles collection and dumping of COCO-format predictions from models.
7
+ Supports distributed processing with multiple GPUs/processes.
8
+ """
9
+
10
+ import copy
11
+ import gc
12
+ import heapq
13
+ import json
14
+ import logging
15
+ import os
16
+ from collections import defaultdict
17
+ from pathlib import Path
18
+ from typing import Any, Optional
19
+
20
+ import pycocotools.mask as mask_utils
21
+ import torch
22
+ from iopath.common.file_io import g_pathmgr
23
+ from sam3.eval.coco_eval_offline import convert_to_xywh
24
+ from sam3.train.masks_ops import rle_encode
25
+ from sam3.train.utils.distributed import (
26
+ all_gather,
27
+ gather_to_rank_0_via_filesys,
28
+ get_rank,
29
+ is_main_process,
30
+ )
31
+
32
+
33
+ ### Helper functions and classes
34
+
35
+
36
+ class HeapElement:
37
+ """Utility class to make a heap with a custom comparator based on score."""
38
+
39
+ def __init__(self, val):
40
+ self.val = val
41
+
42
+ def __lt__(self, other):
43
+ return self.val["score"] < other.val["score"]
44
+
45
+
46
+ class PredictionDumper:
47
+ """
48
+ Handles collection and dumping of COCO-format predictions from a model.
49
+
50
+ This class processes model outputs through a postprocessor, converts them to COCO format,
51
+ and saves them to disk. It supports distributed processing with multiple GPUs/processes.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ dump_dir: str,
57
+ postprocessor,
58
+ maxdets: int,
59
+ iou_type: str,
60
+ gather_pred_via_filesys: bool = False,
61
+ merge_predictions: bool = False,
62
+ pred_file_evaluators: Optional[Any] = None,
63
+ ):
64
+ """
65
+ Initialize the PredictionDumper.
66
+
67
+ Args:
68
+ dump_dir: Directory to dump predictions.
69
+ postprocessor: Module to convert the model's output into COCO format.
70
+ maxdets: Maximum number of detections per image.
71
+ iou_type: IoU type to evaluate. Can include "bbox", "segm"
72
+ gather_pred_via_filesys: If True, use the filesystem for collective gathers across
73
+ processes (requires a shared filesystem). Otherwise, use torch collective ops.
74
+ merge_predictions: If True, merge predictions from all processes and dump to a single file.
75
+ """
76
+ self.iou_type = iou_type
77
+ self.maxdets = maxdets
78
+ self.dump_dir = dump_dir
79
+ self.postprocessor = postprocessor
80
+ self.gather_pred_via_filesys = gather_pred_via_filesys
81
+ self.merge_predictions = merge_predictions
82
+ self.pred_file_evaluators = pred_file_evaluators
83
+ if self.pred_file_evaluators is not None:
84
+ assert (
85
+ merge_predictions
86
+ ), "merge_predictions must be True if pred_file_evaluators are provided"
87
+ assert self.dump_dir is not None, "dump_dir must be provided"
88
+
89
+ if is_main_process():
90
+ os.makedirs(self.dump_dir, exist_ok=True)
91
+ logging.info(f"Created prediction dump directory: {self.dump_dir}")
92
+
93
+ # Initialize state
94
+ self.reset()
95
+
96
+ def update(self, *args, **kwargs):
97
+ """
98
+ Process and accumulate predictions from model outputs.
99
+
100
+ Args:
101
+ *args, **kwargs: Arguments passed to postprocessor.process_results()
102
+ """
103
+ predictions = self.postprocessor.process_results(*args, **kwargs)
104
+ results = self.prepare(predictions, self.iou_type)
105
+ self._dump(results)
106
+
107
+ def _dump(self, results):
108
+ """
109
+ Add results to the dump list with precision rounding.
110
+
111
+ Args:
112
+ results: List of prediction dictionaries in COCO format.
113
+ """
114
+ dumped_results = copy.deepcopy(results)
115
+ for r in dumped_results:
116
+ if "bbox" in r:
117
+ r["bbox"] = [round(coord, 5) for coord in r["bbox"]]
118
+ r["score"] = round(r["score"], 5)
119
+ self.dump.extend(dumped_results)
120
+
121
+ def synchronize_between_processes(self):
122
+ """
123
+ Synchronize predictions across all processes and save to disk.
124
+
125
+ If gather_pred_via_filesys is True, uses filesystem for gathering.
126
+ Otherwise, uses torch distributed collective operations.
127
+ Saves per-rank predictions to separate JSON files.
128
+ """
129
+ logging.info("Prediction Dumper: Synchronizing between processes")
130
+
131
+ if not self.merge_predictions:
132
+ dumped_file = (
133
+ Path(self.dump_dir)
134
+ / f"coco_predictions_{self.iou_type}_{get_rank()}.json"
135
+ )
136
+ logging.info(
137
+ f"Prediction Dumper: Dumping local predictions to {dumped_file}"
138
+ )
139
+ with g_pathmgr.open(str(dumped_file), "w") as f:
140
+ json.dump(self.dump, f)
141
+ else:
142
+ self.dump = self.gather_and_merge_predictions()
143
+ dumped_file = Path(self.dump_dir) / f"coco_predictions_{self.iou_type}.json"
144
+ if is_main_process():
145
+ logging.info(
146
+ f"Prediction Dumper: Dumping merged predictions to {dumped_file}"
147
+ )
148
+ with g_pathmgr.open(str(dumped_file), "w") as f:
149
+ json.dump(self.dump, f)
150
+
151
+ self.reset()
152
+ return dumped_file
153
+
154
+ def gather_and_merge_predictions(self):
155
+ """
156
+ Gather predictions from all processes and merge them, keeping top predictions per image.
157
+
158
+ This method collects predictions from all processes, then keeps only the top maxdets
159
+ predictions per image based on score. It also deduplicates predictions by (image_id, category_id).
160
+
161
+ Returns:
162
+ List of merged prediction dictionaries.
163
+ """
164
+ logging.info("Prediction Dumper: Gathering predictions from all processes")
165
+ gc.collect()
166
+
167
+ if self.gather_pred_via_filesys:
168
+ dump = gather_to_rank_0_via_filesys(self.dump)
169
+ else:
170
+ dump = all_gather(self.dump, force_cpu=True)
171
+
172
+ # Combine predictions, keeping only top maxdets per image
173
+ preds_by_image = defaultdict(list)
174
+ seen_img_cat = set()
175
+
176
+ for cur_dump in dump:
177
+ cur_seen_img_cat = set()
178
+ for p in cur_dump:
179
+ image_id = p["image_id"]
180
+ cat_id = p["category_id"]
181
+
182
+ # Skip if we've already seen this image/category pair in a previous dump
183
+ if (image_id, cat_id) in seen_img_cat:
184
+ continue
185
+
186
+ cur_seen_img_cat.add((image_id, cat_id))
187
+
188
+ # Use a min-heap to keep top predictions
189
+ if len(preds_by_image[image_id]) < self.maxdets:
190
+ heapq.heappush(preds_by_image[image_id], HeapElement(p))
191
+ else:
192
+ heapq.heappushpop(preds_by_image[image_id], HeapElement(p))
193
+
194
+ seen_img_cat.update(cur_seen_img_cat)
195
+
196
+ # Flatten the heap elements back to a list
197
+ merged_dump = sum(
198
+ [[h.val for h in cur_preds] for cur_preds in preds_by_image.values()], []
199
+ )
200
+
201
+ return merged_dump
202
+
203
+ def compute_synced(self):
204
+ """
205
+ Synchronize predictions across processes and compute summary.
206
+
207
+ Returns:
208
+ Summary dictionary from summarize().
209
+ """
210
+ dumped_file = self.synchronize_between_processes()
211
+ if not is_main_process():
212
+ return {"": 0.0}
213
+
214
+ meters = {}
215
+ if self.pred_file_evaluators is not None:
216
+ for evaluator in self.pred_file_evaluators:
217
+ results = evaluator.evaluate(dumped_file)
218
+ meters.update(results)
219
+
220
+ if len(meters) == 0:
221
+ meters = {"": 0.0}
222
+ return meters
223
+
224
+ def compute(self):
225
+ """
226
+ Compute without synchronization.
227
+
228
+ Returns:
229
+ Empty metric dictionary.
230
+ """
231
+ return {"": 0.0}
232
+
233
+ def reset(self):
234
+ """Reset internal state for a new evaluation round."""
235
+ self.dump = []
236
+
237
+ def prepare(self, predictions, iou_type):
238
+ """
239
+ Route predictions to the appropriate preparation method based on iou_type.
240
+
241
+ Args:
242
+ predictions: Dictionary mapping image IDs to prediction dictionaries.
243
+ iou_type: Type of evaluation ("bbox", "segm").
244
+
245
+ Returns:
246
+ List of COCO-format prediction dictionaries.
247
+ """
248
+ if iou_type == "bbox":
249
+ return self.prepare_for_coco_detection(predictions)
250
+ elif iou_type == "segm":
251
+ return self.prepare_for_coco_segmentation(predictions)
252
+ else:
253
+ raise ValueError(f"Unknown iou type: {iou_type}")
254
+
255
+ def prepare_for_coco_detection(self, predictions):
256
+ """
257
+ Convert predictions to COCO detection format.
258
+
259
+ Args:
260
+ predictions: Dictionary mapping image IDs to prediction dictionaries
261
+ containing "boxes", "scores", and "labels".
262
+
263
+ Returns:
264
+ List of COCO-format detection dictionaries.
265
+ """
266
+ coco_results = []
267
+ for original_id, prediction in predictions.items():
268
+ if len(prediction) == 0:
269
+ continue
270
+
271
+ boxes = prediction["boxes"]
272
+ boxes = convert_to_xywh(boxes).tolist()
273
+ scores = prediction["scores"].tolist()
274
+ labels = prediction["labels"].tolist()
275
+
276
+ coco_results.extend(
277
+ [
278
+ {
279
+ "image_id": original_id,
280
+ "category_id": labels[k],
281
+ "bbox": box,
282
+ "score": scores[k],
283
+ }
284
+ for k, box in enumerate(boxes)
285
+ ]
286
+ )
287
+ return coco_results
288
+
289
+ @torch.no_grad()
290
+ def prepare_for_coco_segmentation(self, predictions):
291
+ """
292
+ Convert predictions to COCO segmentation format.
293
+
294
+ Args:
295
+ predictions: Dictionary mapping image IDs to prediction dictionaries
296
+ containing "masks" or "masks_rle", "scores", and "labels".
297
+ Optionally includes "boundaries" and "dilated_boundaries".
298
+
299
+ Returns:
300
+ List of COCO-format segmentation dictionaries with RLE-encoded masks.
301
+ """
302
+ coco_results = []
303
+ for original_id, prediction in predictions.items():
304
+ if len(prediction) == 0:
305
+ continue
306
+
307
+ scores = prediction["scores"].tolist()
308
+ labels = prediction["labels"].tolist()
309
+
310
+ boxes = None
311
+ if "boxes" in prediction:
312
+ boxes = prediction["boxes"]
313
+ boxes = convert_to_xywh(boxes).tolist()
314
+ assert len(boxes) == len(scores)
315
+
316
+ if "masks_rle" in prediction:
317
+ rles = prediction["masks_rle"]
318
+ areas = []
319
+ for rle in rles:
320
+ cur_area = mask_utils.area(rle)
321
+ h, w = rle["size"]
322
+ areas.append(cur_area / (h * w))
323
+ else:
324
+ masks = prediction["masks"]
325
+ masks = masks > 0.5
326
+ h, w = masks.shape[-2:]
327
+
328
+ areas = masks.flatten(1).sum(1) / (h * w)
329
+ areas = areas.tolist()
330
+
331
+ rles = rle_encode(masks.squeeze(1))
332
+
333
+ # Memory cleanup
334
+ del masks
335
+ del prediction["masks"]
336
+
337
+ assert len(areas) == len(rles) == len(scores)
338
+
339
+ for k, rle in enumerate(rles):
340
+ payload = {
341
+ "image_id": original_id,
342
+ "category_id": labels[k],
343
+ "segmentation": rle,
344
+ "score": scores[k],
345
+ "area": areas[k],
346
+ }
347
+ if boxes is not None:
348
+ payload["bbox"] = boxes[k]
349
+
350
+ coco_results.append(payload)
351
+
352
+ return coco_results
sam3/eval/conversion_util.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+ import json
3
+ import os
4
+ from collections import defaultdict
5
+
6
+ from tqdm import tqdm
7
+
8
+
9
+ def convert_ytbvis_to_cocovid_gt(ann_json, save_path=None):
10
+ """Convert YouTube VIS dataset to COCO-style video instance segmentation format.
11
+
12
+ Args:
13
+ ann_json (str): Path to YouTube VIS annotation JSON file
14
+ save_path (str): path to save converted COCO-style JSON
15
+ """
16
+ # Initialize COCO structure
17
+ VIS = {
18
+ "info": {},
19
+ "images": [],
20
+ "videos": [],
21
+ "tracks": [],
22
+ "annotations": [],
23
+ "categories": [],
24
+ "licenses": [],
25
+ }
26
+
27
+ # Load original annotations
28
+ official_anns = json.load(open(ann_json))
29
+ VIS["categories"] = official_anns["categories"] # Direct copy categories
30
+
31
+ # Initialize counters
32
+ records = dict(img_id=1, ann_id=1)
33
+
34
+ # Create video-to-annotations mapping
35
+ vid_to_anns = defaultdict(list)
36
+ for ann in official_anns["annotations"]:
37
+ vid_to_anns[ann["video_id"]].append(ann)
38
+
39
+ # Create tracks directly
40
+ VIS["tracks"] = [
41
+ {
42
+ "id": ann["id"],
43
+ "category_id": ann["category_id"],
44
+ "video_id": ann["video_id"],
45
+ }
46
+ for ann in official_anns["annotations"]
47
+ ]
48
+
49
+ # Process videos
50
+ for video_info in tqdm(official_anns["videos"]):
51
+ # Create video entry
52
+ video = {
53
+ "id": video_info["id"],
54
+ "name": os.path.dirname(video_info["file_names"][0]),
55
+ "width": video_info["width"],
56
+ "height": video_info["height"],
57
+ "length": video_info["length"],
58
+ "neg_category_ids": [],
59
+ "not_exhaustive_category_ids": [],
60
+ }
61
+ VIS["videos"].append(video)
62
+
63
+ # Process frames
64
+ num_frames = len(video_info["file_names"])
65
+ for frame_idx in range(num_frames):
66
+ # Create image entry
67
+ image = {
68
+ "id": records["img_id"],
69
+ "video_id": video_info["id"],
70
+ "file_name": video_info["file_names"][frame_idx],
71
+ "width": video_info["width"],
72
+ "height": video_info["height"],
73
+ "frame_index": frame_idx,
74
+ "frame_id": frame_idx,
75
+ }
76
+ VIS["images"].append(image)
77
+
78
+ # Process annotations for this frame
79
+ if video_info["id"] in vid_to_anns:
80
+ for ann in vid_to_anns[video_info["id"]]:
81
+ bbox = ann["bboxes"][frame_idx]
82
+ if bbox is None:
83
+ continue
84
+
85
+ # Create annotation entry
86
+ annotation = {
87
+ "id": records["ann_id"],
88
+ "video_id": video_info["id"],
89
+ "image_id": records["img_id"],
90
+ "track_id": ann["id"],
91
+ "category_id": ann["category_id"],
92
+ "bbox": bbox,
93
+ "area": ann["areas"][frame_idx],
94
+ "segmentation": ann["segmentations"][frame_idx],
95
+ "iscrowd": ann["iscrowd"],
96
+ }
97
+ VIS["annotations"].append(annotation)
98
+ records["ann_id"] += 1
99
+
100
+ records["img_id"] += 1
101
+
102
+ # Print summary
103
+ print(f"Converted {len(VIS['videos'])} videos")
104
+ print(f"Converted {len(VIS['images'])} images")
105
+ print(f"Created {len(VIS['tracks'])} tracks")
106
+ print(f"Created {len(VIS['annotations'])} annotations")
107
+
108
+ if save_path is None:
109
+ return VIS
110
+
111
+ # Save output
112
+ save_dir = os.path.dirname(save_path)
113
+ os.makedirs(save_dir, exist_ok=True)
114
+ json.dump(VIS, open(save_path, "w"))
115
+
116
+ return VIS
117
+
118
+
119
+ def convert_ytbvis_to_cocovid_pred(
120
+ youtubevis_pred_path: str, converted_dataset_path: str, output_path: str
121
+ ) -> None:
122
+ """
123
+ Convert YouTubeVIS predictions to COCO format with video_id preservation
124
+
125
+ Args:
126
+ youtubevis_pred_path: Path to YouTubeVIS prediction JSON
127
+ converted_dataset_path: Path to converted COCO dataset JSON
128
+ output_path: Path to save COCO format predictions
129
+ """
130
+
131
+ # Load YouTubeVIS predictions
132
+ with open(youtubevis_pred_path) as f:
133
+ ytv_predictions = json.load(f)
134
+
135
+ # Load converted dataset for image ID mapping
136
+ with open(converted_dataset_path) as f:
137
+ coco_dataset = json.load(f)
138
+
139
+ # Create (video_id, frame_idx) -> image_id mapping
140
+ image_id_map = {
141
+ (img["video_id"], img["frame_index"]): img["id"]
142
+ for img in coco_dataset["images"]
143
+ }
144
+
145
+ coco_annotations = []
146
+ track_id_counter = 1 # Unique track ID generator
147
+
148
+ for pred in tqdm(ytv_predictions):
149
+ video_id = pred["video_id"]
150
+ category_id = pred["category_id"]
151
+ bboxes = pred["bboxes"]
152
+ segmentations = pred.get("segmentations", []) # Get segmentations if available
153
+ areas = pred.get("areas", []) # Get areas if available
154
+ score = pred["score"]
155
+
156
+ # Assign unique track ID for this prediction
157
+ track_id = track_id_counter
158
+ track_id_counter += 1
159
+
160
+ # Ensure segmentations and areas have the same length as bboxes
161
+ if len(segmentations) == 0:
162
+ segmentations = [None] * len(bboxes)
163
+ if len(areas) == 0:
164
+ areas = [None] * len(bboxes)
165
+
166
+ for frame_idx, (bbox, segmentation, area_from_pred) in enumerate(
167
+ zip(bboxes, segmentations, areas)
168
+ ):
169
+ # Skip frames with missing objects (None or zero bbox)
170
+ if bbox is None or all(x == 0 for x in bbox):
171
+ continue
172
+
173
+ # Get corresponding image ID from mapping
174
+ image_id = image_id_map.get((video_id, frame_idx))
175
+ if image_id is None:
176
+ raise RuntimeError(
177
+ f"prediction {video_id=}, {frame_idx=} does not match any images in the converted COCO format"
178
+ )
179
+
180
+ # Extract bbox coordinates
181
+ x, y, w, h = bbox
182
+
183
+ # Calculate area - use area from prediction if available, otherwise from bbox
184
+ if area_from_pred is not None and area_from_pred > 0:
185
+ area = area_from_pred
186
+ else:
187
+ area = w * h
188
+
189
+ # Create COCO annotation with video_id
190
+ coco_annotation = {
191
+ "image_id": int(image_id),
192
+ "video_id": video_id, # Added video_id field
193
+ "track_id": track_id,
194
+ "category_id": category_id,
195
+ "bbox": [float(x), float(y), float(w), float(h)],
196
+ "area": float(area),
197
+ "iscrowd": 0,
198
+ "score": float(score),
199
+ }
200
+
201
+ # Add segmentation if available
202
+ if segmentation is not None:
203
+ coco_annotation["segmentation"] = segmentation
204
+
205
+ coco_annotations.append(coco_annotation)
206
+
207
+ # Save output
208
+ with open(output_path, "w") as f:
209
+ json.dump(coco_annotations, f)
210
+
211
+ print(f"Converted {len(coco_annotations)} predictions to COCO format with video_id")
sam3/eval/demo_eval.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ """
4
+ This evaluator is based upon COCO evaluation, but evaluates the model in a "demo" setting.
5
+ This means that the model's predictions are thresholded and evaluated as "hard" predictions.
6
+ """
7
+
8
+ import logging
9
+ from typing import Optional
10
+
11
+ import numpy as np
12
+ import pycocotools.mask as maskUtils
13
+ from pycocotools.cocoeval import COCOeval
14
+
15
+ from sam3.eval.coco_eval import CocoEvaluator
16
+ from sam3.train.masks_ops import compute_F_measure
17
+ from sam3.train.utils.distributed import is_main_process
18
+
19
+ from scipy.optimize import linear_sum_assignment
20
+
21
+
22
+ class DemoEval(COCOeval):
23
+ """
24
+ This evaluator is based upon COCO evaluation, but evaluates the model in a "demo" setting.
25
+ This means that the model's predictions are thresholded and evaluated as "hard" predictions.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ coco_gt=None,
31
+ coco_dt=None,
32
+ iouType="bbox",
33
+ threshold=0.5,
34
+ compute_JnF=False,
35
+ ):
36
+ """
37
+ Args:
38
+ coco_gt (COCO): ground truth COCO API
39
+ coco_dt (COCO): detections COCO API
40
+ iou_type (str): type of IoU to evaluate
41
+ threshold (float): threshold for predictions
42
+ """
43
+ super().__init__(coco_gt, coco_dt, iouType)
44
+ self.threshold = threshold
45
+
46
+ self.params.useCats = False
47
+ self.params.areaRng = [[0**2, 1e5**2]]
48
+ self.params.areaRngLbl = ["all"]
49
+ self.params.maxDets = [100000]
50
+ self.compute_JnF = compute_JnF
51
+
52
+ def computeIoU(self, imgId, catId):
53
+ # Same as the original COCOeval.computeIoU, but without sorting
54
+ p = self.params
55
+ if p.useCats:
56
+ gt = self._gts[imgId, catId]
57
+ dt = self._dts[imgId, catId]
58
+ else:
59
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
60
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
61
+ if len(gt) == 0 and len(dt) == 0:
62
+ return []
63
+
64
+ if p.iouType == "segm":
65
+ g = [g["segmentation"] for g in gt]
66
+ d = [d["segmentation"] for d in dt]
67
+ elif p.iouType == "bbox":
68
+ g = [g["bbox"] for g in gt]
69
+ d = [d["bbox"] for d in dt]
70
+ else:
71
+ raise Exception("unknown iouType for iou computation")
72
+
73
+ # compute iou between each dt and gt region
74
+ iscrowd = [int(o["iscrowd"]) for o in gt]
75
+ ious = maskUtils.iou(d, g, iscrowd)
76
+ return ious
77
+
78
+ def evaluateImg(self, imgId, catId, aRng, maxDet):
79
+ """
80
+ perform evaluation for single category and image
81
+ :return: dict (single image results)
82
+ """
83
+ p = self.params
84
+ assert not p.useCats, "This evaluator does not support per-category evaluation."
85
+ assert catId == -1
86
+ all_gts = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
87
+ keep_gt = np.array([not g["ignore"] for g in all_gts], dtype=bool)
88
+ gt = [g for g in all_gts if not g["ignore"]]
89
+ all_dts = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
90
+ keep_dt = np.array([d["score"] >= self.threshold for d in all_dts], dtype=bool)
91
+ dt = [d for d in all_dts if d["score"] >= self.threshold]
92
+ if len(gt) == 0 and len(dt) == 0:
93
+ # This is a "true negative" case, where there are no GTs and no predictions
94
+ # The box-level metrics are ill-defined, so we don't add them to this dict
95
+ return {
96
+ "image_id": imgId,
97
+ "IL_TP": 0,
98
+ "IL_TN": 1,
99
+ "IL_FP": 0,
100
+ "IL_FN": 0,
101
+ "IL_perfect_neg": np.ones((len(p.iouThrs),), dtype=np.int64),
102
+ "num_dt": len(dt),
103
+ }
104
+
105
+ if len(gt) > 0 and len(dt) == 0:
106
+ # This is a "false negative" case, where there are GTs but no predictions
107
+ return {
108
+ "image_id": imgId,
109
+ "IL_TP": 0,
110
+ "IL_TN": 0,
111
+ "IL_FP": 0,
112
+ "IL_FN": 1,
113
+ "TPs": np.zeros((len(p.iouThrs),), dtype=np.int64),
114
+ "FPs": np.zeros((len(p.iouThrs),), dtype=np.int64),
115
+ "FNs": np.ones((len(p.iouThrs),), dtype=np.int64) * len(gt),
116
+ "local_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64),
117
+ "local_positive_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64),
118
+ "IL_perfect_pos": np.zeros((len(p.iouThrs),), dtype=np.int64),
119
+ "num_dt": len(dt),
120
+ }
121
+
122
+ # Load pre-computed ious
123
+ ious = self.ious[(imgId, catId)]
124
+
125
+ # compute matching
126
+ if len(ious) == 0:
127
+ ious = np.zeros((len(dt), len(gt)))
128
+ else:
129
+ ious = ious[keep_dt, :][:, keep_gt]
130
+ assert ious.shape == (len(dt), len(gt))
131
+
132
+ matched_dt, matched_gt = linear_sum_assignment(-ious)
133
+
134
+ match_scores = ious[matched_dt, matched_gt]
135
+
136
+ if self.compute_JnF and len(match_scores) > 0:
137
+ j_score = match_scores.mean()
138
+ f_measure = 0
139
+ for dt_id, gt_id in zip(matched_dt, matched_gt):
140
+ f_measure += compute_F_measure(
141
+ gt_boundary_rle=gt[gt_id]["boundary"],
142
+ gt_dilated_boundary_rle=gt[gt_id]["dilated_boundary"],
143
+ dt_boundary_rle=dt[dt_id]["boundary"],
144
+ dt_dilated_boundary_rle=dt[dt_id]["dilated_boundary"],
145
+ )
146
+ f_measure /= len(match_scores) + 1e-9
147
+ JnF = (j_score + f_measure) * 0.5
148
+ else:
149
+ j_score = f_measure = JnF = -1
150
+
151
+ TPs, FPs, FNs = [], [], []
152
+ IL_perfect = []
153
+ for thresh in p.iouThrs:
154
+ TP = (match_scores >= thresh).sum()
155
+ FP = len(dt) - TP
156
+ FN = len(gt) - TP
157
+ assert (
158
+ FP >= 0 and FN >= 0
159
+ ), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
160
+ TPs.append(TP)
161
+ FPs.append(FP)
162
+ FNs.append(FN)
163
+
164
+ if FP == FN and FP == 0:
165
+ IL_perfect.append(1)
166
+ else:
167
+ IL_perfect.append(0)
168
+
169
+ TPs = np.array(TPs, dtype=np.int64)
170
+ FPs = np.array(FPs, dtype=np.int64)
171
+ FNs = np.array(FNs, dtype=np.int64)
172
+ IL_perfect = np.array(IL_perfect, dtype=np.int64)
173
+
174
+ # compute precision recall and F1
175
+ precision = TPs / (TPs + FPs + 1e-4)
176
+ assert np.all(precision <= 1)
177
+ recall = TPs / (TPs + FNs + 1e-4)
178
+ assert np.all(recall <= 1)
179
+ F1 = 2 * precision * recall / (precision + recall + 1e-4)
180
+
181
+ result = {
182
+ "image_id": imgId,
183
+ "TPs": TPs,
184
+ "FPs": FPs,
185
+ "FNs": FNs,
186
+ "local_F1s": F1,
187
+ "IL_TP": (len(gt) > 0) and (len(dt) > 0),
188
+ "IL_FP": (len(gt) == 0) and (len(dt) > 0),
189
+ "IL_TN": (len(gt) == 0) and (len(dt) == 0),
190
+ "IL_FN": (len(gt) > 0) and (len(dt) == 0),
191
+ ("IL_perfect_pos" if len(gt) > 0 else "IL_perfect_neg"): IL_perfect,
192
+ "F": f_measure,
193
+ "J": j_score,
194
+ "J&F": JnF,
195
+ "num_dt": len(dt),
196
+ }
197
+ if len(gt) > 0 and len(dt) > 0:
198
+ result["local_positive_F1s"] = F1
199
+ return result
200
+
201
+ def accumulate(self, p=None):
202
+ """
203
+ Accumulate per image evaluation results and store the result in self.eval
204
+ :param p: input params for evaluation
205
+ :return: None
206
+ """
207
+ if not self.evalImgs:
208
+ print("Please run evaluate() first")
209
+ # allows input customized parameters
210
+ if p is None:
211
+ p = self.params
212
+
213
+ setImgIds = set(p.imgIds)
214
+
215
+ # TPs, FPs, FNs
216
+ TPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
217
+ FPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
218
+ pmFPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
219
+ FNs = np.zeros((len(p.iouThrs),), dtype=np.int64)
220
+ local_F1s = np.zeros((len(p.iouThrs),), dtype=np.float64)
221
+
222
+ # Image level metrics
223
+ IL_TPs = 0
224
+ IL_FPs = 0
225
+ IL_TNs = 0
226
+ IL_FNs = 0
227
+ IL_perfects_neg = np.zeros((len(p.iouThrs),), dtype=np.int64)
228
+ IL_perfects_pos = np.zeros((len(p.iouThrs),), dtype=np.int64)
229
+
230
+ # JnF metric
231
+ total_J = 0
232
+ total_F = 0
233
+ total_JnF = 0
234
+
235
+ valid_img_count = 0
236
+ total_pos_count = 0
237
+ total_neg_count = 0
238
+ valid_J_count = 0
239
+ valid_F1_count = 0
240
+ valid_F1_count_w0dt = 0
241
+ for res in self.evalImgs:
242
+ if res["image_id"] not in setImgIds:
243
+ continue
244
+ IL_TPs += res["IL_TP"]
245
+ IL_FPs += res["IL_FP"]
246
+ IL_TNs += res["IL_TN"]
247
+ IL_FNs += res["IL_FN"]
248
+ if "IL_perfect_neg" in res:
249
+ IL_perfects_neg += res["IL_perfect_neg"]
250
+ total_neg_count += 1
251
+ else:
252
+ assert "IL_perfect_pos" in res
253
+ IL_perfects_pos += res["IL_perfect_pos"]
254
+ total_pos_count += 1
255
+
256
+ if "TPs" not in res:
257
+ continue
258
+
259
+ TPs += res["TPs"]
260
+ FPs += res["FPs"]
261
+ FNs += res["FNs"]
262
+ valid_img_count += 1
263
+
264
+ if "local_positive_F1s" in res:
265
+ local_F1s += res["local_positive_F1s"]
266
+ pmFPs += res["FPs"]
267
+ valid_F1_count_w0dt += 1
268
+ if res["num_dt"] > 0:
269
+ valid_F1_count += 1
270
+
271
+ if "J" in res and res["J"] > -1e-9:
272
+ total_J += res["J"]
273
+ total_F += res["F"]
274
+ total_JnF += res["J&F"]
275
+ valid_J_count += 1
276
+
277
+ # compute precision recall and F1
278
+ precision = TPs / (TPs + FPs + 1e-4)
279
+ positive_micro_precision = TPs / (TPs + pmFPs + 1e-4)
280
+ assert np.all(precision <= 1)
281
+ recall = TPs / (TPs + FNs + 1e-4)
282
+ assert np.all(recall <= 1)
283
+ F1 = 2 * precision * recall / (precision + recall + 1e-4)
284
+ positive_micro_F1 = (
285
+ 2
286
+ * positive_micro_precision
287
+ * recall
288
+ / (positive_micro_precision + recall + 1e-4)
289
+ )
290
+
291
+ IL_rec = IL_TPs / (IL_TPs + IL_FNs + 1e-6)
292
+ IL_prec = IL_TPs / (IL_TPs + IL_FPs + 1e-6)
293
+ IL_F1 = 2 * IL_prec * IL_rec / (IL_prec + IL_rec + 1e-6)
294
+ IL_FPR = IL_FPs / (IL_FPs + IL_TNs + 1e-6)
295
+ IL_MCC = float(IL_TPs * IL_TNs - IL_FPs * IL_FNs) / (
296
+ (
297
+ float(IL_TPs + IL_FPs)
298
+ * float(IL_TPs + IL_FNs)
299
+ * float(IL_TNs + IL_FPs)
300
+ * float(IL_TNs + IL_FNs)
301
+ )
302
+ ** 0.5
303
+ + 1e-6
304
+ )
305
+ IL_perfect_pos = IL_perfects_pos / (total_pos_count + 1e-9)
306
+ IL_perfect_neg = IL_perfects_neg / (total_neg_count + 1e-9)
307
+
308
+ total_J = total_J / (valid_J_count + 1e-9)
309
+ total_F = total_F / (valid_J_count + 1e-9)
310
+ total_JnF = total_JnF / (valid_J_count + 1e-9)
311
+
312
+ self.eval = {
313
+ "params": p,
314
+ "TPs": TPs,
315
+ "FPs": FPs,
316
+ "positive_micro_FPs": pmFPs,
317
+ "FNs": FNs,
318
+ "precision": precision,
319
+ "positive_micro_precision": positive_micro_precision,
320
+ "recall": recall,
321
+ "F1": F1,
322
+ "positive_micro_F1": positive_micro_F1,
323
+ "positive_macro_F1": local_F1s / valid_F1_count,
324
+ "positive_w0dt_macro_F1": local_F1s / valid_F1_count_w0dt,
325
+ "IL_recall": IL_rec,
326
+ "IL_precision": IL_prec,
327
+ "IL_F1": IL_F1,
328
+ "IL_FPR": IL_FPR,
329
+ "IL_MCC": IL_MCC,
330
+ "IL_perfect_pos": IL_perfect_pos,
331
+ "IL_perfect_neg": IL_perfect_neg,
332
+ "J": total_J,
333
+ "F": total_F,
334
+ "J&F": total_JnF,
335
+ }
336
+ self.eval["CGF1"] = self.eval["positive_macro_F1"] * self.eval["IL_MCC"]
337
+ self.eval["CGF1_w0dt"] = (
338
+ self.eval["positive_w0dt_macro_F1"] * self.eval["IL_MCC"]
339
+ )
340
+ self.eval["CGF1_micro"] = self.eval["positive_micro_F1"] * self.eval["IL_MCC"]
341
+
342
+ def summarize(self):
343
+ """
344
+ Compute and display summary metrics for evaluation results.
345
+ Note this functin can *only* be applied on the default parameter setting
346
+ """
347
+ if not self.eval:
348
+ raise Exception("Please run accumulate() first")
349
+
350
+ def _summarize(iouThr=None, metric=""):
351
+ p = self.params
352
+ iStr = " {:<18} @[ IoU={:<9}] = {:0.3f}"
353
+ titleStr = "Average " + metric
354
+ iouStr = (
355
+ "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
356
+ if iouThr is None
357
+ else "{:0.2f}".format(iouThr)
358
+ )
359
+
360
+ s = self.eval[metric]
361
+ # IoU
362
+ if iouThr is not None:
363
+ t = np.where(iouThr == p.iouThrs)[0]
364
+ s = s[t]
365
+
366
+ if len(s[s > -1]) == 0:
367
+ mean_s = -1
368
+ else:
369
+ mean_s = np.mean(s[s > -1])
370
+ print(iStr.format(titleStr, iouStr, mean_s))
371
+ return mean_s
372
+
373
+ def _summarize_single(metric=""):
374
+ titleStr = "Average " + metric
375
+ iStr = " {:<35} = {:0.3f}"
376
+ s = self.eval[metric]
377
+ print(iStr.format(titleStr, s))
378
+ return s
379
+
380
+ def _summarizeDets():
381
+ # note: the index of these metrics are also used in video Demo F1 evaluation
382
+ # when adding new metrics, please update the index in video Demo F1 evaluation
383
+ # in "evaluate" method of the "VideoDemoF1Evaluator" class
384
+ stats = np.zeros((len(DEMO_METRICS),))
385
+ stats[0] = _summarize(metric="CGF1")
386
+ stats[1] = _summarize(metric="precision")
387
+ stats[2] = _summarize(metric="recall")
388
+ stats[3] = _summarize(metric="F1")
389
+ stats[4] = _summarize(metric="positive_macro_F1")
390
+ stats[5] = _summarize_single(metric="IL_precision")
391
+ stats[6] = _summarize_single(metric="IL_recall")
392
+ stats[7] = _summarize_single(metric="IL_F1")
393
+ stats[8] = _summarize_single(metric="IL_FPR")
394
+ stats[9] = _summarize_single(metric="IL_MCC")
395
+ stats[10] = _summarize(metric="IL_perfect_pos")
396
+ stats[11] = _summarize(metric="IL_perfect_neg")
397
+ stats[12] = _summarize(iouThr=0.5, metric="CGF1")
398
+ stats[13] = _summarize(iouThr=0.5, metric="precision")
399
+ stats[14] = _summarize(iouThr=0.5, metric="recall")
400
+ stats[15] = _summarize(iouThr=0.5, metric="F1")
401
+ stats[16] = _summarize(iouThr=0.5, metric="positive_macro_F1")
402
+ stats[17] = _summarize(iouThr=0.5, metric="IL_perfect_pos")
403
+ stats[18] = _summarize(iouThr=0.5, metric="IL_perfect_neg")
404
+ stats[19] = _summarize(iouThr=0.75, metric="CGF1")
405
+ stats[20] = _summarize(iouThr=0.75, metric="precision")
406
+ stats[21] = _summarize(iouThr=0.75, metric="recall")
407
+ stats[22] = _summarize(iouThr=0.75, metric="F1")
408
+ stats[23] = _summarize(iouThr=0.75, metric="positive_macro_F1")
409
+ stats[24] = _summarize(iouThr=0.75, metric="IL_perfect_pos")
410
+ stats[25] = _summarize(iouThr=0.75, metric="IL_perfect_neg")
411
+ stats[26] = _summarize_single(metric="J")
412
+ stats[27] = _summarize_single(metric="F")
413
+ stats[28] = _summarize_single(metric="J&F")
414
+ stats[29] = _summarize(metric="CGF1_micro")
415
+ stats[30] = _summarize(metric="positive_micro_precision")
416
+ stats[31] = _summarize(metric="positive_micro_F1")
417
+ stats[32] = _summarize(iouThr=0.5, metric="CGF1_micro")
418
+ stats[33] = _summarize(iouThr=0.5, metric="positive_micro_precision")
419
+ stats[34] = _summarize(iouThr=0.5, metric="positive_micro_F1")
420
+ stats[35] = _summarize(iouThr=0.75, metric="CGF1_micro")
421
+ stats[36] = _summarize(iouThr=0.75, metric="positive_micro_precision")
422
+ stats[37] = _summarize(iouThr=0.75, metric="positive_micro_F1")
423
+ stats[38] = _summarize(metric="CGF1_w0dt")
424
+ stats[39] = _summarize(metric="positive_w0dt_macro_F1")
425
+ stats[40] = _summarize(iouThr=0.5, metric="CGF1_w0dt")
426
+ stats[41] = _summarize(iouThr=0.5, metric="positive_w0dt_macro_F1")
427
+ stats[42] = _summarize(iouThr=0.75, metric="CGF1_w0dt")
428
+ stats[43] = _summarize(iouThr=0.75, metric="positive_w0dt_macro_F1")
429
+ return stats
430
+
431
+ summarize = _summarizeDets
432
+ self.stats = summarize()
433
+
434
+
435
+ DEMO_METRICS = [
436
+ "CGF1",
437
+ "Precision",
438
+ "Recall",
439
+ "F1",
440
+ "Macro_F1",
441
+ "IL_Precision",
442
+ "IL_Recall",
443
+ "IL_F1",
444
+ "IL_FPR",
445
+ "IL_MCC",
446
+ "IL_perfect_pos",
447
+ "IL_perfect_neg",
448
+ "CGF1@0.5",
449
+ "Precision@0.5",
450
+ "Recall@0.5",
451
+ "F1@0.5",
452
+ "Macro_F1@0.5",
453
+ "IL_perfect_pos@0.5",
454
+ "IL_perfect_neg@0.5",
455
+ "CGF1@0.75",
456
+ "Precision@0.75",
457
+ "Recall@0.75",
458
+ "F1@0.75",
459
+ "Macro_F1@0.75",
460
+ "IL_perfect_pos@0.75",
461
+ "IL_perfect_neg@0.75",
462
+ "J",
463
+ "F",
464
+ "J&F",
465
+ "CGF1_micro",
466
+ "positive_micro_Precision",
467
+ "positive_micro_F1",
468
+ "CGF1_micro@0.5",
469
+ "positive_micro_Precision@0.5",
470
+ "positive_micro_F1@0.5",
471
+ "CGF1_micro@0.75",
472
+ "positive_micro_Precision@0.75",
473
+ "positive_micro_F1@0.75",
474
+ "CGF1_w0dt",
475
+ "positive_w0dt_macro_F1",
476
+ "CGF1_w0dt@0.5",
477
+ "positive_w0dt_macro_F1@0.5",
478
+ "CGF1_w0dt@0.75",
479
+ "positive_w0dt_macro_F1@0.75",
480
+ ]
481
+
482
+
483
+ class DemoEvaluator(CocoEvaluator):
484
+ def __init__(
485
+ self,
486
+ coco_gt,
487
+ iou_types,
488
+ dump_dir: Optional[str],
489
+ postprocessor,
490
+ threshold=0.5,
491
+ average_by_rarity=False,
492
+ gather_pred_via_filesys=False,
493
+ exhaustive_only=False,
494
+ all_exhaustive_only=True,
495
+ compute_JnF=False,
496
+ metrics_dump_dir: Optional[str] = None,
497
+ ):
498
+ self.iou_types = iou_types
499
+ self.threshold = threshold
500
+ super().__init__(
501
+ coco_gt=coco_gt,
502
+ iou_types=iou_types,
503
+ useCats=False,
504
+ dump_dir=dump_dir,
505
+ postprocessor=postprocessor,
506
+ # average_by_rarity=average_by_rarity,
507
+ gather_pred_via_filesys=gather_pred_via_filesys,
508
+ exhaustive_only=exhaustive_only,
509
+ all_exhaustive_only=all_exhaustive_only,
510
+ metrics_dump_dir=metrics_dump_dir,
511
+ )
512
+
513
+ self.use_self_evaluate = True
514
+ self.compute_JnF = compute_JnF
515
+
516
+ def _lazy_init(self):
517
+ if self.initialized:
518
+ return
519
+ super()._lazy_init()
520
+ self.use_self_evaluate = True
521
+ self.reset()
522
+
523
+ def select_best_scoring(self, scorings):
524
+ # This function is used for "oracle" type evaluation.
525
+ # It accepts the evaluation results with respect to several ground truths, and picks the best
526
+ if len(scorings) == 1:
527
+ return scorings[0]
528
+
529
+ assert (
530
+ scorings[0].ndim == 3
531
+ ), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
532
+ assert (
533
+ scorings[0].shape[0] == 1
534
+ ), f"Expecting a single category, got {scorings[0].shape[0]}"
535
+
536
+ for scoring in scorings:
537
+ assert (
538
+ scoring.shape == scorings[0].shape
539
+ ), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
540
+
541
+ selected_imgs = []
542
+ for img_id in range(scorings[0].shape[-1]):
543
+ best = scorings[0][:, :, img_id]
544
+
545
+ for scoring in scorings[1:]:
546
+ current = scoring[:, :, img_id]
547
+ if "local_F1s" in best[0, 0] and "local_F1s" in current[0, 0]:
548
+ # we were able to compute a F1 score for this particular image in both evaluations
549
+ # best["local_F1s"] contains the results at various IoU thresholds. We simply take the average for comparision
550
+ best_score = best[0, 0]["local_F1s"].mean()
551
+ current_score = current[0, 0]["local_F1s"].mean()
552
+ if current_score > best_score:
553
+ best = current
554
+
555
+ else:
556
+ # If we're here, it means that in that in some evaluation we were not able to get a valid local F1
557
+ # This happens when both the predictions and targets are empty. In that case, we can assume it's a perfect prediction
558
+ if "local_F1s" not in current[0, 0]:
559
+ best = current
560
+ selected_imgs.append(best)
561
+ result = np.stack(selected_imgs, axis=-1)
562
+ assert result.shape == scorings[0].shape
563
+ return result
564
+
565
+ def summarize(self):
566
+ self._lazy_init()
567
+ logging.info("Demo evaluator: Summarizing")
568
+ if not is_main_process():
569
+ return {}
570
+ outs = {}
571
+ prefix = "oracle_" if len(self.coco_evals) > 1 else ""
572
+ # if self.rarity_buckets is None:
573
+ self.accumulate(self.eval_img_ids)
574
+ for iou_type, coco_eval in self.coco_evals[0].items():
575
+ print("Demo metric, IoU type={}".format(iou_type))
576
+ coco_eval.summarize()
577
+
578
+ if "bbox" in self.coco_evals[0]:
579
+ for i, value in enumerate(self.coco_evals[0]["bbox"].stats):
580
+ outs[f"coco_eval_bbox_{prefix}{DEMO_METRICS[i]}"] = value
581
+ if "segm" in self.coco_evals[0]:
582
+ for i, value in enumerate(self.coco_evals[0]["segm"].stats):
583
+ outs[f"coco_eval_masks_{prefix}{DEMO_METRICS[i]}"] = value
584
+ # else:
585
+ # total_stats = {}
586
+ # for bucket, img_list in self.rarity_buckets.items():
587
+ # self.accumulate(imgIds=img_list)
588
+ # bucket_name = RARITY_BUCKETS[bucket]
589
+ # for iou_type, coco_eval in self.coco_evals[0].items():
590
+ # print(
591
+ # "Demo metric, IoU type={}, Rarity bucket={}".format(
592
+ # iou_type, bucket_name
593
+ # )
594
+ # )
595
+ # coco_eval.summarize()
596
+
597
+ # if "bbox" in self.coco_evals[0]:
598
+ # if "bbox" not in total_stats:
599
+ # total_stats["bbox"] = np.zeros_like(
600
+ # self.coco_evals[0]["bbox"].stats
601
+ # )
602
+ # total_stats["bbox"] += self.coco_evals[0]["bbox"].stats
603
+ # for i, value in enumerate(self.coco_evals[0]["bbox"].stats):
604
+ # outs[
605
+ # f"coco_eval_bbox_{bucket_name}_{prefix}{DEMO_METRICS[i]}"
606
+ # ] = value
607
+ # if "segm" in self.coco_evals[0]:
608
+ # if "segm" not in total_stats:
609
+ # total_stats["segm"] = np.zeros_like(
610
+ # self.coco_evals[0]["segm"].stats
611
+ # )
612
+ # total_stats["segm"] += self.coco_evals[0]["segm"].stats
613
+ # for i, value in enumerate(self.coco_evals[0]["segm"].stats):
614
+ # outs[
615
+ # f"coco_eval_masks_{bucket_name}_{prefix}{DEMO_METRICS[i]}"
616
+ # ] = value
617
+
618
+ # if "bbox" in total_stats:
619
+ # total_stats["bbox"] /= len(self.rarity_buckets)
620
+ # for i, value in enumerate(total_stats["bbox"]):
621
+ # outs[f"coco_eval_bbox_{prefix}{DEMO_METRICS[i]}"] = value
622
+ # if "segm" in total_stats:
623
+ # total_stats["segm"] /= len(self.rarity_buckets)
624
+ # for i, value in enumerate(total_stats["segm"]):
625
+ # outs[f"coco_eval_masks_{prefix}{DEMO_METRICS[i]}"] = value
626
+
627
+ return outs
628
+
629
+ def accumulate(self, imgIds=None):
630
+ self._lazy_init()
631
+ logging.info(
632
+ f"demo evaluator: Accumulating on {len(imgIds) if imgIds is not None else 'all'} images"
633
+ )
634
+ if not is_main_process():
635
+ return
636
+
637
+ if imgIds is not None:
638
+ for coco_eval in self.coco_evals[0].values():
639
+ coco_eval.params.imgIds = list(imgIds)
640
+
641
+ for coco_eval in self.coco_evals[0].values():
642
+ coco_eval.accumulate()
643
+
644
+ def reset(self):
645
+ self.coco_evals = [{} for _ in range(len(self.coco_gts))]
646
+ for i, coco_gt in enumerate(self.coco_gts):
647
+ for iou_type in self.iou_types:
648
+ self.coco_evals[i][iou_type] = DemoEval(
649
+ coco_gt=coco_gt,
650
+ iouType=iou_type,
651
+ threshold=self.threshold,
652
+ compute_JnF=self.compute_JnF,
653
+ )
654
+ self.coco_evals[i][iou_type].useCats = False
655
+ self.img_ids = []
656
+ self.eval_imgs = {k: [] for k in self.iou_types}
657
+ if self.dump is not None:
658
+ self.dump = []
sam3/eval/hota_eval_toolkit/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # flake8: noqa
sam3/eval/hota_eval_toolkit/run_ytvis_eval.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ """run_youtube_vis.py
4
+ Run example:
5
+ run_youtube_vis.py --USE_PARALLEL False --METRICS HOTA --TRACKERS_TO_EVAL STEm_Seg
6
+ Command Line Arguments: Defaults, # Comments
7
+ Eval arguments:
8
+ 'USE_PARALLEL': False,
9
+ 'NUM_PARALLEL_CORES': 8,
10
+ 'BREAK_ON_ERROR': True, # Raises exception and exits with error
11
+ 'RETURN_ON_ERROR': False, # if not BREAK_ON_ERROR, then returns from function on error
12
+ 'LOG_ON_ERROR': os.path.join(code_path, 'error_log.txt'), # if not None, save any errors into a log file.
13
+ 'PRINT_RESULTS': True,
14
+ 'PRINT_ONLY_COMBINED': False,
15
+ 'PRINT_CONFIG': True,
16
+ 'TIME_PROGRESS': True,
17
+ 'DISPLAY_LESS_PROGRESS': True,
18
+ 'OUTPUT_SUMMARY': True,
19
+ 'OUTPUT_EMPTY_CLASSES': True, # If False, summary files are not output for classes with no detections
20
+ 'OUTPUT_DETAILED': True,
21
+ 'PLOT_CURVES': True,
22
+ Dataset arguments:
23
+ 'GT_FOLDER': os.path.join(code_path, 'data/gt/youtube_vis/youtube_vis_training'), # Location of GT data
24
+ 'TRACKERS_FOLDER': os.path.join(code_path, 'data/trackers/youtube_vis/youtube_vis_training'),
25
+ # Trackers location
26
+ 'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER)
27
+ 'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder)
28
+ 'CLASSES_TO_EVAL': None, # Classes to eval (if None, all classes)
29
+ 'SPLIT_TO_EVAL': 'training', # Valid: 'training', 'val'
30
+ 'PRINT_CONFIG': True, # Whether to print current config
31
+ 'OUTPUT_SUB_FOLDER': '', # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER
32
+ 'TRACKER_SUB_FOLDER': 'data', # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER
33
+ 'TRACKER_DISPLAY_NAMES': None, # Names of trackers to display, if None: TRACKERS_TO_EVAL
34
+ Metric arguments:
35
+ 'METRICS': ['TrackMAP', 'HOTA', 'CLEAR', 'Identity']
36
+ """
37
+
38
+ import argparse
39
+ import os
40
+ import sys
41
+ from multiprocessing import freeze_support
42
+
43
+ from . import trackeval
44
+
45
+
46
+ def run_ytvis_eval(args=None, gt_json=None, dt_json=None):
47
+ # Command line interface:
48
+ default_eval_config = trackeval.Evaluator.get_default_eval_config()
49
+ # print only combined since TrackMAP is undefined for per sequence breakdowns
50
+ default_eval_config["PRINT_ONLY_COMBINED"] = True
51
+ default_dataset_config = trackeval.datasets.YouTubeVIS.get_default_dataset_config()
52
+ default_metrics_config = {"METRICS": ["HOTA"]}
53
+ config = {
54
+ **default_eval_config,
55
+ **default_dataset_config,
56
+ **default_metrics_config,
57
+ } # Merge default configs
58
+ parser = argparse.ArgumentParser()
59
+ for setting in config.keys():
60
+ if type(config[setting]) == list or type(config[setting]) == type(None):
61
+ parser.add_argument("--" + setting, nargs="+")
62
+ else:
63
+ parser.add_argument("--" + setting)
64
+ args = parser.parse_args(args).__dict__
65
+ for setting in args.keys():
66
+ if args[setting] is not None:
67
+ if type(config[setting]) == type(True):
68
+ if args[setting] == "True":
69
+ x = True
70
+ elif args[setting] == "False":
71
+ x = False
72
+ else:
73
+ raise Exception(
74
+ "Command line parameter " + setting + "must be True or False"
75
+ )
76
+ elif type(config[setting]) == type(1):
77
+ x = int(args[setting])
78
+ elif type(args[setting]) == type(None):
79
+ x = None
80
+ else:
81
+ x = args[setting]
82
+ config[setting] = x
83
+ eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()}
84
+ dataset_config = {
85
+ k: v for k, v in config.items() if k in default_dataset_config.keys()
86
+ }
87
+ metrics_config = {
88
+ k: v for k, v in config.items() if k in default_metrics_config.keys()
89
+ }
90
+
91
+ # Run code
92
+ evaluator = trackeval.Evaluator(eval_config)
93
+ # allow directly specifying the GT JSON data and Tracker (result)
94
+ # JSON data as Python objects, without reading from files.
95
+ dataset_config["GT_JSON_OBJECT"] = gt_json
96
+ dataset_config["TRACKER_JSON_OBJECT"] = dt_json
97
+ dataset_list = [trackeval.datasets.YouTubeVIS(dataset_config)]
98
+ metrics_list = []
99
+ # for metric in [trackeval.metrics.TrackMAP, trackeval.metrics.HOTA, trackeval.metrics.CLEAR,
100
+ # trackeval.metrics.Identity]:
101
+ for metric in [trackeval.metrics.HOTA]:
102
+ if metric.get_name() in metrics_config["METRICS"]:
103
+ metrics_list.append(metric())
104
+ if len(metrics_list) == 0:
105
+ raise Exception("No metrics selected for evaluation")
106
+ output_res, output_msg = evaluator.evaluate(dataset_list, metrics_list)
107
+ return output_res, output_msg
108
+
109
+
110
+ if __name__ == "__main__":
111
+ import sys
112
+
113
+ freeze_support()
114
+ run_ytvis_eval(sys.argv[1:])
sam3/eval/hota_eval_toolkit/trackeval/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ from . import datasets, metrics, utils
4
+ from .eval import Evaluator
sam3/eval/hota_eval_toolkit/trackeval/_timing.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ import inspect
4
+ from functools import wraps
5
+ from time import perf_counter
6
+
7
+ DO_TIMING = False
8
+ DISPLAY_LESS_PROGRESS = False
9
+ timer_dict = {}
10
+ counter = 0
11
+
12
+
13
+ def time(f):
14
+ @wraps(f)
15
+ def wrap(*args, **kw):
16
+ if DO_TIMING:
17
+ # Run function with timing
18
+ ts = perf_counter()
19
+ result = f(*args, **kw)
20
+ te = perf_counter()
21
+ tt = te - ts
22
+
23
+ # Get function name
24
+ arg_names = inspect.getfullargspec(f)[0]
25
+ if arg_names[0] == "self" and DISPLAY_LESS_PROGRESS:
26
+ return result
27
+ elif arg_names[0] == "self":
28
+ method_name = type(args[0]).__name__ + "." + f.__name__
29
+ else:
30
+ method_name = f.__name__
31
+
32
+ # Record accumulative time in each function for analysis
33
+ if method_name in timer_dict.keys():
34
+ timer_dict[method_name] += tt
35
+ else:
36
+ timer_dict[method_name] = tt
37
+
38
+ # If code is finished, display timing summary
39
+ if method_name == "Evaluator.evaluate":
40
+ print("")
41
+ print("Timing analysis:")
42
+ for key, value in timer_dict.items():
43
+ print("%-70s %2.4f sec" % (key, value))
44
+ else:
45
+ # Get function argument values for printing special arguments of interest
46
+ arg_titles = ["tracker", "seq", "cls"]
47
+ arg_vals = []
48
+ for i, a in enumerate(arg_names):
49
+ if a in arg_titles:
50
+ arg_vals.append(args[i])
51
+ arg_text = "(" + ", ".join(arg_vals) + ")"
52
+
53
+ # Display methods and functions with different indentation.
54
+ if arg_names[0] == "self":
55
+ print("%-74s %2.4f sec" % (" " * 4 + method_name + arg_text, tt))
56
+ elif arg_names[0] == "test":
57
+ pass
58
+ else:
59
+ global counter
60
+ counter += 1
61
+ print("%i %-70s %2.4f sec" % (counter, method_name + arg_text, tt))
62
+
63
+ return result
64
+ else:
65
+ # If config["TIME_PROGRESS"] is false, or config["USE_PARALLEL"] is true, run functions normally without timing.
66
+ return f(*args, **kw)
67
+
68
+ return wrap
sam3/eval/hota_eval_toolkit/trackeval/datasets/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ from .tao_ow import TAO_OW
4
+ from .youtube_vis import YouTubeVIS
sam3/eval/hota_eval_toolkit/trackeval/datasets/_base_dataset.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ import csv
4
+ import io
5
+ import os
6
+ import traceback
7
+ import zipfile
8
+ from abc import ABC, abstractmethod
9
+ from copy import deepcopy
10
+
11
+ import numpy as np
12
+
13
+ from .. import _timing
14
+ from ..utils import TrackEvalException
15
+
16
+
17
+ class _BaseDataset(ABC):
18
+ @abstractmethod
19
+ def __init__(self):
20
+ self.tracker_list = None
21
+ self.seq_list = None
22
+ self.class_list = None
23
+ self.output_fol = None
24
+ self.output_sub_fol = None
25
+ self.should_classes_combine = True
26
+ self.use_super_categories = False
27
+
28
+ # Functions to implement:
29
+
30
+ @staticmethod
31
+ @abstractmethod
32
+ def get_default_dataset_config(): ...
33
+
34
+ @abstractmethod
35
+ def _load_raw_file(self, tracker, seq, is_gt): ...
36
+
37
+ @_timing.time
38
+ @abstractmethod
39
+ def get_preprocessed_seq_data(self, raw_data, cls): ...
40
+
41
+ @abstractmethod
42
+ def _calculate_similarities(self, gt_dets_t, tracker_dets_t): ...
43
+
44
+ # Helper functions for all datasets:
45
+
46
+ @classmethod
47
+ def get_class_name(cls):
48
+ return cls.__name__
49
+
50
+ def get_name(self):
51
+ return self.get_class_name()
52
+
53
+ def get_output_fol(self, tracker):
54
+ return os.path.join(self.output_fol, tracker, self.output_sub_fol)
55
+
56
+ def get_display_name(self, tracker):
57
+ """Can be overwritten if the trackers name (in files) is different to how it should be displayed.
58
+ By default this method just returns the trackers name as is.
59
+ """
60
+ return tracker
61
+
62
+ def get_eval_info(self):
63
+ """Return info about the dataset needed for the Evaluator"""
64
+ return self.tracker_list, self.seq_list, self.class_list
65
+
66
+ @_timing.time
67
+ def get_raw_seq_data(self, tracker, seq):
68
+ """Loads raw data (tracker and ground-truth) for a single tracker on a single sequence.
69
+ Raw data includes all of the information needed for both preprocessing and evaluation, for all classes.
70
+ A later function (get_processed_seq_data) will perform such preprocessing and extract relevant information for
71
+ the evaluation of each class.
72
+
73
+ This returns a dict which contains the fields:
74
+ [num_timesteps]: integer
75
+ [gt_ids, tracker_ids, gt_classes, tracker_classes, tracker_confidences]:
76
+ list (for each timestep) of 1D NDArrays (for each det).
77
+ [gt_dets, tracker_dets, gt_crowd_ignore_regions]: list (for each timestep) of lists of detections.
78
+ [similarity_scores]: list (for each timestep) of 2D NDArrays.
79
+ [gt_extras]: dict (for each extra) of lists (for each timestep) of 1D NDArrays (for each det).
80
+
81
+ gt_extras contains dataset specific information used for preprocessing such as occlusion and truncation levels.
82
+
83
+ Note that similarities are extracted as part of the dataset and not the metric, because almost all metrics are
84
+ independent of the exact method of calculating the similarity. However datasets are not (e.g. segmentation
85
+ masks vs 2D boxes vs 3D boxes).
86
+ We calculate the similarity before preprocessing because often both preprocessing and evaluation require it and
87
+ we don't wish to calculate this twice.
88
+ We calculate similarity between all gt and tracker classes (not just each class individually) to allow for
89
+ calculation of metrics such as class confusion matrices. Typically the impact of this on performance is low.
90
+ """
91
+ # Load raw data.
92
+ raw_gt_data = self._load_raw_file(tracker, seq, is_gt=True)
93
+ raw_tracker_data = self._load_raw_file(tracker, seq, is_gt=False)
94
+ raw_data = {**raw_tracker_data, **raw_gt_data} # Merges dictionaries
95
+
96
+ # Calculate similarities for each timestep.
97
+ similarity_scores = []
98
+ for t, (gt_dets_t, tracker_dets_t) in enumerate(
99
+ zip(raw_data["gt_dets"], raw_data["tracker_dets"])
100
+ ):
101
+ ious = self._calculate_similarities(gt_dets_t, tracker_dets_t)
102
+ similarity_scores.append(ious)
103
+ raw_data["similarity_scores"] = similarity_scores
104
+ return raw_data
105
+
106
+ @staticmethod
107
+ def _load_simple_text_file(
108
+ file,
109
+ time_col=0,
110
+ id_col=None,
111
+ remove_negative_ids=False,
112
+ valid_filter=None,
113
+ crowd_ignore_filter=None,
114
+ convert_filter=None,
115
+ is_zipped=False,
116
+ zip_file=None,
117
+ force_delimiters=None,
118
+ ):
119
+ """Function that loads data which is in a commonly used text file format.
120
+ Assumes each det is given by one row of a text file.
121
+ There is no limit to the number or meaning of each column,
122
+ however one column needs to give the timestep of each det (time_col) which is default col 0.
123
+
124
+ The file dialect (deliminator, num cols, etc) is determined automatically.
125
+ This function automatically separates dets by timestep,
126
+ and is much faster than alternatives such as np.loadtext or pandas.
127
+
128
+ If remove_negative_ids is True and id_col is not None, dets with negative values in id_col are excluded.
129
+ These are not excluded from ignore data.
130
+
131
+ valid_filter can be used to only include certain classes.
132
+ It is a dict with ints as keys, and lists as values,
133
+ such that a row is included if "row[key].lower() is in value" for all key/value pairs in the dict.
134
+ If None, all classes are included.
135
+
136
+ crowd_ignore_filter can be used to read crowd_ignore regions separately. It has the same format as valid filter.
137
+
138
+ convert_filter can be used to convert value read to another format.
139
+ This is used most commonly to convert classes given as string to a class id.
140
+ This is a dict such that the key is the column to convert, and the value is another dict giving the mapping.
141
+
142
+ Optionally, input files could be a zip of multiple text files for storage efficiency.
143
+
144
+ Returns read_data and ignore_data.
145
+ Each is a dict (with keys as timesteps as strings) of lists (over dets) of lists (over column values).
146
+ Note that all data is returned as strings, and must be converted to float/int later if needed.
147
+ Note that timesteps will not be present in the returned dict keys if there are no dets for them
148
+ """
149
+
150
+ if remove_negative_ids and id_col is None:
151
+ raise TrackEvalException(
152
+ "remove_negative_ids is True, but id_col is not given."
153
+ )
154
+ if crowd_ignore_filter is None:
155
+ crowd_ignore_filter = {}
156
+ if convert_filter is None:
157
+ convert_filter = {}
158
+ try:
159
+ if is_zipped: # Either open file directly or within a zip.
160
+ if zip_file is None:
161
+ raise TrackEvalException(
162
+ "is_zipped set to True, but no zip_file is given."
163
+ )
164
+ archive = zipfile.ZipFile(os.path.join(zip_file), "r")
165
+ fp = io.TextIOWrapper(archive.open(file, "r"))
166
+ else:
167
+ fp = open(file)
168
+ read_data = {}
169
+ crowd_ignore_data = {}
170
+ fp.seek(0, os.SEEK_END)
171
+ # check if file is empty
172
+ if fp.tell():
173
+ fp.seek(0)
174
+ dialect = csv.Sniffer().sniff(
175
+ fp.readline(), delimiters=force_delimiters
176
+ ) # Auto determine structure.
177
+ dialect.skipinitialspace = (
178
+ True # Deal with extra spaces between columns
179
+ )
180
+ fp.seek(0)
181
+ reader = csv.reader(fp, dialect)
182
+ for row in reader:
183
+ try:
184
+ # Deal with extra trailing spaces at the end of rows
185
+ if row[-1] in "":
186
+ row = row[:-1]
187
+ timestep = str(int(float(row[time_col])))
188
+ # Read ignore regions separately.
189
+ is_ignored = False
190
+ for ignore_key, ignore_value in crowd_ignore_filter.items():
191
+ if row[ignore_key].lower() in ignore_value:
192
+ # Convert values in one column (e.g. string to id)
193
+ for (
194
+ convert_key,
195
+ convert_value,
196
+ ) in convert_filter.items():
197
+ row[convert_key] = convert_value[
198
+ row[convert_key].lower()
199
+ ]
200
+ # Save data separated by timestep.
201
+ if timestep in crowd_ignore_data.keys():
202
+ crowd_ignore_data[timestep].append(row)
203
+ else:
204
+ crowd_ignore_data[timestep] = [row]
205
+ is_ignored = True
206
+ if (
207
+ is_ignored
208
+ ): # if det is an ignore region, it cannot be a normal det.
209
+ continue
210
+ # Exclude some dets if not valid.
211
+ if valid_filter is not None:
212
+ for key, value in valid_filter.items():
213
+ if row[key].lower() not in value:
214
+ continue
215
+ if remove_negative_ids:
216
+ if int(float(row[id_col])) < 0:
217
+ continue
218
+ # Convert values in one column (e.g. string to id)
219
+ for convert_key, convert_value in convert_filter.items():
220
+ row[convert_key] = convert_value[row[convert_key].lower()]
221
+ # Save data separated by timestep.
222
+ if timestep in read_data.keys():
223
+ read_data[timestep].append(row)
224
+ else:
225
+ read_data[timestep] = [row]
226
+ except Exception:
227
+ exc_str_init = (
228
+ "In file %s the following line cannot be read correctly: \n"
229
+ % os.path.basename(file)
230
+ )
231
+ exc_str = " ".join([exc_str_init] + row)
232
+ raise TrackEvalException(exc_str)
233
+ fp.close()
234
+ except Exception:
235
+ print("Error loading file: %s, printing traceback." % file)
236
+ traceback.print_exc()
237
+ raise TrackEvalException(
238
+ "File %s cannot be read because it is either not present or invalidly formatted"
239
+ % os.path.basename(file)
240
+ )
241
+ return read_data, crowd_ignore_data
242
+
243
+ @staticmethod
244
+ def _calculate_mask_ious(masks1, masks2, is_encoded=False, do_ioa=False):
245
+ """Calculates the IOU (intersection over union) between two arrays of segmentation masks.
246
+ If is_encoded a run length encoding with pycocotools is assumed as input format, otherwise an input of numpy
247
+ arrays of the shape (num_masks, height, width) is assumed and the encoding is performed.
248
+ If do_ioa (intersection over area) , then calculates the intersection over the area of masks1 - this is commonly
249
+ used to determine if detections are within crowd ignore region.
250
+ :param masks1: first set of masks (numpy array of shape (num_masks, height, width) if not encoded,
251
+ else pycocotools rle encoded format)
252
+ :param masks2: second set of masks (numpy array of shape (num_masks, height, width) if not encoded,
253
+ else pycocotools rle encoded format)
254
+ :param is_encoded: whether the input is in pycocotools rle encoded format
255
+ :param do_ioa: whether to perform IoA computation
256
+ :return: the IoU/IoA scores
257
+ """
258
+
259
+ # Only loaded when run to reduce minimum requirements
260
+ from pycocotools import mask as mask_utils
261
+
262
+ # use pycocotools for run length encoding of masks
263
+ if not is_encoded:
264
+ masks1 = mask_utils.encode(
265
+ np.array(np.transpose(masks1, (1, 2, 0)), order="F")
266
+ )
267
+ masks2 = mask_utils.encode(
268
+ np.array(np.transpose(masks2, (1, 2, 0)), order="F")
269
+ )
270
+
271
+ # use pycocotools for iou computation of rle encoded masks
272
+ ious = mask_utils.iou(masks1, masks2, [do_ioa] * len(masks2))
273
+ if len(masks1) == 0 or len(masks2) == 0:
274
+ ious = np.asarray(ious).reshape(len(masks1), len(masks2))
275
+ assert (ious >= 0 - np.finfo("float").eps).all()
276
+ assert (ious <= 1 + np.finfo("float").eps).all()
277
+
278
+ return ious
279
+
280
+ @staticmethod
281
+ def _calculate_box_ious(bboxes1, bboxes2, box_format="xywh", do_ioa=False):
282
+ """Calculates the IOU (intersection over union) between two arrays of boxes.
283
+ Allows variable box formats ('xywh' and 'x0y0x1y1').
284
+ If do_ioa (intersection over area) , then calculates the intersection over the area of boxes1 - this is commonly
285
+ used to determine if detections are within crowd ignore region.
286
+ """
287
+ if box_format in "xywh":
288
+ # layout: (x0, y0, w, h)
289
+ bboxes1 = deepcopy(bboxes1)
290
+ bboxes2 = deepcopy(bboxes2)
291
+
292
+ bboxes1[:, 2] = bboxes1[:, 0] + bboxes1[:, 2]
293
+ bboxes1[:, 3] = bboxes1[:, 1] + bboxes1[:, 3]
294
+ bboxes2[:, 2] = bboxes2[:, 0] + bboxes2[:, 2]
295
+ bboxes2[:, 3] = bboxes2[:, 1] + bboxes2[:, 3]
296
+ elif box_format not in "x0y0x1y1":
297
+ raise (TrackEvalException("box_format %s is not implemented" % box_format))
298
+
299
+ # layout: (x0, y0, x1, y1)
300
+ min_ = np.minimum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :])
301
+ max_ = np.maximum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :])
302
+ intersection = np.maximum(min_[..., 2] - max_[..., 0], 0) * np.maximum(
303
+ min_[..., 3] - max_[..., 1], 0
304
+ )
305
+ area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (
306
+ bboxes1[..., 3] - bboxes1[..., 1]
307
+ )
308
+
309
+ if do_ioa:
310
+ ioas = np.zeros_like(intersection)
311
+ valid_mask = area1 > 0 + np.finfo("float").eps
312
+ ioas[valid_mask, :] = (
313
+ intersection[valid_mask, :] / area1[valid_mask][:, np.newaxis]
314
+ )
315
+
316
+ return ioas
317
+ else:
318
+ area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (
319
+ bboxes2[..., 3] - bboxes2[..., 1]
320
+ )
321
+ union = area1[:, np.newaxis] + area2[np.newaxis, :] - intersection
322
+ intersection[area1 <= 0 + np.finfo("float").eps, :] = 0
323
+ intersection[:, area2 <= 0 + np.finfo("float").eps] = 0
324
+ intersection[union <= 0 + np.finfo("float").eps] = 0
325
+ union[union <= 0 + np.finfo("float").eps] = 1
326
+ ious = intersection / union
327
+ return ious
328
+
329
+ @staticmethod
330
+ def _calculate_euclidean_similarity(dets1, dets2, zero_distance=2.0):
331
+ """Calculates the euclidean distance between two sets of detections, and then converts this into a similarity
332
+ measure with values between 0 and 1 using the following formula: sim = max(0, 1 - dist/zero_distance).
333
+ The default zero_distance of 2.0, corresponds to the default used in MOT15_3D, such that a 0.5 similarity
334
+ threshold corresponds to a 1m distance threshold for TPs.
335
+ """
336
+ dist = np.linalg.norm(dets1[:, np.newaxis] - dets2[np.newaxis, :], axis=2)
337
+ sim = np.maximum(0, 1 - dist / zero_distance)
338
+ return sim
339
+
340
+ @staticmethod
341
+ def _check_unique_ids(data, after_preproc=False):
342
+ """Check the requirement that the tracker_ids and gt_ids are unique per timestep"""
343
+ gt_ids = data["gt_ids"]
344
+ tracker_ids = data["tracker_ids"]
345
+ for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(gt_ids, tracker_ids)):
346
+ if len(tracker_ids_t) > 0:
347
+ unique_ids, counts = np.unique(tracker_ids_t, return_counts=True)
348
+ if np.max(counts) != 1:
349
+ duplicate_ids = unique_ids[counts > 1]
350
+ exc_str_init = (
351
+ "Tracker predicts the same ID more than once in a single timestep "
352
+ "(seq: %s, frame: %i, ids:" % (data["seq"], t + 1)
353
+ )
354
+ exc_str = (
355
+ " ".join([exc_str_init] + [str(d) for d in duplicate_ids]) + ")"
356
+ )
357
+ if after_preproc:
358
+ exc_str_init += (
359
+ "\n Note that this error occurred after preprocessing (but not before), "
360
+ "so ids may not be as in file, and something seems wrong with preproc."
361
+ )
362
+ raise TrackEvalException(exc_str)
363
+ if len(gt_ids_t) > 0:
364
+ unique_ids, counts = np.unique(gt_ids_t, return_counts=True)
365
+ if np.max(counts) != 1:
366
+ duplicate_ids = unique_ids[counts > 1]
367
+ exc_str_init = (
368
+ "Ground-truth has the same ID more than once in a single timestep "
369
+ "(seq: %s, frame: %i, ids:" % (data["seq"], t + 1)
370
+ )
371
+ exc_str = (
372
+ " ".join([exc_str_init] + [str(d) for d in duplicate_ids]) + ")"
373
+ )
374
+ if after_preproc:
375
+ exc_str_init += (
376
+ "\n Note that this error occurred after preprocessing (but not before), "
377
+ "so ids may not be as in file, and something seems wrong with preproc."
378
+ )
379
+ raise TrackEvalException(exc_str)
sam3/eval/hota_eval_toolkit/trackeval/datasets/tao_ow.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ import itertools
4
+ import json
5
+ import os
6
+ from collections import defaultdict
7
+
8
+ import numpy as np
9
+ from scipy.optimize import linear_sum_assignment
10
+
11
+ from .. import _timing, utils
12
+ from ..utils import TrackEvalException
13
+ from ._base_dataset import _BaseDataset
14
+
15
+
16
+ class TAO_OW(_BaseDataset):
17
+ """Dataset class for TAO tracking"""
18
+
19
+ @staticmethod
20
+ def get_default_dataset_config():
21
+ """Default class config values"""
22
+ code_path = utils.get_code_path()
23
+ default_config = {
24
+ "GT_FOLDER": os.path.join(
25
+ code_path, "data/gt/tao/tao_training"
26
+ ), # Location of GT data
27
+ "TRACKERS_FOLDER": os.path.join(
28
+ code_path, "data/trackers/tao/tao_training"
29
+ ), # Trackers location
30
+ "OUTPUT_FOLDER": None, # Where to save eval results (if None, same as TRACKERS_FOLDER)
31
+ "TRACKERS_TO_EVAL": None, # Filenames of trackers to eval (if None, all in folder)
32
+ "CLASSES_TO_EVAL": None, # Classes to eval (if None, all classes)
33
+ "SPLIT_TO_EVAL": "training", # Valid: 'training', 'val'
34
+ "PRINT_CONFIG": True, # Whether to print current config
35
+ "TRACKER_SUB_FOLDER": "data", # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER
36
+ "OUTPUT_SUB_FOLDER": "", # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER
37
+ "TRACKER_DISPLAY_NAMES": None, # Names of trackers to display, if None: TRACKERS_TO_EVAL
38
+ "MAX_DETECTIONS": 300, # Number of maximal allowed detections per image (0 for unlimited)
39
+ "SUBSET": "all",
40
+ }
41
+ return default_config
42
+
43
+ def __init__(self, config=None):
44
+ """Initialise dataset, checking that all required files are present"""
45
+ super().__init__()
46
+ # Fill non-given config values with defaults
47
+ self.config = utils.init_config(
48
+ config, self.get_default_dataset_config(), self.get_name()
49
+ )
50
+ self.gt_fol = self.config["GT_FOLDER"]
51
+ self.tracker_fol = self.config["TRACKERS_FOLDER"]
52
+ self.should_classes_combine = True
53
+ self.use_super_categories = False
54
+
55
+ self.tracker_sub_fol = self.config["TRACKER_SUB_FOLDER"]
56
+ self.output_fol = self.config["OUTPUT_FOLDER"]
57
+ if self.output_fol is None:
58
+ self.output_fol = self.tracker_fol
59
+ self.output_sub_fol = self.config["OUTPUT_SUB_FOLDER"]
60
+
61
+ gt_dir_files = [
62
+ file for file in os.listdir(self.gt_fol) if file.endswith(".json")
63
+ ]
64
+ if len(gt_dir_files) != 1:
65
+ raise TrackEvalException(
66
+ self.gt_fol + " does not contain exactly one json file."
67
+ )
68
+
69
+ with open(os.path.join(self.gt_fol, gt_dir_files[0])) as f:
70
+ self.gt_data = json.load(f)
71
+
72
+ self.subset = self.config["SUBSET"]
73
+ if self.subset != "all":
74
+ # Split GT data into `known`, `unknown` or `distractor`
75
+ self._split_known_unknown_distractor()
76
+ self.gt_data = self._filter_gt_data(self.gt_data)
77
+
78
+ # merge categories marked with a merged tag in TAO dataset
79
+ self._merge_categories(self.gt_data["annotations"] + self.gt_data["tracks"])
80
+
81
+ # Get sequences to eval and sequence information
82
+ self.seq_list = [
83
+ vid["name"].replace("/", "-") for vid in self.gt_data["videos"]
84
+ ]
85
+ self.seq_name_to_seq_id = {
86
+ vid["name"].replace("/", "-"): vid["id"] for vid in self.gt_data["videos"]
87
+ }
88
+ # compute mappings from videos to annotation data
89
+ self.videos_to_gt_tracks, self.videos_to_gt_images = self._compute_vid_mappings(
90
+ self.gt_data["annotations"]
91
+ )
92
+ # compute sequence lengths
93
+ self.seq_lengths = {vid["id"]: 0 for vid in self.gt_data["videos"]}
94
+ for img in self.gt_data["images"]:
95
+ self.seq_lengths[img["video_id"]] += 1
96
+ self.seq_to_images_to_timestep = self._compute_image_to_timestep_mappings()
97
+ self.seq_to_classes = {
98
+ vid["id"]: {
99
+ "pos_cat_ids": list(
100
+ {
101
+ track["category_id"]
102
+ for track in self.videos_to_gt_tracks[vid["id"]]
103
+ }
104
+ ),
105
+ "neg_cat_ids": vid["neg_category_ids"],
106
+ "not_exhaustively_labeled_cat_ids": vid["not_exhaustive_category_ids"],
107
+ }
108
+ for vid in self.gt_data["videos"]
109
+ }
110
+
111
+ # Get classes to eval
112
+ considered_vid_ids = [self.seq_name_to_seq_id[vid] for vid in self.seq_list]
113
+ seen_cats = set(
114
+ [
115
+ cat_id
116
+ for vid_id in considered_vid_ids
117
+ for cat_id in self.seq_to_classes[vid_id]["pos_cat_ids"]
118
+ ]
119
+ )
120
+ # only classes with ground truth are evaluated in TAO
121
+ self.valid_classes = [
122
+ cls["name"] for cls in self.gt_data["categories"] if cls["id"] in seen_cats
123
+ ]
124
+ # cls_name_to_cls_id_map = {cls['name']: cls['id'] for cls in self.gt_data['categories']}
125
+
126
+ if self.config["CLASSES_TO_EVAL"]:
127
+ # self.class_list = [cls.lower() if cls.lower() in self.valid_classes else None
128
+ # for cls in self.config['CLASSES_TO_EVAL']]
129
+ self.class_list = ["object"] # class-agnostic
130
+ if not all(self.class_list):
131
+ raise TrackEvalException(
132
+ "Attempted to evaluate an invalid class. Only classes "
133
+ + ", ".join(self.valid_classes)
134
+ + " are valid (classes present in ground truth data)."
135
+ )
136
+ else:
137
+ # self.class_list = [cls for cls in self.valid_classes]
138
+ self.class_list = ["object"] # class-agnostic
139
+ # self.class_name_to_class_id = {k: v for k, v in cls_name_to_cls_id_map.items() if k in self.class_list}
140
+ self.class_name_to_class_id = {"object": 1} # class-agnostic
141
+
142
+ # Get trackers to eval
143
+ if self.config["TRACKERS_TO_EVAL"] is None:
144
+ self.tracker_list = os.listdir(self.tracker_fol)
145
+ else:
146
+ self.tracker_list = self.config["TRACKERS_TO_EVAL"]
147
+
148
+ if self.config["TRACKER_DISPLAY_NAMES"] is None:
149
+ self.tracker_to_disp = dict(zip(self.tracker_list, self.tracker_list))
150
+ elif (self.config["TRACKERS_TO_EVAL"] is not None) and (
151
+ len(self.config["TRACKER_DISPLAY_NAMES"]) == len(self.tracker_list)
152
+ ):
153
+ self.tracker_to_disp = dict(
154
+ zip(self.tracker_list, self.config["TRACKER_DISPLAY_NAMES"])
155
+ )
156
+ else:
157
+ raise TrackEvalException(
158
+ "List of tracker files and tracker display names do not match."
159
+ )
160
+
161
+ self.tracker_data = {tracker: dict() for tracker in self.tracker_list}
162
+
163
+ for tracker in self.tracker_list:
164
+ tr_dir_files = [
165
+ file
166
+ for file in os.listdir(
167
+ os.path.join(self.tracker_fol, tracker, self.tracker_sub_fol)
168
+ )
169
+ if file.endswith(".json")
170
+ ]
171
+ if len(tr_dir_files) != 1:
172
+ raise TrackEvalException(
173
+ os.path.join(self.tracker_fol, tracker, self.tracker_sub_fol)
174
+ + " does not contain exactly one json file."
175
+ )
176
+ with open(
177
+ os.path.join(
178
+ self.tracker_fol, tracker, self.tracker_sub_fol, tr_dir_files[0]
179
+ )
180
+ ) as f:
181
+ curr_data = json.load(f)
182
+
183
+ # limit detections if MAX_DETECTIONS > 0
184
+ if self.config["MAX_DETECTIONS"]:
185
+ curr_data = self._limit_dets_per_image(curr_data)
186
+
187
+ # fill missing video ids
188
+ self._fill_video_ids_inplace(curr_data)
189
+
190
+ # make track ids unique over whole evaluation set
191
+ self._make_track_ids_unique(curr_data)
192
+
193
+ # merge categories marked with a merged tag in TAO dataset
194
+ self._merge_categories(curr_data)
195
+
196
+ # get tracker sequence information
197
+ curr_videos_to_tracker_tracks, curr_videos_to_tracker_images = (
198
+ self._compute_vid_mappings(curr_data)
199
+ )
200
+ self.tracker_data[tracker]["vids_to_tracks"] = curr_videos_to_tracker_tracks
201
+ self.tracker_data[tracker]["vids_to_images"] = curr_videos_to_tracker_images
202
+
203
+ def get_display_name(self, tracker):
204
+ return self.tracker_to_disp[tracker]
205
+
206
+ def _load_raw_file(self, tracker, seq, is_gt):
207
+ """Load a file (gt or tracker) in the TAO format
208
+
209
+ If is_gt, this returns a dict which contains the fields:
210
+ [gt_ids, gt_classes] : list (for each timestep) of 1D NDArrays (for each det).
211
+ [gt_dets]: list (for each timestep) of lists of detections.
212
+ [classes_to_gt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as
213
+ keys and corresponding segmentations as values) for each track
214
+ [classes_to_gt_track_ids, classes_to_gt_track_areas, classes_to_gt_track_lengths]: dictionary with class values
215
+ as keys and lists (for each track) as values
216
+
217
+ if not is_gt, this returns a dict which contains the fields:
218
+ [tracker_ids, tracker_classes, tracker_confidences] : list (for each timestep) of 1D NDArrays (for each det).
219
+ [tracker_dets]: list (for each timestep) of lists of detections.
220
+ [classes_to_dt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as
221
+ keys and corresponding segmentations as values) for each track
222
+ [classes_to_dt_track_ids, classes_to_dt_track_areas, classes_to_dt_track_lengths]: dictionary with class values
223
+ as keys and lists as values
224
+ [classes_to_dt_track_scores]: dictionary with class values as keys and 1D numpy arrays as values
225
+ """
226
+ seq_id = self.seq_name_to_seq_id[seq]
227
+ # File location
228
+ if is_gt:
229
+ imgs = self.videos_to_gt_images[seq_id]
230
+ else:
231
+ imgs = self.tracker_data[tracker]["vids_to_images"][seq_id]
232
+
233
+ # Convert data to required format
234
+ num_timesteps = self.seq_lengths[seq_id]
235
+ img_to_timestep = self.seq_to_images_to_timestep[seq_id]
236
+ data_keys = ["ids", "classes", "dets"]
237
+ if not is_gt:
238
+ data_keys += ["tracker_confidences"]
239
+ raw_data = {key: [None] * num_timesteps for key in data_keys}
240
+ for img in imgs:
241
+ # some tracker data contains images without any ground truth information, these are ignored
242
+ try:
243
+ t = img_to_timestep[img["id"]]
244
+ except KeyError:
245
+ continue
246
+ annotations = img["annotations"]
247
+ raw_data["dets"][t] = np.atleast_2d(
248
+ [ann["bbox"] for ann in annotations]
249
+ ).astype(float)
250
+ raw_data["ids"][t] = np.atleast_1d(
251
+ [ann["track_id"] for ann in annotations]
252
+ ).astype(int)
253
+ raw_data["classes"][t] = np.atleast_1d([1 for _ in annotations]).astype(
254
+ int
255
+ ) # class-agnostic
256
+ if not is_gt:
257
+ raw_data["tracker_confidences"][t] = np.atleast_1d(
258
+ [ann["score"] for ann in annotations]
259
+ ).astype(float)
260
+
261
+ for t, d in enumerate(raw_data["dets"]):
262
+ if d is None:
263
+ raw_data["dets"][t] = np.empty((0, 4)).astype(float)
264
+ raw_data["ids"][t] = np.empty(0).astype(int)
265
+ raw_data["classes"][t] = np.empty(0).astype(int)
266
+ if not is_gt:
267
+ raw_data["tracker_confidences"][t] = np.empty(0)
268
+
269
+ if is_gt:
270
+ key_map = {"ids": "gt_ids", "classes": "gt_classes", "dets": "gt_dets"}
271
+ else:
272
+ key_map = {
273
+ "ids": "tracker_ids",
274
+ "classes": "tracker_classes",
275
+ "dets": "tracker_dets",
276
+ }
277
+ for k, v in key_map.items():
278
+ raw_data[v] = raw_data.pop(k)
279
+
280
+ # all_classes = [self.class_name_to_class_id[cls] for cls in self.class_list]
281
+ all_classes = [1] # class-agnostic
282
+
283
+ if is_gt:
284
+ classes_to_consider = all_classes
285
+ all_tracks = self.videos_to_gt_tracks[seq_id]
286
+ else:
287
+ # classes_to_consider = self.seq_to_classes[seq_id]['pos_cat_ids'] \
288
+ # + self.seq_to_classes[seq_id]['neg_cat_ids']
289
+ classes_to_consider = all_classes # class-agnostic
290
+ all_tracks = self.tracker_data[tracker]["vids_to_tracks"][seq_id]
291
+
292
+ # classes_to_tracks = {cls: [track for track in all_tracks if track['category_id'] == cls]
293
+ # if cls in classes_to_consider else [] for cls in all_classes}
294
+ classes_to_tracks = {
295
+ cls: [track for track in all_tracks] if cls in classes_to_consider else []
296
+ for cls in all_classes
297
+ } # class-agnostic
298
+
299
+ # mapping from classes to track information
300
+ raw_data["classes_to_tracks"] = {
301
+ cls: [
302
+ {
303
+ det["image_id"]: np.atleast_1d(det["bbox"])
304
+ for det in track["annotations"]
305
+ }
306
+ for track in tracks
307
+ ]
308
+ for cls, tracks in classes_to_tracks.items()
309
+ }
310
+ raw_data["classes_to_track_ids"] = {
311
+ cls: [track["id"] for track in tracks]
312
+ for cls, tracks in classes_to_tracks.items()
313
+ }
314
+ raw_data["classes_to_track_areas"] = {
315
+ cls: [track["area"] for track in tracks]
316
+ for cls, tracks in classes_to_tracks.items()
317
+ }
318
+ raw_data["classes_to_track_lengths"] = {
319
+ cls: [len(track["annotations"]) for track in tracks]
320
+ for cls, tracks in classes_to_tracks.items()
321
+ }
322
+
323
+ if not is_gt:
324
+ raw_data["classes_to_dt_track_scores"] = {
325
+ cls: np.array(
326
+ [
327
+ np.mean([float(x["score"]) for x in track["annotations"]])
328
+ for track in tracks
329
+ ]
330
+ )
331
+ for cls, tracks in classes_to_tracks.items()
332
+ }
333
+
334
+ if is_gt:
335
+ key_map = {
336
+ "classes_to_tracks": "classes_to_gt_tracks",
337
+ "classes_to_track_ids": "classes_to_gt_track_ids",
338
+ "classes_to_track_lengths": "classes_to_gt_track_lengths",
339
+ "classes_to_track_areas": "classes_to_gt_track_areas",
340
+ }
341
+ else:
342
+ key_map = {
343
+ "classes_to_tracks": "classes_to_dt_tracks",
344
+ "classes_to_track_ids": "classes_to_dt_track_ids",
345
+ "classes_to_track_lengths": "classes_to_dt_track_lengths",
346
+ "classes_to_track_areas": "classes_to_dt_track_areas",
347
+ }
348
+ for k, v in key_map.items():
349
+ raw_data[v] = raw_data.pop(k)
350
+
351
+ raw_data["num_timesteps"] = num_timesteps
352
+ raw_data["neg_cat_ids"] = self.seq_to_classes[seq_id]["neg_cat_ids"]
353
+ raw_data["not_exhaustively_labeled_cls"] = self.seq_to_classes[seq_id][
354
+ "not_exhaustively_labeled_cat_ids"
355
+ ]
356
+ raw_data["seq"] = seq
357
+ return raw_data
358
+
359
+ @_timing.time
360
+ def get_preprocessed_seq_data(self, raw_data, cls):
361
+ """Preprocess data for a single sequence for a single class ready for evaluation.
362
+ Inputs:
363
+ - raw_data is a dict containing the data for the sequence already read in by get_raw_seq_data().
364
+ - cls is the class to be evaluated.
365
+ Outputs:
366
+ - data is a dict containing all of the information that metrics need to perform evaluation.
367
+ It contains the following fields:
368
+ [num_timesteps, num_gt_ids, num_tracker_ids, num_gt_dets, num_tracker_dets] : integers.
369
+ [gt_ids, tracker_ids, tracker_confidences]: list (for each timestep) of 1D NDArrays (for each det).
370
+ [gt_dets, tracker_dets]: list (for each timestep) of lists of detections.
371
+ [similarity_scores]: list (for each timestep) of 2D NDArrays.
372
+ Notes:
373
+ General preprocessing (preproc) occurs in 4 steps. Some datasets may not use all of these steps.
374
+ 1) Extract only detections relevant for the class to be evaluated (including distractor detections).
375
+ 2) Match gt dets and tracker dets. Remove tracker dets that are matched to a gt det that is of a
376
+ distractor class, or otherwise marked as to be removed.
377
+ 3) Remove unmatched tracker dets if they fall within a crowd ignore region or don't meet a certain
378
+ other criteria (e.g. are too small).
379
+ 4) Remove gt dets that were only useful for preprocessing and not for actual evaluation.
380
+ After the above preprocessing steps, this function also calculates the number of gt and tracker detections
381
+ and unique track ids. It also relabels gt and tracker ids to be contiguous and checks that ids are
382
+ unique within each timestep.
383
+ TAO:
384
+ In TAO, the 4 preproc steps are as follow:
385
+ 1) All classes present in the ground truth data are evaluated separately.
386
+ 2) No matched tracker detections are removed.
387
+ 3) Unmatched tracker detections are removed if there is not ground truth data and the class does not
388
+ belong to the categories marked as negative for this sequence. Additionally, unmatched tracker
389
+ detections for classes which are marked as not exhaustively labeled are removed.
390
+ 4) No gt detections are removed.
391
+ Further, for TrackMAP computation track representations for the given class are accessed from a dictionary
392
+ and the tracks from the tracker data are sorted according to the tracker confidence.
393
+ """
394
+ cls_id = self.class_name_to_class_id[cls]
395
+ is_not_exhaustively_labeled = cls_id in raw_data["not_exhaustively_labeled_cls"]
396
+ is_neg_category = cls_id in raw_data["neg_cat_ids"]
397
+
398
+ data_keys = [
399
+ "gt_ids",
400
+ "tracker_ids",
401
+ "gt_dets",
402
+ "tracker_dets",
403
+ "tracker_confidences",
404
+ "similarity_scores",
405
+ ]
406
+ data = {key: [None] * raw_data["num_timesteps"] for key in data_keys}
407
+ unique_gt_ids = []
408
+ unique_tracker_ids = []
409
+ num_gt_dets = 0
410
+ num_tracker_dets = 0
411
+ for t in range(raw_data["num_timesteps"]):
412
+ # Only extract relevant dets for this class for preproc and eval (cls)
413
+ gt_class_mask = np.atleast_1d(raw_data["gt_classes"][t] == cls_id)
414
+ gt_class_mask = gt_class_mask.astype(bool)
415
+ gt_ids = raw_data["gt_ids"][t][gt_class_mask]
416
+ gt_dets = raw_data["gt_dets"][t][gt_class_mask]
417
+
418
+ tracker_class_mask = np.atleast_1d(raw_data["tracker_classes"][t] == cls_id)
419
+ tracker_class_mask = tracker_class_mask.astype(bool)
420
+ tracker_ids = raw_data["tracker_ids"][t][tracker_class_mask]
421
+ tracker_dets = raw_data["tracker_dets"][t][tracker_class_mask]
422
+ tracker_confidences = raw_data["tracker_confidences"][t][tracker_class_mask]
423
+ similarity_scores = raw_data["similarity_scores"][t][gt_class_mask, :][
424
+ :, tracker_class_mask
425
+ ]
426
+
427
+ # Match tracker and gt dets (with hungarian algorithm).
428
+ unmatched_indices = np.arange(tracker_ids.shape[0])
429
+ if gt_ids.shape[0] > 0 and tracker_ids.shape[0] > 0:
430
+ matching_scores = similarity_scores.copy()
431
+ matching_scores[matching_scores < 0.5 - np.finfo("float").eps] = 0
432
+ match_rows, match_cols = linear_sum_assignment(-matching_scores)
433
+ actually_matched_mask = (
434
+ matching_scores[match_rows, match_cols] > 0 + np.finfo("float").eps
435
+ )
436
+ match_cols = match_cols[actually_matched_mask]
437
+ unmatched_indices = np.delete(unmatched_indices, match_cols, axis=0)
438
+
439
+ if gt_ids.shape[0] == 0 and not is_neg_category:
440
+ to_remove_tracker = unmatched_indices
441
+ elif is_not_exhaustively_labeled:
442
+ to_remove_tracker = unmatched_indices
443
+ else:
444
+ to_remove_tracker = np.array([], dtype=int)
445
+
446
+ # remove all unwanted unmatched tracker detections
447
+ data["tracker_ids"][t] = np.delete(tracker_ids, to_remove_tracker, axis=0)
448
+ data["tracker_dets"][t] = np.delete(tracker_dets, to_remove_tracker, axis=0)
449
+ data["tracker_confidences"][t] = np.delete(
450
+ tracker_confidences, to_remove_tracker, axis=0
451
+ )
452
+ similarity_scores = np.delete(similarity_scores, to_remove_tracker, axis=1)
453
+
454
+ data["gt_ids"][t] = gt_ids
455
+ data["gt_dets"][t] = gt_dets
456
+ data["similarity_scores"][t] = similarity_scores
457
+
458
+ unique_gt_ids += list(np.unique(data["gt_ids"][t]))
459
+ unique_tracker_ids += list(np.unique(data["tracker_ids"][t]))
460
+ num_tracker_dets += len(data["tracker_ids"][t])
461
+ num_gt_dets += len(data["gt_ids"][t])
462
+
463
+ # Re-label IDs such that there are no empty IDs
464
+ if len(unique_gt_ids) > 0:
465
+ unique_gt_ids = np.unique(unique_gt_ids)
466
+ gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1))
467
+ gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids))
468
+ for t in range(raw_data["num_timesteps"]):
469
+ if len(data["gt_ids"][t]) > 0:
470
+ data["gt_ids"][t] = gt_id_map[data["gt_ids"][t]].astype(int)
471
+ if len(unique_tracker_ids) > 0:
472
+ unique_tracker_ids = np.unique(unique_tracker_ids)
473
+ tracker_id_map = np.nan * np.ones((np.max(unique_tracker_ids) + 1))
474
+ tracker_id_map[unique_tracker_ids] = np.arange(len(unique_tracker_ids))
475
+ for t in range(raw_data["num_timesteps"]):
476
+ if len(data["tracker_ids"][t]) > 0:
477
+ data["tracker_ids"][t] = tracker_id_map[
478
+ data["tracker_ids"][t]
479
+ ].astype(int)
480
+
481
+ # Record overview statistics.
482
+ data["num_tracker_dets"] = num_tracker_dets
483
+ data["num_gt_dets"] = num_gt_dets
484
+ data["num_tracker_ids"] = len(unique_tracker_ids)
485
+ data["num_gt_ids"] = len(unique_gt_ids)
486
+ data["num_timesteps"] = raw_data["num_timesteps"]
487
+ data["seq"] = raw_data["seq"]
488
+
489
+ # get track representations
490
+ data["gt_tracks"] = raw_data["classes_to_gt_tracks"][cls_id]
491
+ data["gt_track_ids"] = raw_data["classes_to_gt_track_ids"][cls_id]
492
+ data["gt_track_lengths"] = raw_data["classes_to_gt_track_lengths"][cls_id]
493
+ data["gt_track_areas"] = raw_data["classes_to_gt_track_areas"][cls_id]
494
+ data["dt_tracks"] = raw_data["classes_to_dt_tracks"][cls_id]
495
+ data["dt_track_ids"] = raw_data["classes_to_dt_track_ids"][cls_id]
496
+ data["dt_track_lengths"] = raw_data["classes_to_dt_track_lengths"][cls_id]
497
+ data["dt_track_areas"] = raw_data["classes_to_dt_track_areas"][cls_id]
498
+ data["dt_track_scores"] = raw_data["classes_to_dt_track_scores"][cls_id]
499
+ data["not_exhaustively_labeled"] = is_not_exhaustively_labeled
500
+ data["iou_type"] = "bbox"
501
+
502
+ # sort tracker data tracks by tracker confidence scores
503
+ if data["dt_tracks"]:
504
+ idx = np.argsort(
505
+ [-score for score in data["dt_track_scores"]], kind="mergesort"
506
+ )
507
+ data["dt_track_scores"] = [data["dt_track_scores"][i] for i in idx]
508
+ data["dt_tracks"] = [data["dt_tracks"][i] for i in idx]
509
+ data["dt_track_ids"] = [data["dt_track_ids"][i] for i in idx]
510
+ data["dt_track_lengths"] = [data["dt_track_lengths"][i] for i in idx]
511
+ data["dt_track_areas"] = [data["dt_track_areas"][i] for i in idx]
512
+ # Ensure that ids are unique per timestep.
513
+ self._check_unique_ids(data)
514
+
515
+ return data
516
+
517
+ def _calculate_similarities(self, gt_dets_t, tracker_dets_t):
518
+ similarity_scores = self._calculate_box_ious(gt_dets_t, tracker_dets_t)
519
+ return similarity_scores
520
+
521
+ def _merge_categories(self, annotations):
522
+ """
523
+ Merges categories with a merged tag. Adapted from https://github.com/TAO-Dataset
524
+ :param annotations: the annotations in which the classes should be merged
525
+ :return: None
526
+ """
527
+ merge_map = {}
528
+ for category in self.gt_data["categories"]:
529
+ if "merged" in category:
530
+ for to_merge in category["merged"]:
531
+ merge_map[to_merge["id"]] = category["id"]
532
+
533
+ for ann in annotations:
534
+ ann["category_id"] = merge_map.get(ann["category_id"], ann["category_id"])
535
+
536
+ def _compute_vid_mappings(self, annotations):
537
+ """
538
+ Computes mappings from Videos to corresponding tracks and images.
539
+ :param annotations: the annotations for which the mapping should be generated
540
+ :return: the video-to-track-mapping, the video-to-image-mapping
541
+ """
542
+ vids_to_tracks = {}
543
+ vids_to_imgs = {}
544
+ vid_ids = [vid["id"] for vid in self.gt_data["videos"]]
545
+
546
+ # compute an mapping from image IDs to images
547
+ images = {}
548
+ for image in self.gt_data["images"]:
549
+ images[image["id"]] = image
550
+
551
+ for ann in annotations:
552
+ ann["area"] = ann["bbox"][2] * ann["bbox"][3]
553
+
554
+ vid = ann["video_id"]
555
+ if ann["video_id"] not in vids_to_tracks.keys():
556
+ vids_to_tracks[ann["video_id"]] = list()
557
+ if ann["video_id"] not in vids_to_imgs.keys():
558
+ vids_to_imgs[ann["video_id"]] = list()
559
+
560
+ # Fill in vids_to_tracks
561
+ tid = ann["track_id"]
562
+ exist_tids = [track["id"] for track in vids_to_tracks[vid]]
563
+ try:
564
+ index1 = exist_tids.index(tid)
565
+ except ValueError:
566
+ index1 = -1
567
+ if tid not in exist_tids:
568
+ curr_track = {
569
+ "id": tid,
570
+ "category_id": ann["category_id"],
571
+ "video_id": vid,
572
+ "annotations": [ann],
573
+ }
574
+ vids_to_tracks[vid].append(curr_track)
575
+ else:
576
+ vids_to_tracks[vid][index1]["annotations"].append(ann)
577
+
578
+ # Fill in vids_to_imgs
579
+ img_id = ann["image_id"]
580
+ exist_img_ids = [img["id"] for img in vids_to_imgs[vid]]
581
+ try:
582
+ index2 = exist_img_ids.index(img_id)
583
+ except ValueError:
584
+ index2 = -1
585
+ if index2 == -1:
586
+ curr_img = {"id": img_id, "annotations": [ann]}
587
+ vids_to_imgs[vid].append(curr_img)
588
+ else:
589
+ vids_to_imgs[vid][index2]["annotations"].append(ann)
590
+
591
+ # sort annotations by frame index and compute track area
592
+ for vid, tracks in vids_to_tracks.items():
593
+ for track in tracks:
594
+ track["annotations"] = sorted(
595
+ track["annotations"],
596
+ key=lambda x: images[x["image_id"]]["frame_index"],
597
+ )
598
+ # Computer average area
599
+ track["area"] = sum(x["area"] for x in track["annotations"]) / len(
600
+ track["annotations"]
601
+ )
602
+
603
+ # Ensure all videos are present
604
+ for vid_id in vid_ids:
605
+ if vid_id not in vids_to_tracks.keys():
606
+ vids_to_tracks[vid_id] = []
607
+ if vid_id not in vids_to_imgs.keys():
608
+ vids_to_imgs[vid_id] = []
609
+
610
+ return vids_to_tracks, vids_to_imgs
611
+
612
+ def _compute_image_to_timestep_mappings(self):
613
+ """
614
+ Computes a mapping from images to the corresponding timestep in the sequence.
615
+ :return: the image-to-timestep-mapping
616
+ """
617
+ images = {}
618
+ for image in self.gt_data["images"]:
619
+ images[image["id"]] = image
620
+
621
+ seq_to_imgs_to_timestep = {vid["id"]: dict() for vid in self.gt_data["videos"]}
622
+ for vid in seq_to_imgs_to_timestep:
623
+ curr_imgs = [img["id"] for img in self.videos_to_gt_images[vid]]
624
+ curr_imgs = sorted(curr_imgs, key=lambda x: images[x]["frame_index"])
625
+ seq_to_imgs_to_timestep[vid] = {
626
+ curr_imgs[i]: i for i in range(len(curr_imgs))
627
+ }
628
+
629
+ return seq_to_imgs_to_timestep
630
+
631
+ def _limit_dets_per_image(self, annotations):
632
+ """
633
+ Limits the number of detections for each image to config['MAX_DETECTIONS']. Adapted from
634
+ https://github.com/TAO-Dataset/
635
+ :param annotations: the annotations in which the detections should be limited
636
+ :return: the annotations with limited detections
637
+ """
638
+ max_dets = self.config["MAX_DETECTIONS"]
639
+ img_ann = defaultdict(list)
640
+ for ann in annotations:
641
+ img_ann[ann["image_id"]].append(ann)
642
+
643
+ for img_id, _anns in img_ann.items():
644
+ if len(_anns) <= max_dets:
645
+ continue
646
+ _anns = sorted(_anns, key=lambda x: x["score"], reverse=True)
647
+ img_ann[img_id] = _anns[:max_dets]
648
+
649
+ return [ann for anns in img_ann.values() for ann in anns]
650
+
651
+ def _fill_video_ids_inplace(self, annotations):
652
+ """
653
+ Fills in missing video IDs inplace. Adapted from https://github.com/TAO-Dataset/
654
+ :param annotations: the annotations for which the videos IDs should be filled inplace
655
+ :return: None
656
+ """
657
+ missing_video_id = [x for x in annotations if "video_id" not in x]
658
+ if missing_video_id:
659
+ image_id_to_video_id = {
660
+ x["id"]: x["video_id"] for x in self.gt_data["images"]
661
+ }
662
+ for x in missing_video_id:
663
+ x["video_id"] = image_id_to_video_id[x["image_id"]]
664
+
665
+ @staticmethod
666
+ def _make_track_ids_unique(annotations):
667
+ """
668
+ Makes the track IDs unqiue over the whole annotation set. Adapted from https://github.com/TAO-Dataset/
669
+ :param annotations: the annotation set
670
+ :return: the number of updated IDs
671
+ """
672
+ track_id_videos = {}
673
+ track_ids_to_update = set()
674
+ max_track_id = 0
675
+ for ann in annotations:
676
+ t = ann["track_id"]
677
+ if t not in track_id_videos:
678
+ track_id_videos[t] = ann["video_id"]
679
+
680
+ if ann["video_id"] != track_id_videos[t]:
681
+ # Track id is assigned to multiple videos
682
+ track_ids_to_update.add(t)
683
+ max_track_id = max(max_track_id, t)
684
+
685
+ if track_ids_to_update:
686
+ print("true")
687
+ next_id = itertools.count(max_track_id + 1)
688
+ new_track_ids = defaultdict(lambda: next(next_id))
689
+ for ann in annotations:
690
+ t = ann["track_id"]
691
+ v = ann["video_id"]
692
+ if t in track_ids_to_update:
693
+ ann["track_id"] = new_track_ids[t, v]
694
+ return len(track_ids_to_update)
695
+
696
+ def _split_known_unknown_distractor(self):
697
+ all_ids = set(
698
+ [i for i in range(1, 2000)]
699
+ ) # 2000 is larger than the max category id in TAO-OW.
700
+ # `knowns` includes 78 TAO_category_ids that corresponds to 78 COCO classes.
701
+ # (The other 2 COCO classes do not have corresponding classes in TAO).
702
+ self.knowns = {
703
+ 4,
704
+ 13,
705
+ 1038,
706
+ 544,
707
+ 1057,
708
+ 34,
709
+ 35,
710
+ 36,
711
+ 41,
712
+ 45,
713
+ 58,
714
+ 60,
715
+ 579,
716
+ 1091,
717
+ 1097,
718
+ 1099,
719
+ 78,
720
+ 79,
721
+ 81,
722
+ 91,
723
+ 1115,
724
+ 1117,
725
+ 95,
726
+ 1122,
727
+ 99,
728
+ 1132,
729
+ 621,
730
+ 1135,
731
+ 625,
732
+ 118,
733
+ 1144,
734
+ 126,
735
+ 642,
736
+ 1155,
737
+ 133,
738
+ 1162,
739
+ 139,
740
+ 154,
741
+ 174,
742
+ 185,
743
+ 699,
744
+ 1215,
745
+ 714,
746
+ 717,
747
+ 1229,
748
+ 211,
749
+ 729,
750
+ 221,
751
+ 229,
752
+ 747,
753
+ 235,
754
+ 237,
755
+ 779,
756
+ 276,
757
+ 805,
758
+ 299,
759
+ 829,
760
+ 852,
761
+ 347,
762
+ 371,
763
+ 382,
764
+ 896,
765
+ 392,
766
+ 926,
767
+ 937,
768
+ 428,
769
+ 429,
770
+ 961,
771
+ 452,
772
+ 979,
773
+ 980,
774
+ 982,
775
+ 475,
776
+ 480,
777
+ 993,
778
+ 1001,
779
+ 502,
780
+ 1018,
781
+ }
782
+ # `distractors` is defined as in the paper "Opening up Open-World Tracking"
783
+ self.distractors = {
784
+ 20,
785
+ 63,
786
+ 108,
787
+ 180,
788
+ 188,
789
+ 204,
790
+ 212,
791
+ 247,
792
+ 303,
793
+ 403,
794
+ 407,
795
+ 415,
796
+ 490,
797
+ 504,
798
+ 507,
799
+ 513,
800
+ 529,
801
+ 567,
802
+ 569,
803
+ 588,
804
+ 672,
805
+ 691,
806
+ 702,
807
+ 708,
808
+ 711,
809
+ 720,
810
+ 736,
811
+ 737,
812
+ 798,
813
+ 813,
814
+ 815,
815
+ 827,
816
+ 831,
817
+ 851,
818
+ 877,
819
+ 883,
820
+ 912,
821
+ 971,
822
+ 976,
823
+ 1130,
824
+ 1133,
825
+ 1134,
826
+ 1169,
827
+ 1184,
828
+ 1220,
829
+ }
830
+ self.unknowns = all_ids.difference(self.knowns.union(self.distractors))
831
+
832
+ def _filter_gt_data(self, raw_gt_data):
833
+ """
834
+ Filter out irrelevant data in the raw_gt_data
835
+ Args:
836
+ raw_gt_data: directly loaded from json.
837
+
838
+ Returns:
839
+ filtered gt_data
840
+ """
841
+ valid_cat_ids = list()
842
+ if self.subset == "known":
843
+ valid_cat_ids = self.knowns
844
+ elif self.subset == "distractor":
845
+ valid_cat_ids = self.distractors
846
+ elif self.subset == "unknown":
847
+ valid_cat_ids = self.unknowns
848
+ # elif self.subset == "test_only_unknowns":
849
+ # valid_cat_ids = test_only_unknowns
850
+ else:
851
+ raise Exception("The parameter `SUBSET` is incorrect")
852
+
853
+ filtered = dict()
854
+ filtered["videos"] = raw_gt_data["videos"]
855
+ # filtered["videos"] = list()
856
+ unwanted_vid = set()
857
+ # for video in raw_gt_data["videos"]:
858
+ # datasrc = video["name"].split('/')[1]
859
+ # if datasrc in data_srcs:
860
+ # filtered["videos"].append(video)
861
+ # else:
862
+ # unwanted_vid.add(video["id"])
863
+
864
+ filtered["annotations"] = list()
865
+ for ann in raw_gt_data["annotations"]:
866
+ if (ann["video_id"] not in unwanted_vid) and (
867
+ ann["category_id"] in valid_cat_ids
868
+ ):
869
+ filtered["annotations"].append(ann)
870
+
871
+ filtered["tracks"] = list()
872
+ for track in raw_gt_data["tracks"]:
873
+ if (track["video_id"] not in unwanted_vid) and (
874
+ track["category_id"] in valid_cat_ids
875
+ ):
876
+ filtered["tracks"].append(track)
877
+
878
+ filtered["images"] = list()
879
+ for image in raw_gt_data["images"]:
880
+ if image["video_id"] not in unwanted_vid:
881
+ filtered["images"].append(image)
882
+
883
+ filtered["categories"] = list()
884
+ for cat in raw_gt_data["categories"]:
885
+ if cat["id"] in valid_cat_ids:
886
+ filtered["categories"].append(cat)
887
+
888
+ filtered["info"] = raw_gt_data["info"]
889
+ filtered["licenses"] = raw_gt_data["licenses"]
890
+
891
+ return filtered
sam3/eval/hota_eval_toolkit/trackeval/datasets/youtube_vis.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ # note: this file has been modified from its original version in TrackEval in
4
+ # https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/datasets/youtube_vis.py
5
+ # to support the following:
6
+ # 1) bbox evaluation (via `IOU_TYPE`)
7
+ # 2) passing GT and prediction data as Python objects (via `GT_JSON_OBJECT` and `TRACKER_JSON_OBJECT`)
8
+ # 3) specifying a custom dataset name (via `DATASET_NAME`)
9
+
10
+ import json
11
+ import os
12
+
13
+ import numpy as np
14
+
15
+ from .. import _timing, utils
16
+ from ..utils import TrackEvalException
17
+ from ._base_dataset import _BaseDataset
18
+
19
+
20
+ class YouTubeVIS(_BaseDataset):
21
+ """Dataset class for YouTubeVIS tracking"""
22
+
23
+ @staticmethod
24
+ def get_default_dataset_config():
25
+ """Default class config values"""
26
+ code_path = utils.get_code_path()
27
+ default_config = {
28
+ "GT_FOLDER": os.path.join(
29
+ code_path, "data/gt/youtube_vis/"
30
+ ), # Location of GT data
31
+ "TRACKERS_FOLDER": os.path.join(code_path, "data/trackers/youtube_vis/"),
32
+ # Trackers location
33
+ "OUTPUT_FOLDER": None, # Where to save eval results (if None, same as TRACKERS_FOLDER)
34
+ "TRACKERS_TO_EVAL": None, # Filenames of trackers to eval (if None, all in folder)
35
+ "CLASSES_TO_EVAL": None, # Classes to eval (if None, all classes)
36
+ "SPLIT_TO_EVAL": "train_sub_split", # Valid: 'train', 'val', 'train_sub_split'
37
+ "PRINT_CONFIG": True, # Whether to print current config
38
+ "OUTPUT_SUB_FOLDER": "", # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER
39
+ "TRACKER_SUB_FOLDER": "data", # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER
40
+ "TRACKER_DISPLAY_NAMES": None, # Names of trackers to display, if None: TRACKERS_TO_EVAL
41
+ # Added for video phrase AP evaluation -- allow directly specifying the GT JSON data and Tracker (result)
42
+ # JSON data as Python objects, without reading from files.
43
+ "GT_JSON_OBJECT": None,
44
+ "TRACKER_JSON_OBJECT": None,
45
+ "IOU_TYPE": "segm",
46
+ "DATASET_NAME": "video",
47
+ }
48
+ return default_config
49
+
50
+ def __init__(self, config=None):
51
+ """Initialise dataset, checking that all required files are present"""
52
+ super().__init__()
53
+ # Fill non-given config values with defaults
54
+ self.config = utils.init_config(config, self.get_default_dataset_config())
55
+ self.gt_fol = (
56
+ self.config["GT_FOLDER"] + "youtube_vis_" + self.config["SPLIT_TO_EVAL"]
57
+ )
58
+ self.tracker_fol = (
59
+ self.config["TRACKERS_FOLDER"]
60
+ + "youtube_vis_"
61
+ + self.config["SPLIT_TO_EVAL"]
62
+ )
63
+ self.use_super_categories = False
64
+ self.should_classes_combine = True
65
+ assert self.config["IOU_TYPE"] in ["segm", "bbox"]
66
+ self.iou_type = self.config["IOU_TYPE"]
67
+ print("=" * 100)
68
+ print(f"Evaluate annotation type *{self.iou_type}*")
69
+ self.dataset_name = self.config["DATASET_NAME"]
70
+
71
+ self.output_fol = self.config["OUTPUT_FOLDER"]
72
+ if self.output_fol is None:
73
+ self.output_fol = self.tracker_fol
74
+ self.output_sub_fol = self.config["OUTPUT_SUB_FOLDER"]
75
+ self.tracker_sub_fol = self.config["TRACKER_SUB_FOLDER"]
76
+
77
+ if self.config["GT_JSON_OBJECT"] is not None:
78
+ # allow directly specifying the GT JSON data without reading from files
79
+ gt_json = self.config["GT_JSON_OBJECT"]
80
+ assert isinstance(gt_json, dict)
81
+ assert "videos" in gt_json
82
+ assert "categories" in gt_json
83
+ assert "annotations" in gt_json
84
+ self.gt_data = gt_json
85
+ else:
86
+ if not os.path.exists(self.gt_fol):
87
+ print("GT folder not found: " + self.gt_fol)
88
+ raise TrackEvalException(
89
+ "GT folder not found: " + os.path.basename(self.gt_fol)
90
+ )
91
+ gt_dir_files = [
92
+ file for file in os.listdir(self.gt_fol) if file.endswith(".json")
93
+ ]
94
+ if len(gt_dir_files) != 1:
95
+ raise TrackEvalException(
96
+ self.gt_fol + " does not contain exactly one json file."
97
+ )
98
+
99
+ with open(os.path.join(self.gt_fol, gt_dir_files[0])) as f:
100
+ self.gt_data = json.load(f)
101
+
102
+ # Get classes to eval
103
+ self.valid_classes = [cls["name"] for cls in self.gt_data["categories"]]
104
+ cls_name_to_cls_id_map = {
105
+ cls["name"]: cls["id"] for cls in self.gt_data["categories"]
106
+ }
107
+
108
+ if self.config["CLASSES_TO_EVAL"]:
109
+ self.class_list = [
110
+ cls.lower() if cls.lower() in self.valid_classes else None
111
+ for cls in self.config["CLASSES_TO_EVAL"]
112
+ ]
113
+ if not all(self.class_list):
114
+ raise TrackEvalException(
115
+ "Attempted to evaluate an invalid class. Only classes "
116
+ + ", ".join(self.valid_classes)
117
+ + " are valid."
118
+ )
119
+ else:
120
+ self.class_list = [cls["name"] for cls in self.gt_data["categories"]]
121
+ self.class_name_to_class_id = {
122
+ k: v for k, v in cls_name_to_cls_id_map.items() if k in self.class_list
123
+ }
124
+
125
+ # Get sequences to eval and check gt files exist
126
+ self.seq_list = [
127
+ vid["file_names"][0].split("/")[0] for vid in self.gt_data["videos"]
128
+ ]
129
+ self.seq_name_to_seq_id = {
130
+ vid["file_names"][0].split("/")[0]: vid["id"]
131
+ for vid in self.gt_data["videos"]
132
+ }
133
+ self.seq_lengths = {
134
+ vid["id"]: len(vid["file_names"]) for vid in self.gt_data["videos"]
135
+ }
136
+
137
+ # encode masks and compute track areas
138
+ self._prepare_gt_annotations()
139
+
140
+ # Get trackers to eval
141
+ if self.config["TRACKER_JSON_OBJECT"] is not None:
142
+ # allow directly specifying the tracker JSON data without reading from files
143
+ tracker_json = self.config["TRACKER_JSON_OBJECT"]
144
+ assert isinstance(tracker_json, list)
145
+ self.tracker_list = ["tracker"]
146
+ elif self.config["TRACKERS_TO_EVAL"] is None:
147
+ self.tracker_list = os.listdir(self.tracker_fol)
148
+ else:
149
+ self.tracker_list = self.config["TRACKERS_TO_EVAL"]
150
+
151
+ if self.config["TRACKER_DISPLAY_NAMES"] is None:
152
+ self.tracker_to_disp = dict(zip(self.tracker_list, self.tracker_list))
153
+ elif (self.config["TRACKERS_TO_EVAL"] is not None) and (
154
+ len(self.config["TRACKER_DISPLAY_NAMES"]) == len(self.tracker_list)
155
+ ):
156
+ self.tracker_to_disp = dict(
157
+ zip(self.tracker_list, self.config["TRACKER_DISPLAY_NAMES"])
158
+ )
159
+ else:
160
+ raise TrackEvalException(
161
+ "List of tracker files and tracker display names do not match."
162
+ )
163
+
164
+ # counter for globally unique track IDs
165
+ self.global_tid_counter = 0
166
+
167
+ self.tracker_data = dict()
168
+ if self.config["TRACKER_JSON_OBJECT"] is not None:
169
+ # allow directly specifying the tracker JSON data without reading from files
170
+ tracker = self.tracker_list[0]
171
+ self.tracker_data[tracker] = tracker_json
172
+ else:
173
+ for tracker in self.tracker_list:
174
+ tracker_dir_path = os.path.join(
175
+ self.tracker_fol, tracker, self.tracker_sub_fol
176
+ )
177
+ tr_dir_files = [
178
+ file
179
+ for file in os.listdir(tracker_dir_path)
180
+ if file.endswith(".json")
181
+ ]
182
+ if len(tr_dir_files) != 1:
183
+ raise TrackEvalException(
184
+ tracker_dir_path + " does not contain exactly one json file."
185
+ )
186
+
187
+ with open(os.path.join(tracker_dir_path, tr_dir_files[0])) as f:
188
+ curr_data = json.load(f)
189
+
190
+ self.tracker_data[tracker] = curr_data
191
+
192
+ def get_display_name(self, tracker):
193
+ return self.tracker_to_disp[tracker]
194
+
195
+ def _load_raw_file(self, tracker, seq, is_gt):
196
+ """Load a file (gt or tracker) in the YouTubeVIS format
197
+ If is_gt, this returns a dict which contains the fields:
198
+ [gt_ids, gt_classes] : list (for each timestep) of 1D NDArrays (for each det).
199
+ [gt_dets]: list (for each timestep) of lists of detections.
200
+ [classes_to_gt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as
201
+ keys and corresponding segmentations as values) for each track
202
+ [classes_to_gt_track_ids, classes_to_gt_track_areas, classes_to_gt_track_iscrowd]: dictionary with class values
203
+ as keys and lists (for each track) as values
204
+
205
+ if not is_gt, this returns a dict which contains the fields:
206
+ [tracker_ids, tracker_classes, tracker_confidences] : list (for each timestep) of 1D NDArrays (for each det).
207
+ [tracker_dets]: list (for each timestep) of lists of detections.
208
+ [classes_to_dt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as
209
+ keys and corresponding segmentations as values) for each track
210
+ [classes_to_dt_track_ids, classes_to_dt_track_areas]: dictionary with class values as keys and lists as values
211
+ [classes_to_dt_track_scores]: dictionary with class values as keys and 1D numpy arrays as values
212
+ """
213
+ # select sequence tracks
214
+ seq_id = self.seq_name_to_seq_id[seq]
215
+ if is_gt:
216
+ tracks = [
217
+ ann for ann in self.gt_data["annotations"] if ann["video_id"] == seq_id
218
+ ]
219
+ else:
220
+ tracks = self._get_tracker_seq_tracks(tracker, seq_id)
221
+
222
+ # Convert data to required format
223
+ num_timesteps = self.seq_lengths[seq_id]
224
+ data_keys = ["ids", "classes", "dets"]
225
+ if not is_gt:
226
+ data_keys += ["tracker_confidences"]
227
+ raw_data = {key: [None] * num_timesteps for key in data_keys}
228
+ result_key = "segmentations" if self.iou_type == "segm" else "bboxes"
229
+ for t in range(num_timesteps):
230
+ raw_data["dets"][t] = [
231
+ track[result_key][t] for track in tracks if track[result_key][t]
232
+ ]
233
+ raw_data["ids"][t] = np.atleast_1d(
234
+ [track["id"] for track in tracks if track[result_key][t]]
235
+ ).astype(int)
236
+ raw_data["classes"][t] = np.atleast_1d(
237
+ [track["category_id"] for track in tracks if track[result_key][t]]
238
+ ).astype(int)
239
+ if not is_gt:
240
+ raw_data["tracker_confidences"][t] = np.atleast_1d(
241
+ [track["score"] for track in tracks if track[result_key][t]]
242
+ ).astype(float)
243
+
244
+ if is_gt:
245
+ key_map = {"ids": "gt_ids", "classes": "gt_classes", "dets": "gt_dets"}
246
+ else:
247
+ key_map = {
248
+ "ids": "tracker_ids",
249
+ "classes": "tracker_classes",
250
+ "dets": "tracker_dets",
251
+ }
252
+ for k, v in key_map.items():
253
+ raw_data[v] = raw_data.pop(k)
254
+
255
+ all_cls_ids = {self.class_name_to_class_id[cls] for cls in self.class_list}
256
+ classes_to_tracks = {
257
+ cls: [track for track in tracks if track["category_id"] == cls]
258
+ for cls in all_cls_ids
259
+ }
260
+
261
+ # mapping from classes to track representations and track information
262
+ raw_data["classes_to_tracks"] = {
263
+ cls: [
264
+ {i: track[result_key][i] for i in range(len(track[result_key]))}
265
+ for track in tracks
266
+ ]
267
+ for cls, tracks in classes_to_tracks.items()
268
+ }
269
+ raw_data["classes_to_track_ids"] = {
270
+ cls: [track["id"] for track in tracks]
271
+ for cls, tracks in classes_to_tracks.items()
272
+ }
273
+ raw_data["classes_to_track_areas"] = {
274
+ cls: [track["area"] for track in tracks]
275
+ for cls, tracks in classes_to_tracks.items()
276
+ }
277
+
278
+ if is_gt:
279
+ raw_data["classes_to_gt_track_iscrowd"] = {
280
+ cls: [track["iscrowd"] for track in tracks]
281
+ for cls, tracks in classes_to_tracks.items()
282
+ }
283
+ else:
284
+ raw_data["classes_to_dt_track_scores"] = {
285
+ cls: np.array([track["score"] for track in tracks])
286
+ for cls, tracks in classes_to_tracks.items()
287
+ }
288
+
289
+ if is_gt:
290
+ key_map = {
291
+ "classes_to_tracks": "classes_to_gt_tracks",
292
+ "classes_to_track_ids": "classes_to_gt_track_ids",
293
+ "classes_to_track_areas": "classes_to_gt_track_areas",
294
+ }
295
+ else:
296
+ key_map = {
297
+ "classes_to_tracks": "classes_to_dt_tracks",
298
+ "classes_to_track_ids": "classes_to_dt_track_ids",
299
+ "classes_to_track_areas": "classes_to_dt_track_areas",
300
+ }
301
+ for k, v in key_map.items():
302
+ raw_data[v] = raw_data.pop(k)
303
+
304
+ raw_data["num_timesteps"] = num_timesteps
305
+ raw_data["seq"] = seq
306
+ return raw_data
307
+
308
+ @_timing.time
309
+ def get_preprocessed_seq_data(self, raw_data, cls):
310
+ """Preprocess data for a single sequence for a single class ready for evaluation.
311
+ Inputs:
312
+ - raw_data is a dict containing the data for the sequence already read in by get_raw_seq_data().
313
+ - cls is the class to be evaluated.
314
+ Outputs:
315
+ - data is a dict containing all of the information that metrics need to perform evaluation.
316
+ It contains the following fields:
317
+ [num_timesteps, num_gt_ids, num_tracker_ids, num_gt_dets, num_tracker_dets] : integers.
318
+ [gt_ids, tracker_ids, tracker_confidences]: list (for each timestep) of 1D NDArrays (for each det).
319
+ [gt_dets, tracker_dets]: list (for each timestep) of lists of detections.
320
+ [similarity_scores]: list (for each timestep) of 2D NDArrays.
321
+ Notes:
322
+ General preprocessing (preproc) occurs in 4 steps. Some datasets may not use all of these steps.
323
+ 1) Extract only detections relevant for the class to be evaluated (including distractor detections).
324
+ 2) Match gt dets and tracker dets. Remove tracker dets that are matched to a gt det that is of a
325
+ distractor class, or otherwise marked as to be removed.
326
+ 3) Remove unmatched tracker dets if they fall within a crowd ignore region or don't meet a certain
327
+ other criteria (e.g. are too small).
328
+ 4) Remove gt dets that were only useful for preprocessing and not for actual evaluation.
329
+ After the above preprocessing steps, this function also calculates the number of gt and tracker detections
330
+ and unique track ids. It also relabels gt and tracker ids to be contiguous and checks that ids are
331
+ unique within each timestep.
332
+ YouTubeVIS:
333
+ In YouTubeVIS, the 4 preproc steps are as follow:
334
+ 1) There are 40 classes which are evaluated separately.
335
+ 2) No matched tracker dets are removed.
336
+ 3) No unmatched tracker dets are removed.
337
+ 4) No gt dets are removed.
338
+ Further, for TrackMAP computation track representations for the given class are accessed from a dictionary
339
+ and the tracks from the tracker data are sorted according to the tracker confidence.
340
+ """
341
+ cls_id = self.class_name_to_class_id[cls]
342
+
343
+ data_keys = [
344
+ "gt_ids",
345
+ "tracker_ids",
346
+ "gt_dets",
347
+ "tracker_dets",
348
+ "similarity_scores",
349
+ ]
350
+ data = {key: [None] * raw_data["num_timesteps"] for key in data_keys}
351
+ unique_gt_ids = []
352
+ unique_tracker_ids = []
353
+ num_gt_dets = 0
354
+ num_tracker_dets = 0
355
+
356
+ for t in range(raw_data["num_timesteps"]):
357
+ # Only extract relevant dets for this class for eval (cls)
358
+ gt_class_mask = np.atleast_1d(raw_data["gt_classes"][t] == cls_id)
359
+ gt_class_mask = gt_class_mask.astype(bool)
360
+ gt_ids = raw_data["gt_ids"][t][gt_class_mask]
361
+ gt_dets = [
362
+ raw_data["gt_dets"][t][ind]
363
+ for ind in range(len(gt_class_mask))
364
+ if gt_class_mask[ind]
365
+ ]
366
+
367
+ tracker_class_mask = np.atleast_1d(raw_data["tracker_classes"][t] == cls_id)
368
+ tracker_class_mask = tracker_class_mask.astype(bool)
369
+ tracker_ids = raw_data["tracker_ids"][t][tracker_class_mask]
370
+ tracker_dets = [
371
+ raw_data["tracker_dets"][t][ind]
372
+ for ind in range(len(tracker_class_mask))
373
+ if tracker_class_mask[ind]
374
+ ]
375
+ similarity_scores = raw_data["similarity_scores"][t][gt_class_mask, :][
376
+ :, tracker_class_mask
377
+ ]
378
+
379
+ data["tracker_ids"][t] = tracker_ids
380
+ data["tracker_dets"][t] = tracker_dets
381
+ data["gt_ids"][t] = gt_ids
382
+ data["gt_dets"][t] = gt_dets
383
+ data["similarity_scores"][t] = similarity_scores
384
+
385
+ unique_gt_ids += list(np.unique(data["gt_ids"][t]))
386
+ unique_tracker_ids += list(np.unique(data["tracker_ids"][t]))
387
+ num_tracker_dets += len(data["tracker_ids"][t])
388
+ num_gt_dets += len(data["gt_ids"][t])
389
+
390
+ # Re-label IDs such that there are no empty IDs
391
+ if len(unique_gt_ids) > 0:
392
+ unique_gt_ids = np.unique(unique_gt_ids)
393
+ gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1))
394
+ gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids))
395
+ for t in range(raw_data["num_timesteps"]):
396
+ if len(data["gt_ids"][t]) > 0:
397
+ data["gt_ids"][t] = gt_id_map[data["gt_ids"][t]].astype(int)
398
+ if len(unique_tracker_ids) > 0:
399
+ unique_tracker_ids = np.unique(unique_tracker_ids)
400
+ tracker_id_map = np.nan * np.ones((np.max(unique_tracker_ids) + 1))
401
+ tracker_id_map[unique_tracker_ids] = np.arange(len(unique_tracker_ids))
402
+ for t in range(raw_data["num_timesteps"]):
403
+ if len(data["tracker_ids"][t]) > 0:
404
+ data["tracker_ids"][t] = tracker_id_map[
405
+ data["tracker_ids"][t]
406
+ ].astype(int)
407
+
408
+ # Ensure that ids are unique per timestep.
409
+ self._check_unique_ids(data)
410
+
411
+ # Record overview statistics.
412
+ data["num_tracker_dets"] = num_tracker_dets
413
+ data["num_gt_dets"] = num_gt_dets
414
+ data["num_tracker_ids"] = len(unique_tracker_ids)
415
+ data["num_gt_ids"] = len(unique_gt_ids)
416
+ data["num_timesteps"] = raw_data["num_timesteps"]
417
+ data["seq"] = raw_data["seq"]
418
+
419
+ # get track representations
420
+ data["gt_tracks"] = raw_data["classes_to_gt_tracks"][cls_id]
421
+ data["gt_track_ids"] = raw_data["classes_to_gt_track_ids"][cls_id]
422
+ data["gt_track_areas"] = raw_data["classes_to_gt_track_areas"][cls_id]
423
+ data["gt_track_iscrowd"] = raw_data["classes_to_gt_track_iscrowd"][cls_id]
424
+ data["dt_tracks"] = raw_data["classes_to_dt_tracks"][cls_id]
425
+ data["dt_track_ids"] = raw_data["classes_to_dt_track_ids"][cls_id]
426
+ data["dt_track_areas"] = raw_data["classes_to_dt_track_areas"][cls_id]
427
+ data["dt_track_scores"] = raw_data["classes_to_dt_track_scores"][cls_id]
428
+ data["iou_type"] = "mask"
429
+
430
+ # sort tracker data tracks by tracker confidence scores
431
+ if data["dt_tracks"]:
432
+ idx = np.argsort(
433
+ [-score for score in data["dt_track_scores"]], kind="mergesort"
434
+ )
435
+ data["dt_track_scores"] = [data["dt_track_scores"][i] for i in idx]
436
+ data["dt_tracks"] = [data["dt_tracks"][i] for i in idx]
437
+ data["dt_track_ids"] = [data["dt_track_ids"][i] for i in idx]
438
+ data["dt_track_areas"] = [data["dt_track_areas"][i] for i in idx]
439
+
440
+ return data
441
+
442
+ def _calculate_similarities(self, gt_dets_t, tracker_dets_t):
443
+ if self.iou_type == "segm":
444
+ similarity_scores = self._calculate_mask_ious(
445
+ gt_dets_t, tracker_dets_t, is_encoded=True, do_ioa=False
446
+ )
447
+ else:
448
+ gt_dets_t = np.array(gt_dets_t, dtype=np.float32).reshape(-1, 4)
449
+ tracker_dets_t = np.array(tracker_dets_t, dtype=np.float32).reshape(-1, 4)
450
+ similarity_scores = self._calculate_box_ious(
451
+ gt_dets_t, tracker_dets_t, box_format="xywh", do_ioa=False
452
+ )
453
+ return similarity_scores
454
+
455
+ def _prepare_gt_annotations(self):
456
+ """
457
+ Prepares GT data by rle encoding segmentations and computing the average track area.
458
+ :return: None
459
+ """
460
+ if self.iou_type == "segm":
461
+ # only loaded when needed to reduce minimum requirements
462
+ from pycocotools import mask as mask_utils
463
+
464
+ for track in self.gt_data["annotations"]:
465
+ h = track["height"]
466
+ w = track["width"]
467
+ for i, seg in enumerate(track["segmentations"]):
468
+ if seg is not None and isinstance(seg["counts"], list):
469
+ track["segmentations"][i] = mask_utils.frPyObjects(seg, h, w)
470
+ areas = [a for a in track["areas"] if a]
471
+ if len(areas) == 0:
472
+ track["area"] = 0
473
+ else:
474
+ track["area"] = np.array(areas).mean()
475
+ else:
476
+ for track in self.gt_data["annotations"]:
477
+ # For bbox eval, compute areas from bboxes if not already available
478
+ areas = [a for a in track.get("areas", []) if a]
479
+ if not areas:
480
+ areas = []
481
+ for bbox in track.get("bboxes", []):
482
+ if bbox is not None:
483
+ areas.append(bbox[2] * bbox[3])
484
+ track["area"] = np.array(areas).mean() if areas else 0
485
+
486
+ def _get_tracker_seq_tracks(self, tracker, seq_id):
487
+ """
488
+ Prepares tracker data for a given sequence. Extracts all annotations for given sequence ID, computes
489
+ average track area and assigns a track ID.
490
+ :param tracker: the given tracker
491
+ :param seq_id: the sequence ID
492
+ :return: the extracted tracks
493
+ """
494
+ # only loaded when needed to reduce minimum requirements
495
+ from pycocotools import mask as mask_utils
496
+
497
+ tracks = [
498
+ ann for ann in self.tracker_data[tracker] if ann["video_id"] == seq_id
499
+ ]
500
+ for track in tracks:
501
+ if "areas" not in track:
502
+ if self.iou_type == "segm":
503
+ for seg in track["segmentations"]:
504
+ if seg:
505
+ track["areas"].append(mask_utils.area(seg))
506
+ else:
507
+ track["areas"].append(None)
508
+ else:
509
+ for bbox in track["bboxes"]:
510
+ if bbox:
511
+ track["areas"].append(bbox[2] * bbox[3])
512
+ else:
513
+ track["areas"].append(None)
514
+ areas = [a for a in track["areas"] if a]
515
+ if len(areas) == 0:
516
+ track["area"] = 0
517
+ else:
518
+ track["area"] = np.array(areas).mean()
519
+ track["id"] = self.global_tid_counter
520
+ self.global_tid_counter += 1
521
+ return tracks
522
+
523
+ def get_name(self):
524
+ return self.dataset_name
sam3/eval/hota_eval_toolkit/trackeval/eval.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ import os
4
+ import time
5
+ import traceback
6
+ from functools import partial
7
+ from multiprocessing.pool import Pool
8
+
9
+ import numpy as np
10
+
11
+ from . import _timing, utils
12
+ from .metrics import Count
13
+ from .utils import TrackEvalException
14
+
15
+ try:
16
+ import tqdm
17
+
18
+ TQDM_IMPORTED = True
19
+ except ImportError as _:
20
+ TQDM_IMPORTED = False
21
+
22
+
23
+ class Evaluator:
24
+ """Evaluator class for evaluating different metrics for different datasets"""
25
+
26
+ @staticmethod
27
+ def get_default_eval_config():
28
+ """Returns the default config values for evaluation"""
29
+ code_path = utils.get_code_path()
30
+ default_config = {
31
+ "USE_PARALLEL": False,
32
+ "NUM_PARALLEL_CORES": 8,
33
+ "BREAK_ON_ERROR": True, # Raises exception and exits with error
34
+ "RETURN_ON_ERROR": False, # if not BREAK_ON_ERROR, then returns from function on error
35
+ "LOG_ON_ERROR": os.path.join(
36
+ code_path, "error_log.txt"
37
+ ), # if not None, save any errors into a log file.
38
+ "PRINT_RESULTS": True,
39
+ "PRINT_ONLY_COMBINED": False,
40
+ "PRINT_CONFIG": True,
41
+ "TIME_PROGRESS": True,
42
+ "DISPLAY_LESS_PROGRESS": True,
43
+ "OUTPUT_SUMMARY": True,
44
+ "OUTPUT_EMPTY_CLASSES": True, # If False, summary files are not output for classes with no detections
45
+ "OUTPUT_DETAILED": True,
46
+ "PLOT_CURVES": True,
47
+ }
48
+ return default_config
49
+
50
+ def __init__(self, config=None):
51
+ """Initialise the evaluator with a config file"""
52
+ self.config = utils.init_config(config, self.get_default_eval_config(), "Eval")
53
+ # Only run timing analysis if not run in parallel.
54
+ if self.config["TIME_PROGRESS"] and not self.config["USE_PARALLEL"]:
55
+ _timing.DO_TIMING = True
56
+ if self.config["DISPLAY_LESS_PROGRESS"]:
57
+ _timing.DISPLAY_LESS_PROGRESS = True
58
+
59
+ def _combine_results(
60
+ self,
61
+ res,
62
+ metrics_list,
63
+ metric_names,
64
+ dataset,
65
+ res_field="COMBINED_SEQ",
66
+ target_tag=None,
67
+ ):
68
+ assert res_field.startswith("COMBINED_SEQ")
69
+ # collecting combined cls keys (cls averaged, det averaged, super classes)
70
+ tracker_list, seq_list, class_list = dataset.get_eval_info()
71
+ combined_cls_keys = []
72
+ res[res_field] = {}
73
+
74
+ # narrow the target for evaluation
75
+ if target_tag is not None:
76
+ target_video_ids = [
77
+ annot["video_id"]
78
+ for annot in dataset.gt_data["annotations"]
79
+ if target_tag in annot["tags"]
80
+ ]
81
+ vid2name = {
82
+ video["id"]: video["file_names"][0].split("/")[0]
83
+ for video in dataset.gt_data["videos"]
84
+ }
85
+ target_video_ids = set(target_video_ids)
86
+ target_video = [vid2name[video_id] for video_id in target_video_ids]
87
+
88
+ if len(target_video) == 0:
89
+ raise TrackEvalException(
90
+ "No sequences found with the tag %s" % target_tag
91
+ )
92
+
93
+ target_annotations = [
94
+ annot
95
+ for annot in dataset.gt_data["annotations"]
96
+ if annot["video_id"] in target_video_ids
97
+ ]
98
+ assert all(target_tag in annot["tags"] for annot in target_annotations), (
99
+ f"Not all annotations in the target sequences have the target tag {target_tag}. "
100
+ "We currently only support a target tag at the sequence level, not at the annotation level."
101
+ )
102
+ else:
103
+ target_video = seq_list
104
+
105
+ # combine sequences for each class
106
+ for c_cls in class_list:
107
+ res[res_field][c_cls] = {}
108
+ for metric, metric_name in zip(metrics_list, metric_names):
109
+ curr_res = {
110
+ seq_key: seq_value[c_cls][metric_name]
111
+ for seq_key, seq_value in res.items()
112
+ if not seq_key.startswith("COMBINED_SEQ")
113
+ and seq_key in target_video
114
+ }
115
+ res[res_field][c_cls][metric_name] = metric.combine_sequences(curr_res)
116
+ # combine classes
117
+ if dataset.should_classes_combine:
118
+ combined_cls_keys += [
119
+ "cls_comb_cls_av",
120
+ "cls_comb_det_av",
121
+ "all",
122
+ ]
123
+ res[res_field]["cls_comb_cls_av"] = {}
124
+ res[res_field]["cls_comb_det_av"] = {}
125
+ for metric, metric_name in zip(metrics_list, metric_names):
126
+ cls_res = {
127
+ cls_key: cls_value[metric_name]
128
+ for cls_key, cls_value in res[res_field].items()
129
+ if cls_key not in combined_cls_keys
130
+ }
131
+ res[res_field]["cls_comb_cls_av"][metric_name] = (
132
+ metric.combine_classes_class_averaged(cls_res)
133
+ )
134
+ res[res_field]["cls_comb_det_av"][metric_name] = (
135
+ metric.combine_classes_det_averaged(cls_res)
136
+ )
137
+ # combine classes to super classes
138
+ if dataset.use_super_categories:
139
+ for cat, sub_cats in dataset.super_categories.items():
140
+ combined_cls_keys.append(cat)
141
+ res[res_field][cat] = {}
142
+ for metric, metric_name in zip(metrics_list, metric_names):
143
+ cat_res = {
144
+ cls_key: cls_value[metric_name]
145
+ for cls_key, cls_value in res[res_field].items()
146
+ if cls_key in sub_cats
147
+ }
148
+ res[res_field][cat][metric_name] = (
149
+ metric.combine_classes_det_averaged(cat_res)
150
+ )
151
+ return res, combined_cls_keys
152
+
153
+ def _summarize_results(
154
+ self,
155
+ res,
156
+ tracker,
157
+ metrics_list,
158
+ metric_names,
159
+ dataset,
160
+ res_field,
161
+ combined_cls_keys,
162
+ ):
163
+ config = self.config
164
+ output_fol = dataset.get_output_fol(tracker)
165
+ tracker_display_name = dataset.get_display_name(tracker)
166
+ for c_cls in res[
167
+ res_field
168
+ ].keys(): # class_list + combined classes if calculated
169
+ summaries = []
170
+ details = []
171
+ num_dets = res[res_field][c_cls]["Count"]["Dets"]
172
+ if config["OUTPUT_EMPTY_CLASSES"] or num_dets > 0:
173
+ for metric, metric_name in zip(metrics_list, metric_names):
174
+ # for combined classes there is no per sequence evaluation
175
+ if c_cls in combined_cls_keys:
176
+ table_res = {res_field: res[res_field][c_cls][metric_name]}
177
+ else:
178
+ table_res = {
179
+ seq_key: seq_value[c_cls][metric_name]
180
+ for seq_key, seq_value in res.items()
181
+ }
182
+
183
+ if config["PRINT_RESULTS"] and config["PRINT_ONLY_COMBINED"]:
184
+ dont_print = (
185
+ dataset.should_classes_combine
186
+ and c_cls not in combined_cls_keys
187
+ )
188
+ if not dont_print:
189
+ metric.print_table(
190
+ {res_field: table_res[res_field]},
191
+ tracker_display_name,
192
+ c_cls,
193
+ res_field,
194
+ res_field,
195
+ )
196
+ elif config["PRINT_RESULTS"]:
197
+ metric.print_table(
198
+ table_res, tracker_display_name, c_cls, res_field, res_field
199
+ )
200
+ if config["OUTPUT_SUMMARY"]:
201
+ summaries.append(metric.summary_results(table_res))
202
+ if config["OUTPUT_DETAILED"]:
203
+ details.append(metric.detailed_results(table_res))
204
+ if config["PLOT_CURVES"]:
205
+ metric.plot_single_tracker_results(
206
+ table_res,
207
+ tracker_display_name,
208
+ c_cls,
209
+ output_fol,
210
+ )
211
+ if config["OUTPUT_SUMMARY"]:
212
+ utils.write_summary_results(summaries, c_cls, output_fol)
213
+ if config["OUTPUT_DETAILED"]:
214
+ utils.write_detailed_results(details, c_cls, output_fol)
215
+
216
+ @_timing.time
217
+ def evaluate(self, dataset_list, metrics_list, show_progressbar=False):
218
+ """Evaluate a set of metrics on a set of datasets"""
219
+ config = self.config
220
+ metrics_list = metrics_list + [Count()] # Count metrics are always run
221
+ metric_names = utils.validate_metrics_list(metrics_list)
222
+ dataset_names = [dataset.get_name() for dataset in dataset_list]
223
+ output_res = {}
224
+ output_msg = {}
225
+
226
+ for dataset, dataset_name in zip(dataset_list, dataset_names):
227
+ # Get dataset info about what to evaluate
228
+ output_res[dataset_name] = {}
229
+ output_msg[dataset_name] = {}
230
+ tracker_list, seq_list, class_list = dataset.get_eval_info()
231
+ print(
232
+ "\nEvaluating %i tracker(s) on %i sequence(s) for %i class(es) on %s dataset using the following "
233
+ "metrics: %s\n"
234
+ % (
235
+ len(tracker_list),
236
+ len(seq_list),
237
+ len(class_list),
238
+ dataset_name,
239
+ ", ".join(metric_names),
240
+ )
241
+ )
242
+
243
+ # Evaluate each tracker
244
+ for tracker in tracker_list:
245
+ # if not config['BREAK_ON_ERROR'] then go to next tracker without breaking
246
+ try:
247
+ # Evaluate each sequence in parallel or in series.
248
+ # returns a nested dict (res), indexed like: res[seq][class][metric_name][sub_metric field]
249
+ # e.g. res[seq_0001][pedestrian][hota][DetA]
250
+ print("\nEvaluating %s\n" % tracker)
251
+ time_start = time.time()
252
+ if config["USE_PARALLEL"]:
253
+ if show_progressbar and TQDM_IMPORTED:
254
+ seq_list_sorted = sorted(seq_list)
255
+
256
+ with Pool(config["NUM_PARALLEL_CORES"]) as pool, tqdm.tqdm(
257
+ total=len(seq_list)
258
+ ) as pbar:
259
+ _eval_sequence = partial(
260
+ eval_sequence,
261
+ dataset=dataset,
262
+ tracker=tracker,
263
+ class_list=class_list,
264
+ metrics_list=metrics_list,
265
+ metric_names=metric_names,
266
+ )
267
+ results = []
268
+ for r in pool.imap(
269
+ _eval_sequence, seq_list_sorted, chunksize=20
270
+ ):
271
+ results.append(r)
272
+ pbar.update()
273
+ res = dict(zip(seq_list_sorted, results))
274
+
275
+ else:
276
+ with Pool(config["NUM_PARALLEL_CORES"]) as pool:
277
+ _eval_sequence = partial(
278
+ eval_sequence,
279
+ dataset=dataset,
280
+ tracker=tracker,
281
+ class_list=class_list,
282
+ metrics_list=metrics_list,
283
+ metric_names=metric_names,
284
+ )
285
+ results = pool.map(_eval_sequence, seq_list)
286
+ res = dict(zip(seq_list, results))
287
+ else:
288
+ res = {}
289
+ if show_progressbar and TQDM_IMPORTED:
290
+ seq_list_sorted = sorted(seq_list)
291
+ for curr_seq in tqdm.tqdm(seq_list_sorted):
292
+ res[curr_seq] = eval_sequence(
293
+ curr_seq,
294
+ dataset,
295
+ tracker,
296
+ class_list,
297
+ metrics_list,
298
+ metric_names,
299
+ )
300
+ else:
301
+ for curr_seq in sorted(seq_list):
302
+ res[curr_seq] = eval_sequence(
303
+ curr_seq,
304
+ dataset,
305
+ tracker,
306
+ class_list,
307
+ metrics_list,
308
+ metric_names,
309
+ )
310
+
311
+ # Combine results over all sequences and then over all classes
312
+ res, combined_cls_keys = self._combine_results(
313
+ res, metrics_list, metric_names, dataset, "COMBINED_SEQ"
314
+ )
315
+
316
+ if np.all(
317
+ ["tags" in annot for annot in dataset.gt_data["annotations"]]
318
+ ):
319
+ # Combine results over the challenging sequences and then over all classes
320
+ # currently only support "tracking_challenging_pair"
321
+ res, _ = self._combine_results(
322
+ res,
323
+ metrics_list,
324
+ metric_names,
325
+ dataset,
326
+ "COMBINED_SEQ_CHALLENGING",
327
+ "tracking_challenging_pair",
328
+ )
329
+
330
+ # Print and output results in various formats
331
+ if config["TIME_PROGRESS"]:
332
+ print(
333
+ "\nAll sequences for %s finished in %.2f seconds"
334
+ % (tracker, time.time() - time_start)
335
+ )
336
+
337
+ self._summarize_results(
338
+ res,
339
+ tracker,
340
+ metrics_list,
341
+ metric_names,
342
+ dataset,
343
+ "COMBINED_SEQ",
344
+ combined_cls_keys,
345
+ )
346
+ if "COMBINED_SEQ_CHALLENGING" in res:
347
+ self._summarize_results(
348
+ res,
349
+ tracker,
350
+ metrics_list,
351
+ metric_names,
352
+ dataset,
353
+ "COMBINED_SEQ_CHALLENGING",
354
+ combined_cls_keys,
355
+ )
356
+
357
+ # Output for returning from function
358
+ output_res[dataset_name][tracker] = res
359
+ output_msg[dataset_name][tracker] = "Success"
360
+
361
+ except Exception as err:
362
+ output_res[dataset_name][tracker] = None
363
+ if type(err) == TrackEvalException:
364
+ output_msg[dataset_name][tracker] = str(err)
365
+ else:
366
+ output_msg[dataset_name][tracker] = "Unknown error occurred."
367
+ print("Tracker %s was unable to be evaluated." % tracker)
368
+ print(err)
369
+ traceback.print_exc()
370
+ if config["LOG_ON_ERROR"] is not None:
371
+ with open(config["LOG_ON_ERROR"], "a") as f:
372
+ print(dataset_name, file=f)
373
+ print(tracker, file=f)
374
+ print(traceback.format_exc(), file=f)
375
+ print("\n\n\n", file=f)
376
+ if config["BREAK_ON_ERROR"]:
377
+ raise err
378
+ elif config["RETURN_ON_ERROR"]:
379
+ return output_res, output_msg
380
+
381
+ return output_res, output_msg
382
+
383
+
384
+ @_timing.time
385
+ def eval_sequence(seq, dataset, tracker, class_list, metrics_list, metric_names):
386
+ """Function for evaluating a single sequence"""
387
+
388
+ raw_data = dataset.get_raw_seq_data(tracker, seq)
389
+ seq_res = {}
390
+ for cls in class_list:
391
+ seq_res[cls] = {}
392
+ data = dataset.get_preprocessed_seq_data(raw_data, cls)
393
+ for metric, met_name in zip(metrics_list, metric_names):
394
+ seq_res[cls][met_name] = metric.eval_sequence(data)
395
+ return seq_res
sam3/eval/hota_eval_toolkit/trackeval/metrics/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ from .count import Count
4
+ from .hota import HOTA
sam3/eval/hota_eval_toolkit/trackeval/metrics/_base_metric.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ import numpy as np
6
+
7
+ from .. import _timing
8
+ from ..utils import TrackEvalException
9
+
10
+
11
+ class _BaseMetric(ABC):
12
+ @abstractmethod
13
+ def __init__(self):
14
+ self.plottable = False
15
+ self.integer_fields = []
16
+ self.float_fields = []
17
+ self.array_labels = []
18
+ self.integer_array_fields = []
19
+ self.float_array_fields = []
20
+ self.fields = []
21
+ self.summary_fields = []
22
+ self.registered = False
23
+
24
+ #####################################################################
25
+ # Abstract functions for subclasses to implement
26
+
27
+ @_timing.time
28
+ @abstractmethod
29
+ def eval_sequence(self, data): ...
30
+
31
+ @abstractmethod
32
+ def combine_sequences(self, all_res): ...
33
+
34
+ @abstractmethod
35
+ def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False): ...
36
+
37
+ @abstractmethod
38
+ def combine_classes_det_averaged(self, all_res): ...
39
+
40
+ def plot_single_tracker_results(self, all_res, tracker, output_folder, cls):
41
+ """Plot results of metrics, only valid for metrics with self.plottable"""
42
+ if self.plottable:
43
+ raise NotImplementedError(
44
+ "plot_results is not implemented for metric %s" % self.get_name()
45
+ )
46
+ else:
47
+ pass
48
+
49
+ #####################################################################
50
+ # Helper functions which are useful for all metrics:
51
+
52
+ @classmethod
53
+ def get_name(cls):
54
+ return cls.__name__
55
+
56
+ @staticmethod
57
+ def _combine_sum(all_res, field):
58
+ """Combine sequence results via sum"""
59
+ return sum([all_res[k][field] for k in all_res.keys()])
60
+
61
+ @staticmethod
62
+ def _combine_weighted_av(all_res, field, comb_res, weight_field):
63
+ """Combine sequence results via weighted average"""
64
+ return sum(
65
+ [all_res[k][field] * all_res[k][weight_field] for k in all_res.keys()]
66
+ ) / np.maximum(1.0, comb_res[weight_field])
67
+
68
+ def print_table(
69
+ self, table_res, tracker, cls, res_field="COMBINED_SEQ", output_lable="COMBINED"
70
+ ):
71
+ """Prints table of results for all sequences"""
72
+ print("")
73
+ metric_name = self.get_name()
74
+ self._row_print(
75
+ [metric_name + ": " + tracker + "-" + cls] + self.summary_fields
76
+ )
77
+ for seq, results in sorted(table_res.items()):
78
+ if seq.startswith("COMBINED_SEQ"):
79
+ continue
80
+ summary_res = self._summary_row(results)
81
+ self._row_print([seq] + summary_res)
82
+ summary_res = self._summary_row(table_res[res_field])
83
+ self._row_print([output_lable] + summary_res)
84
+
85
+ def _summary_row(self, results_):
86
+ vals = []
87
+ for h in self.summary_fields:
88
+ if h in self.float_array_fields:
89
+ vals.append("{0:1.5g}".format(100 * np.mean(results_[h])))
90
+ elif h in self.float_fields:
91
+ vals.append("{0:1.5g}".format(100 * float(results_[h])))
92
+ elif h in self.integer_fields:
93
+ vals.append("{0:d}".format(int(results_[h])))
94
+ else:
95
+ raise NotImplementedError(
96
+ "Summary function not implemented for this field type."
97
+ )
98
+ return vals
99
+
100
+ @staticmethod
101
+ def _row_print(*argv):
102
+ """Prints results in an evenly spaced rows, with more space in first row"""
103
+ if len(argv) == 1:
104
+ argv = argv[0]
105
+ to_print = "%-35s" % argv[0]
106
+ for v in argv[1:]:
107
+ to_print += "%-10s" % str(v)
108
+ print(to_print)
109
+
110
+ def summary_results(self, table_res):
111
+ """Returns a simple summary of final results for a tracker"""
112
+ return dict(
113
+ zip(self.summary_fields, self._summary_row(table_res["COMBINED_SEQ"]))
114
+ )
115
+
116
+ def detailed_results(self, table_res):
117
+ """Returns detailed final results for a tracker"""
118
+ # Get detailed field information
119
+ detailed_fields = self.float_fields + self.integer_fields
120
+ for h in self.float_array_fields + self.integer_array_fields:
121
+ for alpha in [int(100 * x) for x in self.array_labels]:
122
+ detailed_fields.append(h + "___" + str(alpha))
123
+ detailed_fields.append(h + "___AUC")
124
+
125
+ # Get detailed results
126
+ detailed_results = {}
127
+ for seq, res in table_res.items():
128
+ detailed_row = self._detailed_row(res)
129
+ if len(detailed_row) != len(detailed_fields):
130
+ raise TrackEvalException(
131
+ "Field names and data have different sizes (%i and %i)"
132
+ % (len(detailed_row), len(detailed_fields))
133
+ )
134
+ detailed_results[seq] = dict(zip(detailed_fields, detailed_row))
135
+ return detailed_results
136
+
137
+ def _detailed_row(self, res):
138
+ detailed_row = []
139
+ for h in self.float_fields + self.integer_fields:
140
+ detailed_row.append(res[h])
141
+ for h in self.float_array_fields + self.integer_array_fields:
142
+ for i, alpha in enumerate([int(100 * x) for x in self.array_labels]):
143
+ detailed_row.append(res[h][i])
144
+ detailed_row.append(np.mean(res[h]))
145
+ return detailed_row
sam3/eval/hota_eval_toolkit/trackeval/metrics/count.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ from .. import _timing
4
+ from ._base_metric import _BaseMetric
5
+
6
+
7
+ class Count(_BaseMetric):
8
+ """Class which simply counts the number of tracker and gt detections and ids."""
9
+
10
+ def __init__(self, config=None):
11
+ super().__init__()
12
+ self.integer_fields = ["Dets", "GT_Dets", "IDs", "GT_IDs"]
13
+ self.fields = self.integer_fields
14
+ self.summary_fields = self.fields
15
+
16
+ @_timing.time
17
+ def eval_sequence(self, data):
18
+ """Returns counts for one sequence"""
19
+ # Get results
20
+ res = {
21
+ "Dets": data["num_tracker_dets"],
22
+ "GT_Dets": data["num_gt_dets"],
23
+ "IDs": data["num_tracker_ids"],
24
+ "GT_IDs": data["num_gt_ids"],
25
+ "Frames": data["num_timesteps"],
26
+ }
27
+ return res
28
+
29
+ def combine_sequences(self, all_res):
30
+ """Combines metrics across all sequences"""
31
+ res = {}
32
+ for field in self.integer_fields:
33
+ res[field] = self._combine_sum(all_res, field)
34
+ return res
35
+
36
+ def combine_classes_class_averaged(self, all_res, ignore_empty_classes=None):
37
+ """Combines metrics across all classes by averaging over the class values"""
38
+ res = {}
39
+ for field in self.integer_fields:
40
+ res[field] = self._combine_sum(all_res, field)
41
+ return res
42
+
43
+ def combine_classes_det_averaged(self, all_res):
44
+ """Combines metrics across all classes by averaging over the detection values"""
45
+ res = {}
46
+ for field in self.integer_fields:
47
+ res[field] = self._combine_sum(all_res, field)
48
+ return res
sam3/eval/hota_eval_toolkit/trackeval/metrics/hota.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ import os
4
+
5
+ import numpy as np
6
+ from scipy.optimize import linear_sum_assignment
7
+
8
+ from .. import _timing
9
+ from ._base_metric import _BaseMetric
10
+
11
+
12
+ class HOTA(_BaseMetric):
13
+ """Class which implements the HOTA metrics.
14
+ See: https://link.springer.com/article/10.1007/s11263-020-01375-2
15
+ """
16
+
17
+ def __init__(self, config=None):
18
+ super().__init__()
19
+ self.plottable = True
20
+ self.array_labels = np.arange(0.05, 0.99, 0.05)
21
+ self.integer_array_fields = ["HOTA_TP", "HOTA_FN", "HOTA_FP"]
22
+ self.float_array_fields = [
23
+ "HOTA",
24
+ "DetA",
25
+ "AssA",
26
+ "DetRe",
27
+ "DetPr",
28
+ "AssRe",
29
+ "AssPr",
30
+ "LocA",
31
+ "OWTA",
32
+ ]
33
+ self.float_fields = ["HOTA(0)", "LocA(0)", "HOTALocA(0)"]
34
+ self.fields = (
35
+ self.float_array_fields + self.integer_array_fields + self.float_fields
36
+ )
37
+ self.summary_fields = self.float_array_fields + self.float_fields
38
+
39
+ @_timing.time
40
+ def eval_sequence(self, data):
41
+ """Calculates the HOTA metrics for one sequence"""
42
+
43
+ # Initialise results
44
+ res = {}
45
+ for field in self.float_array_fields + self.integer_array_fields:
46
+ res[field] = np.zeros((len(self.array_labels)), dtype=float)
47
+ for field in self.float_fields:
48
+ res[field] = 0
49
+
50
+ # Return result quickly if tracker or gt sequence is empty
51
+ if data["num_tracker_dets"] == 0:
52
+ res["HOTA_FN"] = data["num_gt_dets"] * np.ones(
53
+ (len(self.array_labels)), dtype=float
54
+ )
55
+ res["LocA"] = np.ones((len(self.array_labels)), dtype=float)
56
+ res["LocA(0)"] = 1.0
57
+ return res
58
+ if data["num_gt_dets"] == 0:
59
+ res["HOTA_FP"] = data["num_tracker_dets"] * np.ones(
60
+ (len(self.array_labels)), dtype=float
61
+ )
62
+ res["LocA"] = np.ones((len(self.array_labels)), dtype=float)
63
+ res["LocA(0)"] = 1.0
64
+ return res
65
+
66
+ # Variables counting global association
67
+ potential_matches_count = np.zeros(
68
+ (data["num_gt_ids"], data["num_tracker_ids"])
69
+ )
70
+ gt_id_count = np.zeros((data["num_gt_ids"], 1))
71
+ tracker_id_count = np.zeros((1, data["num_tracker_ids"]))
72
+
73
+ # First loop through each timestep and accumulate global track information.
74
+ for t, (gt_ids_t, tracker_ids_t) in enumerate(
75
+ zip(data["gt_ids"], data["tracker_ids"])
76
+ ):
77
+ # Count the potential matches between ids in each timestep
78
+ # These are normalised, weighted by the match similarity.
79
+ similarity = data["similarity_scores"][t]
80
+ sim_iou_denom = (
81
+ similarity.sum(0)[np.newaxis, :]
82
+ + similarity.sum(1)[:, np.newaxis]
83
+ - similarity
84
+ )
85
+ sim_iou = np.zeros_like(similarity)
86
+ sim_iou_mask = sim_iou_denom > 0 + np.finfo("float").eps
87
+ sim_iou[sim_iou_mask] = (
88
+ similarity[sim_iou_mask] / sim_iou_denom[sim_iou_mask]
89
+ )
90
+ potential_matches_count[
91
+ gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]
92
+ ] += sim_iou
93
+
94
+ # Calculate the total number of dets for each gt_id and tracker_id.
95
+ gt_id_count[gt_ids_t] += 1
96
+ tracker_id_count[0, tracker_ids_t] += 1
97
+
98
+ # Calculate overall jaccard alignment score (before unique matching) between IDs
99
+ global_alignment_score = potential_matches_count / (
100
+ gt_id_count + tracker_id_count - potential_matches_count
101
+ )
102
+ matches_counts = [
103
+ np.zeros_like(potential_matches_count) for _ in self.array_labels
104
+ ]
105
+
106
+ # Calculate scores for each timestep
107
+ for t, (gt_ids_t, tracker_ids_t) in enumerate(
108
+ zip(data["gt_ids"], data["tracker_ids"])
109
+ ):
110
+ # Deal with the case that there are no gt_det/tracker_det in a timestep.
111
+ if len(gt_ids_t) == 0:
112
+ for a, alpha in enumerate(self.array_labels):
113
+ res["HOTA_FP"][a] += len(tracker_ids_t)
114
+ continue
115
+ if len(tracker_ids_t) == 0:
116
+ for a, alpha in enumerate(self.array_labels):
117
+ res["HOTA_FN"][a] += len(gt_ids_t)
118
+ continue
119
+
120
+ # Get matching scores between pairs of dets for optimizing HOTA
121
+ similarity = data["similarity_scores"][t]
122
+ score_mat = (
123
+ global_alignment_score[
124
+ gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :]
125
+ ]
126
+ * similarity
127
+ )
128
+
129
+ # Hungarian algorithm to find best matches
130
+ match_rows, match_cols = linear_sum_assignment(-score_mat)
131
+
132
+ # Calculate and accumulate basic statistics
133
+ for a, alpha in enumerate(self.array_labels):
134
+ actually_matched_mask = (
135
+ similarity[match_rows, match_cols] >= alpha - np.finfo("float").eps
136
+ )
137
+ alpha_match_rows = match_rows[actually_matched_mask]
138
+ alpha_match_cols = match_cols[actually_matched_mask]
139
+ num_matches = len(alpha_match_rows)
140
+ res["HOTA_TP"][a] += num_matches
141
+ res["HOTA_FN"][a] += len(gt_ids_t) - num_matches
142
+ res["HOTA_FP"][a] += len(tracker_ids_t) - num_matches
143
+ if num_matches > 0:
144
+ res["LocA"][a] += sum(
145
+ similarity[alpha_match_rows, alpha_match_cols]
146
+ )
147
+ matches_counts[a][
148
+ gt_ids_t[alpha_match_rows], tracker_ids_t[alpha_match_cols]
149
+ ] += 1
150
+
151
+ # Calculate association scores (AssA, AssRe, AssPr) for the alpha value.
152
+ # First calculate scores per gt_id/tracker_id combo and then average over the number of detections.
153
+ for a, alpha in enumerate(self.array_labels):
154
+ matches_count = matches_counts[a]
155
+ ass_a = matches_count / np.maximum(
156
+ 1, gt_id_count + tracker_id_count - matches_count
157
+ )
158
+ res["AssA"][a] = np.sum(matches_count * ass_a) / np.maximum(
159
+ 1, res["HOTA_TP"][a]
160
+ )
161
+ ass_re = matches_count / np.maximum(1, gt_id_count)
162
+ res["AssRe"][a] = np.sum(matches_count * ass_re) / np.maximum(
163
+ 1, res["HOTA_TP"][a]
164
+ )
165
+ ass_pr = matches_count / np.maximum(1, tracker_id_count)
166
+ res["AssPr"][a] = np.sum(matches_count * ass_pr) / np.maximum(
167
+ 1, res["HOTA_TP"][a]
168
+ )
169
+
170
+ # Calculate final scores
171
+ res["LocA"] = np.maximum(1e-10, res["LocA"]) / np.maximum(1e-10, res["HOTA_TP"])
172
+ res = self._compute_final_fields(res)
173
+ return res
174
+
175
+ def combine_sequences(self, all_res):
176
+ """Combines metrics across all sequences"""
177
+ res = {}
178
+ for field in self.integer_array_fields:
179
+ res[field] = self._combine_sum(all_res, field)
180
+ for field in ["AssRe", "AssPr", "AssA"]:
181
+ res[field] = self._combine_weighted_av(
182
+ all_res, field, res, weight_field="HOTA_TP"
183
+ )
184
+ loca_weighted_sum = sum(
185
+ [all_res[k]["LocA"] * all_res[k]["HOTA_TP"] for k in all_res.keys()]
186
+ )
187
+ res["LocA"] = np.maximum(1e-10, loca_weighted_sum) / np.maximum(
188
+ 1e-10, res["HOTA_TP"]
189
+ )
190
+ res = self._compute_final_fields(res)
191
+ return res
192
+
193
+ def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False):
194
+ """Combines metrics across all classes by averaging over the class values.
195
+ If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection.
196
+ """
197
+ res = {}
198
+ for field in self.integer_array_fields:
199
+ if ignore_empty_classes:
200
+ res[field] = self._combine_sum(
201
+ {
202
+ k: v
203
+ for k, v in all_res.items()
204
+ if (
205
+ v["HOTA_TP"] + v["HOTA_FN"] + v["HOTA_FP"]
206
+ > 0 + np.finfo("float").eps
207
+ ).any()
208
+ },
209
+ field,
210
+ )
211
+ else:
212
+ res[field] = self._combine_sum(
213
+ {k: v for k, v in all_res.items()}, field
214
+ )
215
+
216
+ for field in self.float_fields + self.float_array_fields:
217
+ if ignore_empty_classes:
218
+ res[field] = np.mean(
219
+ [
220
+ v[field]
221
+ for v in all_res.values()
222
+ if (
223
+ v["HOTA_TP"] + v["HOTA_FN"] + v["HOTA_FP"]
224
+ > 0 + np.finfo("float").eps
225
+ ).any()
226
+ ],
227
+ axis=0,
228
+ )
229
+ else:
230
+ res[field] = np.mean([v[field] for v in all_res.values()], axis=0)
231
+ return res
232
+
233
+ def combine_classes_det_averaged(self, all_res):
234
+ """Combines metrics across all classes by averaging over the detection values"""
235
+ res = {}
236
+ for field in self.integer_array_fields:
237
+ res[field] = self._combine_sum(all_res, field)
238
+ for field in ["AssRe", "AssPr", "AssA"]:
239
+ res[field] = self._combine_weighted_av(
240
+ all_res, field, res, weight_field="HOTA_TP"
241
+ )
242
+ loca_weighted_sum = sum(
243
+ [all_res[k]["LocA"] * all_res[k]["HOTA_TP"] for k in all_res.keys()]
244
+ )
245
+ res["LocA"] = np.maximum(1e-10, loca_weighted_sum) / np.maximum(
246
+ 1e-10, res["HOTA_TP"]
247
+ )
248
+ res = self._compute_final_fields(res)
249
+ return res
250
+
251
+ @staticmethod
252
+ def _compute_final_fields(res):
253
+ """Calculate sub-metric ('field') values which only depend on other sub-metric values.
254
+ This function is used both for both per-sequence calculation, and in combining values across sequences.
255
+ """
256
+ res["DetRe"] = res["HOTA_TP"] / np.maximum(1, res["HOTA_TP"] + res["HOTA_FN"])
257
+ res["DetPr"] = res["HOTA_TP"] / np.maximum(1, res["HOTA_TP"] + res["HOTA_FP"])
258
+ res["DetA"] = res["HOTA_TP"] / np.maximum(
259
+ 1, res["HOTA_TP"] + res["HOTA_FN"] + res["HOTA_FP"]
260
+ )
261
+ res["HOTA"] = np.sqrt(res["DetA"] * res["AssA"])
262
+ res["OWTA"] = np.sqrt(res["DetRe"] * res["AssA"])
263
+
264
+ res["HOTA(0)"] = res["HOTA"][0]
265
+ res["LocA(0)"] = res["LocA"][0]
266
+ res["HOTALocA(0)"] = res["HOTA(0)"] * res["LocA(0)"]
267
+ return res
268
+
269
+ def plot_single_tracker_results(self, table_res, tracker, cls, output_folder):
270
+ """Create plot of results"""
271
+
272
+ # Only loaded when run to reduce minimum requirements
273
+ from matplotlib import pyplot as plt
274
+
275
+ res = table_res["COMBINED_SEQ"]
276
+ styles_to_plot = ["r", "b", "g", "b--", "b:", "g--", "g:", "m"]
277
+ for name, style in zip(self.float_array_fields, styles_to_plot):
278
+ plt.plot(self.array_labels, res[name], style)
279
+ plt.xlabel("alpha")
280
+ plt.ylabel("score")
281
+ plt.title(tracker + " - " + cls)
282
+ plt.axis([0, 1, 0, 1])
283
+ legend = []
284
+ for name in self.float_array_fields:
285
+ legend += [name + " (" + str(np.round(np.mean(res[name]), 2)) + ")"]
286
+ plt.legend(legend, loc="lower left")
287
+ out_file = os.path.join(output_folder, cls + "_plot.pdf")
288
+ os.makedirs(os.path.dirname(out_file), exist_ok=True)
289
+ plt.savefig(out_file)
290
+ plt.savefig(out_file.replace(".pdf", ".png"))
291
+ plt.clf()
sam3/eval/hota_eval_toolkit/trackeval/utils.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ import argparse
4
+ import csv
5
+ import os
6
+ from collections import OrderedDict
7
+
8
+
9
+ def init_config(config, default_config, name=None):
10
+ """Initialise non-given config values with defaults"""
11
+ if config is None:
12
+ config = default_config
13
+ else:
14
+ for k in default_config.keys():
15
+ if k not in config.keys():
16
+ config[k] = default_config[k]
17
+ if name and config["PRINT_CONFIG"]:
18
+ print("\n%s Config:" % name)
19
+ for c in config.keys():
20
+ print("%-20s : %-30s" % (c, config[c]))
21
+ return config
22
+
23
+
24
+ def update_config(config):
25
+ """
26
+ Parse the arguments of a script and updates the config values for a given value if specified in the arguments.
27
+ :param config: the config to update
28
+ :return: the updated config
29
+ """
30
+ parser = argparse.ArgumentParser()
31
+ for setting in config.keys():
32
+ if type(config[setting]) == list or type(config[setting]) == type(None):
33
+ parser.add_argument("--" + setting, nargs="+")
34
+ else:
35
+ parser.add_argument("--" + setting)
36
+ args = parser.parse_args().__dict__
37
+ for setting in args.keys():
38
+ if args[setting] is not None:
39
+ if type(config[setting]) == type(True):
40
+ if args[setting] == "True":
41
+ x = True
42
+ elif args[setting] == "False":
43
+ x = False
44
+ else:
45
+ raise Exception(
46
+ "Command line parameter " + setting + "must be True or False"
47
+ )
48
+ elif type(config[setting]) == type(1):
49
+ x = int(args[setting])
50
+ elif type(args[setting]) == type(None):
51
+ x = None
52
+ else:
53
+ x = args[setting]
54
+ config[setting] = x
55
+ return config
56
+
57
+
58
+ def get_code_path():
59
+ """Get base path where code is"""
60
+ return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
61
+
62
+
63
+ def validate_metrics_list(metrics_list):
64
+ """Get names of metric class and ensures they are unique, further checks that the fields within each metric class
65
+ do not have overlapping names.
66
+ """
67
+ metric_names = [metric.get_name() for metric in metrics_list]
68
+ # check metric names are unique
69
+ if len(metric_names) != len(set(metric_names)):
70
+ raise TrackEvalException(
71
+ "Code being run with multiple metrics of the same name"
72
+ )
73
+ fields = []
74
+ for m in metrics_list:
75
+ fields += m.fields
76
+ # check metric fields are unique
77
+ if len(fields) != len(set(fields)):
78
+ raise TrackEvalException(
79
+ "Code being run with multiple metrics with fields of the same name"
80
+ )
81
+ return metric_names
82
+
83
+
84
+ def write_summary_results(summaries, cls, output_folder):
85
+ """Write summary results to file"""
86
+
87
+ fields = sum([list(s.keys()) for s in summaries], [])
88
+ values = sum([list(s.values()) for s in summaries], [])
89
+
90
+ # In order to remain consistent upon new fields being adding, for each of the following fields if they are present
91
+ # they will be output in the summary first in the order below. Any further fields will be output in the order each
92
+ # metric family is called, and within each family either in the order they were added to the dict (python >= 3.6) or
93
+ # randomly (python < 3.6).
94
+ default_order = [
95
+ "HOTA",
96
+ "DetA",
97
+ "AssA",
98
+ "DetRe",
99
+ "DetPr",
100
+ "AssRe",
101
+ "AssPr",
102
+ "LocA",
103
+ "OWTA",
104
+ "HOTA(0)",
105
+ "LocA(0)",
106
+ "HOTALocA(0)",
107
+ "MOTA",
108
+ "MOTP",
109
+ "MODA",
110
+ "CLR_Re",
111
+ "CLR_Pr",
112
+ "MTR",
113
+ "PTR",
114
+ "MLR",
115
+ "CLR_TP",
116
+ "CLR_FN",
117
+ "CLR_FP",
118
+ "IDSW",
119
+ "MT",
120
+ "PT",
121
+ "ML",
122
+ "Frag",
123
+ "sMOTA",
124
+ "IDF1",
125
+ "IDR",
126
+ "IDP",
127
+ "IDTP",
128
+ "IDFN",
129
+ "IDFP",
130
+ "Dets",
131
+ "GT_Dets",
132
+ "IDs",
133
+ "GT_IDs",
134
+ ]
135
+ default_ordered_dict = OrderedDict(
136
+ zip(default_order, [None for _ in default_order])
137
+ )
138
+ for f, v in zip(fields, values):
139
+ default_ordered_dict[f] = v
140
+ for df in default_order:
141
+ if default_ordered_dict[df] is None:
142
+ del default_ordered_dict[df]
143
+ fields = list(default_ordered_dict.keys())
144
+ values = list(default_ordered_dict.values())
145
+
146
+ out_file = os.path.join(output_folder, cls + "_summary.txt")
147
+ os.makedirs(os.path.dirname(out_file), exist_ok=True)
148
+ with open(out_file, "w", newline="") as f:
149
+ writer = csv.writer(f, delimiter=" ")
150
+ writer.writerow(fields)
151
+ writer.writerow(values)
152
+
153
+
154
+ def write_detailed_results(details, cls, output_folder):
155
+ """Write detailed results to file"""
156
+ sequences = details[0].keys()
157
+ fields = ["seq"] + sum([list(s["COMBINED_SEQ"].keys()) for s in details], [])
158
+ out_file = os.path.join(output_folder, cls + "_detailed.csv")
159
+ os.makedirs(os.path.dirname(out_file), exist_ok=True)
160
+ with open(out_file, "w", newline="") as f:
161
+ writer = csv.writer(f)
162
+ writer.writerow(fields)
163
+ for seq in sorted(sequences):
164
+ if seq == "COMBINED_SEQ":
165
+ continue
166
+ writer.writerow([seq] + sum([list(s[seq].values()) for s in details], []))
167
+ writer.writerow(
168
+ ["COMBINED"] + sum([list(s["COMBINED_SEQ"].values()) for s in details], [])
169
+ )
170
+
171
+
172
+ def load_detail(file):
173
+ """Loads detailed data for a tracker."""
174
+ data = {}
175
+ with open(file) as f:
176
+ for i, row_text in enumerate(f):
177
+ row = row_text.replace("\r", "").replace("\n", "").split(",")
178
+ if i == 0:
179
+ keys = row[1:]
180
+ continue
181
+ current_values = row[1:]
182
+ seq = row[0]
183
+ if seq == "COMBINED":
184
+ seq = "COMBINED_SEQ"
185
+ if (len(current_values) == len(keys)) and seq != "":
186
+ data[seq] = {}
187
+ for key, value in zip(keys, current_values):
188
+ data[seq][key] = float(value)
189
+ return data
190
+
191
+
192
+ class TrackEvalException(Exception):
193
+ """Custom exception for catching expected errors."""
194
+
195
+ ...
sam3/eval/postprocessors.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ """Postprocessors class to transform MDETR output according to the downstream task"""
4
+
5
+ import dataclasses
6
+ import logging
7
+ from collections import defaultdict
8
+ from typing import Dict, List, Optional
9
+
10
+ import numpy as np
11
+ import torch
12
+ from sam3.model import box_ops
13
+ from sam3.model.data_misc import BatchedInferenceMetadata, interpolate
14
+ from sam3.train.masks_ops import rle_encode, robust_rle_encode
15
+ from torch import nn
16
+
17
+
18
+ class PostProcessNullOp(nn.Module):
19
+ def __init__(self, **kwargs):
20
+ super(PostProcessNullOp).__init__()
21
+ pass
22
+
23
+ def forward(self, input):
24
+ pass
25
+
26
+ def process_results(self, **kwargs):
27
+ return kwargs["find_stages"]
28
+
29
+
30
+ class PostProcessImage(nn.Module):
31
+ """This module converts the model's output into the format expected by the coco api"""
32
+
33
+ def __init__(
34
+ self,
35
+ max_dets_per_img: int,
36
+ iou_type="bbox",
37
+ to_cpu: bool = True,
38
+ use_original_ids: bool = False,
39
+ use_original_sizes_box: bool = False,
40
+ use_original_sizes_mask: bool = False,
41
+ convert_mask_to_rle: bool = False,
42
+ always_interpolate_masks_on_gpu: bool = True,
43
+ use_presence: bool = True,
44
+ detection_threshold: float = -1.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.max_dets_per_img = max_dets_per_img
48
+ self.iou_type = iou_type
49
+ self.to_cpu = to_cpu
50
+ self.convert_mask_to_rle = convert_mask_to_rle
51
+ self.always_interpolate_masks_on_gpu = always_interpolate_masks_on_gpu
52
+
53
+ self.use_presence = use_presence
54
+ self.detection_threshold = detection_threshold
55
+ self.use_original_ids = use_original_ids
56
+ self.use_original_sizes_box = use_original_sizes_box
57
+ self.use_original_sizes_mask = use_original_sizes_mask
58
+
59
+ @torch.no_grad()
60
+ def forward(
61
+ self,
62
+ outputs,
63
+ target_sizes_boxes,
64
+ target_sizes_masks,
65
+ forced_labels=None,
66
+ consistent=False,
67
+ ret_tensordict: bool = False, # This is experimental
68
+ ):
69
+ """Perform the computation
70
+ Parameters:
71
+ outputs: raw outputs of the model
72
+ target_sizes_boxes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
73
+ For evaluation, this must be the original image size (before any data augmentation)
74
+ For visualization, this should be the image size after data augment, but before padding
75
+ target_sizes_masks: same but used to resize masks
76
+ forced_labels: tensor of dimension [batch_size] containing the label to force for each image of the batch
77
+ This is useful when evaluating the model using standard metrics (eg on COCO, LVIS). In that case,
78
+ we query the model with every possible class label, so we when we pass the predictions to the evaluator,
79
+ we want to make sure that the predicted "class" matches the one that was queried.
80
+ consistent: whether all target sizes are equal
81
+ ret_tensordict: Experimental argument. If true, return a tensordict.TensorDict instead of a list of dictionaries for easier manipulation.
82
+ """
83
+ if ret_tensordict:
84
+ assert (
85
+ consistent is True
86
+ ), "We don't support returning TensorDict if the outputs have different shapes" # NOTE: It's possible but we don't support it.
87
+ assert self.detection_threshold <= 0.0, "TODO: implement?"
88
+ try:
89
+ from tensordict import TensorDict
90
+ except ImportError:
91
+ logging.info(
92
+ "tensordict is not installed. Install by running `pip install tensordict --no-deps`. Falling back by setting `ret_tensordict=False`"
93
+ )
94
+ ret_tensordict = False
95
+
96
+ out_bbox = outputs["pred_boxes"] if "pred_boxes" in outputs else None
97
+ out_logits = outputs["pred_logits"]
98
+ pred_masks = outputs["pred_masks"] if self.iou_type == "segm" else None
99
+ out_probs = out_logits.sigmoid()
100
+ if self.use_presence:
101
+ presence_score = outputs["presence_logit_dec"].sigmoid().unsqueeze(1)
102
+ out_probs = out_probs * presence_score
103
+
104
+ assert target_sizes_boxes.shape[1] == 2
105
+ assert target_sizes_masks.shape[1] == 2
106
+ batch_size = target_sizes_boxes.shape[0]
107
+
108
+ boxes, scores, labels, keep = self._process_boxes_and_labels(
109
+ target_sizes_boxes, forced_labels, out_bbox, out_probs
110
+ )
111
+ assert boxes is None or len(boxes) == batch_size
112
+ out_masks = self._process_masks(
113
+ target_sizes_masks, pred_masks, consistent=consistent, keep=keep
114
+ )
115
+ del pred_masks
116
+
117
+ if boxes is None:
118
+ assert out_masks is not None
119
+ assert not ret_tensordict, "We don't support returning TensorDict if the output does not contain boxes"
120
+ B = len(out_masks)
121
+ boxes = [None] * B
122
+ scores = [None] * B
123
+ labels = [None] * B
124
+
125
+ results = {
126
+ "scores": scores,
127
+ "labels": labels,
128
+ "boxes": boxes,
129
+ }
130
+ if out_masks is not None:
131
+ if self.convert_mask_to_rle:
132
+ results.update(masks_rle=out_masks)
133
+ else:
134
+ results.update(masks=out_masks)
135
+
136
+ if ret_tensordict:
137
+ results = TensorDict(results).auto_batch_size_()
138
+ if self.to_cpu:
139
+ results = results.cpu()
140
+ else:
141
+ # Convert a dictonary of lists/tensors to list of dictionaries
142
+ results = [
143
+ dict(zip(results.keys(), res_tuple))
144
+ for res_tuple in zip(*results.values())
145
+ ]
146
+
147
+ return results
148
+
149
+ def _process_masks(self, target_sizes, pred_masks, consistent=True, keep=None):
150
+ if pred_masks is None:
151
+ return None
152
+ if self.always_interpolate_masks_on_gpu:
153
+ gpu_device = target_sizes.device
154
+ assert gpu_device.type == "cuda"
155
+ pred_masks = pred_masks.to(device=gpu_device)
156
+ if consistent:
157
+ assert keep is None, "TODO: implement?"
158
+ # All masks should have the same shape, expected when processing a batch of size 1
159
+ target_size = target_sizes.unique(dim=0)
160
+ assert target_size.size(0) == 1, "Expecting all target sizes to be equal"
161
+ out_masks = (
162
+ interpolate(
163
+ pred_masks,
164
+ target_size.squeeze().tolist(),
165
+ mode="bilinear",
166
+ align_corners=False,
167
+ ).sigmoid()
168
+ > 0.5
169
+ )
170
+ if self.convert_mask_to_rle:
171
+ raise RuntimeError("TODO: implement?")
172
+ if self.to_cpu:
173
+ out_masks = out_masks.cpu()
174
+ else:
175
+ out_masks = [[]] * len(pred_masks)
176
+
177
+ assert keep is None or len(keep) == len(pred_masks)
178
+ for i, mask in enumerate(pred_masks):
179
+ h, w = target_sizes[i]
180
+ if keep is not None:
181
+ mask = mask[keep[i]]
182
+ # Uses the gpu version fist, moves masks to cpu if it fails"""
183
+ try:
184
+ interpolated = (
185
+ interpolate(
186
+ mask.unsqueeze(1),
187
+ (h, w),
188
+ mode="bilinear",
189
+ align_corners=False,
190
+ ).sigmoid()
191
+ > 0.5
192
+ )
193
+ except Exception as e:
194
+ logging.info("Issue found, reverting to CPU mode!")
195
+ mask_device = mask.device
196
+ mask = mask.cpu()
197
+ interpolated = (
198
+ interpolate(
199
+ mask.unsqueeze(1),
200
+ (h, w),
201
+ mode="bilinear",
202
+ align_corners=False,
203
+ ).sigmoid()
204
+ > 0.5
205
+ )
206
+ interpolated = interpolated.to(mask_device)
207
+
208
+ if self.convert_mask_to_rle:
209
+ out_masks[i] = robust_rle_encode(interpolated.squeeze(1))
210
+ else:
211
+ out_masks[i] = interpolated
212
+ if self.to_cpu:
213
+ out_masks[i] = out_masks[i].cpu()
214
+
215
+ return out_masks
216
+
217
+ def _process_boxes_and_labels(
218
+ self, target_sizes, forced_labels, out_bbox, out_probs
219
+ ):
220
+ if out_bbox is None:
221
+ return None, None, None, None
222
+ assert len(out_probs) == len(target_sizes)
223
+ if self.to_cpu:
224
+ out_probs = out_probs.cpu()
225
+ scores, labels = out_probs.max(-1)
226
+ if forced_labels is None:
227
+ labels = torch.ones_like(labels)
228
+ else:
229
+ labels = forced_labels[:, None].expand_as(labels)
230
+
231
+ # convert to [x0, y0, x1, y1] format
232
+ boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
233
+
234
+ img_h, img_w = target_sizes.unbind(1)
235
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
236
+ boxes = boxes * scale_fct[:, None, :]
237
+
238
+ if self.to_cpu:
239
+ boxes = boxes.cpu()
240
+
241
+ keep = None
242
+ if self.detection_threshold > 0:
243
+ # Filter out the boxes with scores below the detection threshold
244
+ keep = scores > self.detection_threshold
245
+ assert len(keep) == len(boxes) == len(scores) == len(labels)
246
+
247
+ boxes = [b[k.to(b.device)] for b, k in zip(boxes, keep)]
248
+ scores = [s[k.to(s.device)] for s, k in zip(scores, keep)]
249
+ labels = [l[k.to(l.device)] for l, k in zip(labels, keep)]
250
+
251
+ return boxes, scores, labels, keep
252
+
253
+ def process_results(
254
+ self, find_stages, find_metadatas: List[BatchedInferenceMetadata], **kwargs
255
+ ):
256
+ if find_stages.loss_stages is not None:
257
+ find_metadatas = [find_metadatas[i] for i in find_stages.loss_stages]
258
+ assert len(find_stages) == len(find_metadatas)
259
+ results = {}
260
+ for outputs, meta in zip(find_stages, find_metadatas):
261
+ img_size_for_boxes = (
262
+ meta.original_size
263
+ if self.use_original_sizes_box
264
+ else torch.ones_like(meta.original_size)
265
+ )
266
+ img_size_for_masks = (
267
+ meta.original_size
268
+ if self.use_original_sizes_mask
269
+ else torch.ones_like(meta.original_size)
270
+ )
271
+ detection_results = self(
272
+ outputs,
273
+ img_size_for_boxes,
274
+ img_size_for_masks,
275
+ forced_labels=(
276
+ meta.original_category_id if self.use_original_ids else None
277
+ ),
278
+ )
279
+ ids = (
280
+ meta.original_image_id if self.use_original_ids else meta.coco_image_id
281
+ )
282
+ assert len(detection_results) == len(ids)
283
+ for img_id, result in zip(ids, detection_results):
284
+ if img_id.item() not in results:
285
+ results[img_id.item()] = result
286
+ else:
287
+ assert set(results[img_id.item()].keys()) == set(result.keys())
288
+ for k in result.keys():
289
+ if isinstance(result[k], torch.Tensor):
290
+ results[img_id.item()][k] = torch.cat(
291
+ [results[img_id.item()][k], result[k]], dim=0
292
+ )
293
+ elif isinstance(result[k], list):
294
+ results[img_id.item()][k] += result[k]
295
+ else:
296
+ raise NotImplementedError(
297
+ f"Unexpected type {type(result[k])} in result."
298
+ )
299
+ # Prune the results to the max number of detections per image.
300
+ for img_id, result in results.items():
301
+ if (
302
+ self.max_dets_per_img > 0
303
+ and len(result["scores"]) > self.max_dets_per_img
304
+ ):
305
+ _, topk_indexes = torch.topk(
306
+ result["scores"], self.max_dets_per_img, dim=0
307
+ )
308
+ if self.to_cpu:
309
+ topk_indexes = topk_indexes.cpu()
310
+ for k in result.keys():
311
+ if isinstance(results[img_id][k], list):
312
+ results[img_id][k] = [
313
+ results[img_id][k][i] for i in topk_indexes.tolist()
314
+ ]
315
+ else:
316
+ results[img_id][k] = results[img_id][k].to(topk_indexes.device)[
317
+ topk_indexes
318
+ ]
319
+
320
+ return results
321
+
322
+
323
+ class PostProcessAPIVideo(PostProcessImage):
324
+ """This module converts the video model's output into the format expected by the YT-VIS api"""
325
+
326
+ def __init__(
327
+ self,
328
+ *args,
329
+ to_cpu: bool = True,
330
+ convert_mask_to_rle: bool = False,
331
+ always_interpolate_masks_on_gpu: bool = True,
332
+ prob_thresh: float = 0.5,
333
+ use_presence: bool = False,
334
+ **kwargs,
335
+ ):
336
+ super().__init__(
337
+ *args,
338
+ # Here we always set `convert_mask_to_rle=False` in the base `PostProcessAPI` class
339
+ # (so that its `_process_masks` won't return a list of RLEs). If we want to return
340
+ # RLEs for video masklets, we handle it in this `PostProcessAPIVideo` class instead.
341
+ convert_mask_to_rle=False,
342
+ # Here we always set `to_cpu=False` in the base `PostProcessAPI` class (so that
343
+ # the interpolated masks won't be automatically moved back to CPU). We will handle
344
+ # it in this `PostProcessAPIVideo` class instead.
345
+ always_interpolate_masks_on_gpu=always_interpolate_masks_on_gpu,
346
+ use_presence=use_presence,
347
+ **kwargs,
348
+ )
349
+ # Expected keys in the output dict to postprocess
350
+ self.EXPECTED_KEYS = [
351
+ "pred_logits",
352
+ "pred_boxes",
353
+ "pred_masks",
354
+ ]
355
+ # Whether to post-process video masklets (under packed representation) into RLE format
356
+ self.convert_mask_to_rle_for_video = convert_mask_to_rle
357
+ self.to_cpu_for_video = to_cpu
358
+ self.prob_thresh = prob_thresh
359
+
360
+ def process_results(
361
+ self, find_stages, find_metadatas: List[BatchedInferenceMetadata], **kwargs
362
+ ):
363
+ """
364
+ Tracking Postprocessor for SAM 3 video model.
365
+ This function takes in the output of the SAM 3 video model and processes it to extract all the tracklet predictions.
366
+ Args:
367
+ find_stages: A list of tensors representing the output of the SAM 3 video model.
368
+ find_metadatas: A list of BatchedInferenceMetadata objects containing metadata about each frame.
369
+ **kwargs: Additional keyword arguments.
370
+ Returns:
371
+ A dictionary of predcitions with video_id as key.
372
+ """
373
+
374
+ # Import tensordict here to avoid global dependency.
375
+ try:
376
+ from tensordict import TensorDict
377
+ except ImportError as e:
378
+ logging.error(
379
+ "tensordict is not installed, please install by running `pip install tensordict --no-deps`"
380
+ )
381
+ raise e
382
+ # Notes and assumptions:
383
+ # 1- This postprocessor assumes results only for a single video.
384
+ # 2- There are N stage outputs corresponding to N video frames
385
+ # 3- Each stage outputs contains PxQ preds, where P is number of prompts and Q is number of object queries. The output should also contain the tracking object ids corresponding to each object query.
386
+ # 4- The tracking object id has a default value of -1, indicating that the object query is not tracking any object in the frame, and hence its predictions can be ingored for a given frame.
387
+ # 5- Some objects may be tracked in a subset of frames only. So, we first extract the predictions in a packed representation (for efficient postprocessing -- specially memory)
388
+ # and then we convert the packed representation into a padded one, where we zero pad boxes/masks for objects that are not tracked in some frames.
389
+ # 6- We refer to objects by an object id, which is a tuple (prompt_idx, obj_id)
390
+
391
+ assert len(find_stages) > 0, "There is nothing to postprocess?"
392
+ PROMPT_AXIS, OBJ_QUERY_AXIS = (0, 1)
393
+ NO_OBJ_ID = -1
394
+ # Maps object ID -> [indices in packed tensor]
395
+ tracked_objects_packed_idx = defaultdict(list)
396
+ # Maps object ID -> [indices in padded tensor (abs frame index)]
397
+ tracked_objects_frame_idx = defaultdict(list)
398
+ total_num_preds = 0
399
+ # This will hold the packed representation of predictions.
400
+ vid_preds_packed: List[TensorDict] = []
401
+ vid_masklets_rle_packed: List[Optional[Dict]] = []
402
+ video_id = -1 # We assume single video postprocessing, this ID should be unique in the datapoint.
403
+
404
+ for frame_idx, (frame_outs, meta) in enumerate(
405
+ zip(find_stages, find_metadatas)
406
+ ):
407
+ # only store keys we need to extract the results
408
+ frame_outs_td = TensorDict(
409
+ {k: frame_outs[k] for k in self.EXPECTED_KEYS}
410
+ ).auto_batch_size_() # Shape is [P,Q,...]
411
+ meta_td = TensorDict(
412
+ dataclasses.asdict(meta)
413
+ ).auto_batch_size_() # Shape is [P,...]
414
+ unique_vid_id = meta.original_image_id.unique()
415
+ assert unique_vid_id.size(0) == 1
416
+ if video_id == -1:
417
+ video_id = unique_vid_id.item()
418
+ else:
419
+ assert (
420
+ video_id == unique_vid_id.item()
421
+ ), "We can only postprocess one video per datapoint"
422
+ # keeping track of which objects appear in the current frame
423
+ obj_ids_per_frame = frame_outs["pred_object_ids"]
424
+ assert obj_ids_per_frame.size(-1) == frame_outs["pred_logits"].size(-2)
425
+ if self.prob_thresh is not None:
426
+ # only keep the predictions on this frame with probability above the threshold
427
+ # (remove those predictions during the keep-alive period of a tracking query,
428
+ # where its "pred_object_ids" is still the tracked object ID rather than -1)
429
+ pred_probs = frame_outs["pred_logits"].sigmoid().squeeze(-1)
430
+ obj_ids_per_frame = torch.where(
431
+ pred_probs >= self.prob_thresh, obj_ids_per_frame, NO_OBJ_ID
432
+ )
433
+ tracked_obj_ids_idx = torch.where(obj_ids_per_frame != NO_OBJ_ID)
434
+ # Object id is a tuple of (prompt_idx, obj_id). This is because the model can assign same obj_id for two different prompts.
435
+ tracked_obj_ids = [
436
+ (p_id.item(), obj_ids_per_frame[p_id, q_id].item())
437
+ for p_id, q_id in zip(
438
+ tracked_obj_ids_idx[PROMPT_AXIS],
439
+ tracked_obj_ids_idx[OBJ_QUERY_AXIS],
440
+ )
441
+ ]
442
+ if len(tracked_obj_ids) == 0:
443
+ continue
444
+ # For each object, we keep track of the packed and padded (frame index) indices
445
+ for oid in tracked_obj_ids:
446
+ tracked_objects_packed_idx[oid].append(total_num_preds)
447
+ tracked_objects_frame_idx[oid].append(frame_idx)
448
+ total_num_preds += 1
449
+
450
+ # Since we have P*Q masks per frame, mask interpolation is the GPU memory bottleneck or time bottleneck in case of cpu processing.
451
+ # Instead, we first extract results only for tracked objects, reducing the number of masks to K = sum_i(tracked_objs_per_ith_prompt), hopefully <<< P*Q
452
+ tracked_objs_outs_td = frame_outs_td[
453
+ tracked_obj_ids_idx
454
+ ] # [P,Q,...] --> [K,...]
455
+ meta_td = meta_td[tracked_obj_ids_idx[PROMPT_AXIS].cpu()]
456
+ if self.always_interpolate_masks_on_gpu:
457
+ gpu_device = meta_td["original_size"].device
458
+ assert gpu_device.type == "cuda"
459
+ tracked_objs_outs_td = tracked_objs_outs_td.to(device=gpu_device)
460
+ frame_results_td = self(
461
+ tracked_objs_outs_td.unsqueeze(1),
462
+ (
463
+ meta_td["original_size"]
464
+ if self.use_original_sizes
465
+ else torch.ones_like(meta_td["original_size"])
466
+ ),
467
+ forced_labels=(
468
+ meta_td["original_category_id"] if self.use_original_ids else None
469
+ ),
470
+ consistent=True,
471
+ ret_tensordict=True,
472
+ ).squeeze(1)
473
+ del tracked_objs_outs_td
474
+
475
+ # Optionally, remove "masks" from output tensor dict and directly encode them
476
+ # to RLE format under packed representations
477
+ if self.convert_mask_to_rle_for_video:
478
+ interpolated_binary_masks = frame_results_td.pop("masks")
479
+ rle_list = rle_encode(interpolated_binary_masks, return_areas=True)
480
+ vid_masklets_rle_packed.extend(rle_list)
481
+ # Optionally, move output TensorDict to CPU (do this after RLE encoding step above)
482
+ if self.to_cpu_for_video:
483
+ frame_results_td = frame_results_td.cpu()
484
+ vid_preds_packed.append(frame_results_td)
485
+
486
+ if len(vid_preds_packed) == 0:
487
+ logging.debug(f"Video {video_id} has no predictions")
488
+ return {video_id: []}
489
+
490
+ vid_preds_packed = torch.cat(vid_preds_packed, dim=0)
491
+ ############### Construct a padded representation of the predictions ###############
492
+ num_preds = len(tracked_objects_packed_idx)
493
+ num_frames = len(find_stages)
494
+ # We zero pad any missing prediction
495
+ # NOTE: here, we also have padded tensors for "scores" and "labels", but we overwrite them later.
496
+ padded_frames_results = TensorDict(
497
+ {
498
+ k: torch.zeros(
499
+ num_preds, num_frames, *v.shape[1:], device=v.device, dtype=v.dtype
500
+ )
501
+ for k, v in vid_preds_packed.items()
502
+ },
503
+ batch_size=[
504
+ num_preds,
505
+ num_frames,
506
+ ],
507
+ )
508
+ padded_frames_results["scores"][...] = -1e8 # a very low score for empty object
509
+ # Track scores and labels of each pred tracklet, only for frames where the model was able to track that object
510
+ tracklet_scores = []
511
+ tracklet_labels = []
512
+ # Optionally, fill the list of RLEs for masklets
513
+ # note: only frames with actual predicted masks (in packed format) will be
514
+ # filled with RLEs; the rest will remains None in results["masks_rle"]
515
+ if self.convert_mask_to_rle_for_video:
516
+ vid_masklets_rle_padded = [[None] * num_frames for _ in range(num_preds)]
517
+ for o_idx, oid in enumerate(tracked_objects_packed_idx):
518
+ oid2packed_idx = tracked_objects_packed_idx[oid]
519
+ oid2padded_idx = tracked_objects_frame_idx[oid]
520
+ obj_packed_results = vid_preds_packed[oid2packed_idx]
521
+ padded_frames_results[o_idx][oid2padded_idx] = obj_packed_results
522
+ if self.convert_mask_to_rle_for_video:
523
+ for packed_idx, padded_idx in zip(oid2packed_idx, oid2padded_idx):
524
+ vid_masklets_rle_padded[o_idx][padded_idx] = (
525
+ vid_masklets_rle_packed[packed_idx]
526
+ )
527
+ # NOTE: We need a single confidence score per tracklet for the mAP metric.
528
+ # We use the average confidence score across time. (How does this impact AP?)
529
+ tracklet_scores.append(obj_packed_results["scores"].mean())
530
+ # We also need to have a unique category Id per tracklet.
531
+ # This is not a problem for phrase AP, however, for mAP we do majority voting across time.
532
+ tracklet_labels.append(obj_packed_results["labels"].mode()[0])
533
+
534
+ results = padded_frames_results.to_dict()
535
+ results["scores"] = torch.stack(tracklet_scores, dim=0)
536
+ results["labels"] = torch.stack(tracklet_labels, dim=0)
537
+ if self.convert_mask_to_rle_for_video:
538
+ results["masks_rle"] = vid_masklets_rle_padded
539
+ # we keep the frame-level scores since it's needed by some evaluation scripts
540
+ results["per_frame_scores"] = padded_frames_results["scores"]
541
+
542
+ return {video_id: results}
543
+
544
+
545
+ class PostProcessTracking(PostProcessImage):
546
+ """This module converts the model's output into the format expected by the coco api"""
547
+
548
+ def __init__(
549
+ self,
550
+ max_dets_per_img: int,
551
+ iou_type="bbox",
552
+ force_single_mask: bool = False,
553
+ **kwargs,
554
+ ) -> None:
555
+ super().__init__(max_dets_per_img=max_dets_per_img, iou_type=iou_type, **kwargs)
556
+ self.force_single_mask = force_single_mask
557
+
558
+ def process_results(
559
+ self, find_stages, find_metadatas: BatchedInferenceMetadata, **kwargs
560
+ ):
561
+ assert len(find_stages) == len(find_metadatas)
562
+ results = {}
563
+ for outputs, meta in zip(find_stages, find_metadatas):
564
+ if self.force_single_mask:
565
+ scores, labels = outputs["pred_logits"].max(-1)
566
+ m = []
567
+ for i in range(len(outputs["pred_masks"])):
568
+ score, idx = scores[i].max(0)
569
+ m.append(outputs["pred_masks"][i][idx])
570
+ outputs["pred_masks"] = torch.stack(m, 0).unsqueeze(1)
571
+ detection_results = self(outputs, meta.original_size, consistent=False)
572
+ assert len(detection_results) == len(meta.coco_image_id)
573
+ results.update(
574
+ {
575
+ (media_id.item(), object_id.item(), frame_index.item()): result
576
+ for media_id, object_id, frame_index, result in zip(
577
+ meta.original_image_id,
578
+ meta.object_id,
579
+ meta.frame_index,
580
+ detection_results,
581
+ )
582
+ }
583
+ )
584
+ return results
585
+
586
+
587
+ class PostProcessCounting(nn.Module):
588
+ """This module converts the model's output to be evaluated for counting tasks"""
589
+
590
+ def __init__(
591
+ self,
592
+ use_original_ids: bool = False,
593
+ threshold: float = 0.5,
594
+ use_presence: bool = False,
595
+ ) -> None:
596
+ """
597
+ Args:
598
+ use_original_ids: whether to use the original image ids or the coco ids
599
+ threshold: threshold for counting (values above this are counted)
600
+ """
601
+ super().__init__()
602
+ self.use_original_ids = use_original_ids
603
+ self.threshold = threshold
604
+ self.use_presence = use_presence
605
+
606
+ def forward(self, outputs, target_sizes):
607
+ """Perform the computation
608
+ Parameters:
609
+ outputs: raw outputs of the model
610
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
611
+ """
612
+ # Extract scores from model outputs and apply sigmoid
613
+ scores = torch.sigmoid(outputs["pred_logits"]).squeeze(-1) # [B, N]
614
+ if self.use_presence:
615
+ presence_score = outputs["presence_logit_dec"].sigmoid()
616
+ if presence_score.ndim == 1:
617
+ presence_score = presence_score.unsqueeze(1) # [B, 1]
618
+ scores = scores * presence_score # [B, N]
619
+
620
+ # Calculate counts by summing values above threshold
621
+ counts = (scores > self.threshold).float().sum(dim=1)
622
+
623
+ assert len(counts) == len(target_sizes)
624
+ results = []
625
+ for count in counts:
626
+ results.append({"count": count.item()})
627
+
628
+ return results
629
+
630
+ @torch.no_grad()
631
+ def process_results(
632
+ self, find_stages, find_metadatas: List[BatchedInferenceMetadata], **kwargs
633
+ ):
634
+ assert len(find_stages) == len(find_metadatas)
635
+ results = {}
636
+ for outputs, meta in zip(find_stages, find_metadatas):
637
+ detection_results = self(
638
+ outputs,
639
+ meta.original_size,
640
+ )
641
+ ids = (
642
+ meta.original_image_id if self.use_original_ids else meta.coco_image_id
643
+ )
644
+ assert len(detection_results) == len(ids)
645
+ for img_id, result in zip(ids, detection_results):
646
+ results[img_id.item()] = result
647
+
648
+ return results
sam3/eval/saco_veval_eval.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+ import argparse
3
+ import json
4
+ import os
5
+ from collections import defaultdict
6
+
7
+ from iopath.common.file_io import g_pathmgr
8
+ from sam3.eval.saco_veval_evaluators import (
9
+ VideoCGF1Evaluator,
10
+ VideoPhraseApEvaluator,
11
+ VideoPhraseHotaEvaluator,
12
+ VideoTetaEvaluator,
13
+ YTVISPredFileEvaluator,
14
+ )
15
+
16
+
17
+ class VEvalEvaluator:
18
+ def __init__(self, gt_annot_file: str, eval_res_file: str):
19
+ self.gt_annot_file = gt_annot_file
20
+ self.eval_res_file = eval_res_file
21
+ self.evaluators = [
22
+ # mAP
23
+ YTVISPredFileEvaluator(gt_annot_file),
24
+ # Phrase AP
25
+ VideoPhraseApEvaluator(gt_annot_file),
26
+ # TETA
27
+ VideoTetaEvaluator(gt_annot_file, use_mask=True, is_exhaustive=True),
28
+ # HOTA
29
+ VideoPhraseHotaEvaluator(gt_annot_file),
30
+ # cgF1
31
+ VideoCGF1Evaluator(gt_annot_file),
32
+ ]
33
+
34
+ def run_eval(self, pred_file: str):
35
+ dataset_results = {}
36
+ video_np_results = defaultdict(dict)
37
+ for evaluator in self.evaluators:
38
+ d_res, v_np_res = evaluator.evaluate(pred_file)
39
+ dataset_results.update(d_res)
40
+ for (video_id, category_id), res in v_np_res.items():
41
+ video_np_results[(video_id, category_id)].update(res)
42
+
43
+ if len(dataset_results) == 0:
44
+ dataset_results = {"": 0.0}
45
+
46
+ formatted_video_np_results = [
47
+ {"video_id": video_id, "category_id": category_id, **res}
48
+ for (video_id, category_id), res in video_np_results.items()
49
+ ]
50
+ eval_metrics = {
51
+ "dataset_results": dataset_results,
52
+ "video_np_results": formatted_video_np_results,
53
+ }
54
+
55
+ with g_pathmgr.open(self.eval_res_file, "w") as f:
56
+ json.dump(eval_metrics, f)
57
+
58
+ return eval_metrics
59
+
60
+
61
+ def run_main_all(dataset_name, args):
62
+ gt_annot_file = os.path.join(args.gt_annot_dir, dataset_name + ".json")
63
+ pred_file = os.path.join(args.pred_dir, dataset_name + "_preds.json")
64
+ eval_res_file = os.path.join(args.eval_res_dir, dataset_name + "_eval_res.json")
65
+ print(f"=== Running evaluation for Pred {pred_file} vs GT {gt_annot_file} ===")
66
+ veval_evaluator = VEvalEvaluator(
67
+ gt_annot_file=gt_annot_file, eval_res_file=eval_res_file
68
+ )
69
+ _ = veval_evaluator.run_eval(pred_file=pred_file)
70
+
71
+ print(f"=== Results saved to {eval_res_file} ===")
72
+
73
+
74
+ def main_all(args):
75
+ saco_veval_dataset_names = [
76
+ "saco_veval_sav_test",
77
+ "saco_veval_sav_val",
78
+ "saco_veval_yt1b_test",
79
+ "saco_veval_yt1b_val",
80
+ "saco_veval_smartglasses_test",
81
+ "saco_veval_smartglasses_val",
82
+ ]
83
+
84
+ # multiprocessing may not really work as inner evaluator also using multiprocessing
85
+ # so we just for loop
86
+ for dataset_name in saco_veval_dataset_names:
87
+ print(f"=== Running evaluation for dataset {dataset_name} ===")
88
+ run_main_all(dataset_name=dataset_name, args=args)
89
+
90
+
91
+ def main_one(args):
92
+ gt_annot_file = args.gt_annot_file
93
+ pred_file = args.pred_file
94
+ eval_res_file = args.eval_res_file
95
+
96
+ print(f"=== Running evaluation for Pred {pred_file} vs GT {gt_annot_file} ===")
97
+ veval_evaluator = VEvalEvaluator(
98
+ gt_annot_file=gt_annot_file, eval_res_file=eval_res_file
99
+ )
100
+ _ = veval_evaluator.run_eval(pred_file=pred_file)
101
+
102
+ print(f"=== Results saved to {eval_res_file} ===")
103
+
104
+
105
+ def main():
106
+ parser = argparse.ArgumentParser(description="Run video grounding evaluators")
107
+
108
+ # Create subparsers for different commands
109
+ subparsers = parser.add_subparsers(dest="command", required=True)
110
+
111
+ # Run evaluation for all datasets
112
+ all_parser = subparsers.add_parser("all", help="Run evaluation for all datasets")
113
+ all_parser.add_argument(
114
+ "--gt_annot_dir",
115
+ type=str,
116
+ help="Directory that contains the ground truth annotation files",
117
+ )
118
+ all_parser.add_argument(
119
+ "--pred_dir",
120
+ type=str,
121
+ help="Directory that contains the prediction files",
122
+ )
123
+ all_parser.add_argument(
124
+ "--eval_res_dir",
125
+ type=str,
126
+ help="Directory that contains the eval results files",
127
+ )
128
+ all_parser.set_defaults(func=main_all)
129
+
130
+ # Run evaluation for one dataset
131
+ one_parser = subparsers.add_parser("one", help="Run evaluation for one dataset")
132
+ one_parser.add_argument(
133
+ "--gt_annot_file",
134
+ type=str,
135
+ help="Path to the ground truth annotation file",
136
+ )
137
+ one_parser.add_argument(
138
+ "--pred_file",
139
+ type=str,
140
+ help="Path to the prediction file",
141
+ )
142
+ one_parser.add_argument(
143
+ "--eval_res_file",
144
+ type=str,
145
+ help="Path to the eval results file",
146
+ )
147
+ one_parser.set_defaults(func=main_one)
148
+
149
+ # Parse and dispatch
150
+ args = parser.parse_args()
151
+ args.func(args)
152
+
153
+
154
+ if __name__ == "__main__":
155
+ main()
sam3/eval/saco_veval_evaluators.py ADDED
@@ -0,0 +1,838 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+ import json
3
+ import os
4
+ import tempfile
5
+ from collections import defaultdict
6
+ from typing import Dict, Optional, Sequence, Tuple
7
+
8
+ import numpy as np
9
+ import pycocotools.mask
10
+ from sam3.eval.cgf1_eval import CGF1_METRICS
11
+ from sam3.eval.conversion_util import (
12
+ convert_ytbvis_to_cocovid_gt,
13
+ convert_ytbvis_to_cocovid_pred,
14
+ )
15
+ from sam3.eval.hota_eval_toolkit.run_ytvis_eval import run_ytvis_eval
16
+ from sam3.eval.teta_eval_toolkit import config, Evaluator, metrics
17
+ from sam3.eval.teta_eval_toolkit.datasets import COCO, TAO
18
+ from sam3.eval.ytvis_coco_wrapper import YTVIS
19
+ from sam3.eval.ytvis_eval import VideoDemoF1Eval, YTVISeval
20
+ from sam3.train.nms_helper import process_frame_level_nms, process_track_level_nms
21
+
22
+
23
+ def _get_metric_index(metric_name: str, iou_threshold: Optional[float] = None) -> int:
24
+ """
25
+ Find the index of a metric in CGF1_METRICS by name and IoU threshold.
26
+
27
+ Args:
28
+ metric_name: Name of the metric (e.g., "cgF1", "precision", "recall")
29
+ iou_threshold: IoU threshold (None for average over 0.5:0.95, or specific value like 0.5, 0.75)
30
+
31
+ Returns:
32
+ Index of the metric in CGF1_METRICS
33
+
34
+ Raises:
35
+ ValueError: If metric not found
36
+ """
37
+ for idx, metric in enumerate(CGF1_METRICS):
38
+ if metric.name == metric_name and metric.iou_threshold == iou_threshold:
39
+ return idx
40
+ raise ValueError(
41
+ f"Metric '{metric_name}' with IoU threshold {iou_threshold} not found in CGF1_METRICS"
42
+ )
43
+
44
+
45
+ class BasePredFileEvaluator:
46
+ """A base class for evaluating a prediction file."""
47
+
48
+ pass
49
+
50
+
51
+ class YTVISPredFileEvaluator(BasePredFileEvaluator):
52
+ """Evaluate class mAP for YT-VIS prediction files."""
53
+
54
+ def __init__(
55
+ self,
56
+ gt_ann_file: str,
57
+ dataset_name: str = "video",
58
+ iou_types: Optional[Sequence[str]] = None,
59
+ ):
60
+ self.gt_ann_file = gt_ann_file
61
+ self.dataset_name = dataset_name
62
+ self.iou_types = list(iou_types) if iou_types is not None else ["bbox", "segm"]
63
+ assert all(iou_type in ["bbox", "segm"] for iou_type in self.iou_types)
64
+
65
+ def evaluate(self, pred_file: str) -> Dict[str, float]:
66
+ # use our internal video evaluation toolkit for YT-VIS pred file
67
+ # (i.e. the same one we're using for video phrase AP)
68
+ results = {}
69
+ use_cats = True # YT-VIS mAP evaluation uses categories
70
+ ytvisGT = YTVIS(self.gt_ann_file, ignore_gt_cats=not use_cats)
71
+ # the original YT-VIS GT annotations have uncompressed RLEs ("counts" is an integer list)
72
+ # rather than compressed RLEs ("counts" is a string), so we first convert them here.
73
+ if "segm" in self.iou_types:
74
+ for ann in ytvisGT.dataset["annotations"]:
75
+ ann["segmentations"] = [
76
+ _compress_rle(rle) for rle in ann["segmentations"]
77
+ ]
78
+
79
+ with open(pred_file) as f:
80
+ dt = json.load(f)
81
+ # Our prediction file saves "video_id" and absolute (unnormalized) boxes.
82
+ # Note that we should use the official (original) YT-VIS annotations (i.e. the one
83
+ # saved via "scripts/datasets/training/ytvis_split.py", instead of the one saved
84
+ # via "scripts/api_db_to_ytvis_json.py") in this evaluator, which contain absolute
85
+ # boxes coordinates in its GT annotations.
86
+ for d in dt:
87
+ d["image_id"] = d["video_id"]
88
+ ytvisDT = ytvisGT.loadRes(dt)
89
+
90
+ for iou_type in self.iou_types:
91
+ ytvisEval = YTVISeval(ytvisGT, ytvisDT, iou_type)
92
+
93
+ # set the area ranges for small, medium, and large objects (using
94
+ # absolute pixel areas) as in the official YT-VIS evaluation toolkit:
95
+ # https://github.com/achalddave/ytvosapi/blob/eca601117c9f86bad084cb91f1d918e9ab665a75/PythonAPI/ytvostools/ytvoseval.py#L538
96
+ ytvisEval.params.areaRng = [
97
+ [0**2, 1e5**2],
98
+ [0**2, 128**2],
99
+ [128**2, 256**2],
100
+ [256**2, 1e5**2],
101
+ ]
102
+ ytvisEval.params.areaRngLbl = ["all", "small", "medium", "large"]
103
+ ytvisEval.params.useCats = use_cats
104
+
105
+ ytvisEval.evaluate()
106
+ ytvisEval.accumulate()
107
+ ytvisEval.summarize()
108
+ result_key = f"{self.dataset_name}_{'mask' if iou_type == 'segm' else 'bbox'}_mAP_50_95"
109
+ results[result_key] = ytvisEval.stats[0]
110
+
111
+ # video-NP level results not supported for `YTVISPredFileEvaluator` yet
112
+ video_np_level_results = {}
113
+ return results, video_np_level_results
114
+
115
+
116
+ class VideoPhraseApEvaluator(BasePredFileEvaluator):
117
+ """Evaluate Video Phrase AP with YT-VIS format prediction and GT files."""
118
+
119
+ def __init__(
120
+ self,
121
+ gt_ann_file: str,
122
+ dataset_name: str = "video",
123
+ iou_types: Optional[Sequence[str]] = None,
124
+ ):
125
+ self.gt_ann_file = gt_ann_file
126
+ self.dataset_name = dataset_name
127
+ self.iou_types = list(iou_types) if iou_types is not None else ["bbox", "segm"]
128
+ assert all(iou_type in ["bbox", "segm"] for iou_type in self.iou_types)
129
+
130
+ def evaluate(self, pred_file: str) -> Dict[str, float]:
131
+ with open(self.gt_ann_file) as f:
132
+ gt = json.load(f)
133
+ with open(pred_file) as f:
134
+ dt = json.load(f)
135
+ # For phrase AP and demo F1 evaluation, we need to remap each pair of (video_id, category_id) to
136
+ # a new unique video_id, so that we don't mix detections from different categories under `useCat=False`
137
+ gt, dt = remap_video_category_pairs_to_unique_video_ids(gt, dt)
138
+ if "segm" in self.iou_types:
139
+ for ann in gt["annotations"]:
140
+ ann["segmentations"] = [
141
+ _compress_rle(rle) for rle in ann["segmentations"]
142
+ ]
143
+ for d in dt:
144
+ d["image_id"] = d["video_id"]
145
+
146
+ results = {}
147
+ use_cats = False # Phrase AP evaluation does not use categories
148
+ ytvisGT = YTVIS(annotation_file=None, ignore_gt_cats=not use_cats)
149
+ ytvisGT.dataset = gt
150
+ ytvisGT.createIndex()
151
+ ytvisDT = ytvisGT.loadRes(dt)
152
+
153
+ for iou_type in self.iou_types:
154
+ phraseApEval = YTVISeval(ytvisGT, ytvisDT, iou_type)
155
+
156
+ # set the area ranges for small, medium, and large objects (using
157
+ # absolute pixel areas) as in the official YT-VIS evaluation toolkit:
158
+ # https://github.com/achalddave/ytvosapi/blob/eca601117c9f86bad084cb91f1d918e9ab665a75/PythonAPI/ytvostools/ytvoseval.py#L538
159
+ phraseApEval.params.areaRng = [
160
+ [0**2, 1e5**2],
161
+ [0**2, 128**2],
162
+ [128**2, 256**2],
163
+ [256**2, 1e5**2],
164
+ ]
165
+ phraseApEval.params.areaRngLbl = ["all", "small", "medium", "large"]
166
+ phraseApEval.params.useCats = use_cats
167
+
168
+ phraseApEval.evaluate()
169
+ phraseApEval.accumulate()
170
+ phraseApEval.summarize()
171
+ result_prefix = f"{self.dataset_name}"
172
+ result_prefix += f"_{'mask' if iou_type == 'segm' else 'bbox'}_phrase_ap"
173
+ # fetch Phrase AP results from the corresponding indices in `phraseApEval.stats`
174
+ # (see `_summarizeDets` in https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py)
175
+ results[result_prefix + "_50_95"] = phraseApEval.stats[0] # IoU=0.5:0.95
176
+ results[result_prefix + "_50"] = phraseApEval.stats[1] # IoU=0.5
177
+ results[result_prefix + "_75"] = phraseApEval.stats[2] # IoU=0.75
178
+
179
+ # video-NP level results not supported for `VideoPhraseApEvaluator` yet
180
+ video_np_level_results = {}
181
+ return results, video_np_level_results
182
+
183
+
184
+ class VideoCGF1Evaluator(BasePredFileEvaluator):
185
+ """Evaluate Video Demo F1 with YT-VIS format prediction and GT files."""
186
+
187
+ def __init__(
188
+ self,
189
+ gt_ann_file: str,
190
+ dataset_name: str = "video",
191
+ prob_thresh: float = 0.5,
192
+ iou_types: Optional[Sequence[str]] = None,
193
+ ):
194
+ self.gt_ann_file = gt_ann_file
195
+ self.dataset_name = dataset_name
196
+ self.prob_thresh = prob_thresh
197
+ self.iou_types = list(iou_types) if iou_types is not None else ["bbox", "segm"]
198
+ assert all(iou_type in ["bbox", "segm"] for iou_type in self.iou_types)
199
+
200
+ def evaluate(self, pred_file: str) -> Dict[str, float]:
201
+ with open(self.gt_ann_file) as f:
202
+ gt = json.load(f)
203
+ with open(pred_file) as f:
204
+ dt = json.load(f)
205
+ # compute IL_MCC and CG-F1 can only be computed if we have "video_np_pairs" keys in the GT JSON
206
+ compute_ilmcc_and_cgf1 = "video_np_pairs" in gt
207
+ if not compute_ilmcc_and_cgf1:
208
+ print(
209
+ f"Warning: IL_MCC and CG-F1 are not computed for {pred_file=} as it does not have 'video_np_pairs' keys in the GT JSON"
210
+ )
211
+ # For phrase AP and demo F1 evaluation, we need to remap each pair of (video_id, category_id) to
212
+ # a new unique video_id, so that we don't mix detections from different categories under `useCat=False`
213
+ gt, dt = remap_video_category_pairs_to_unique_video_ids(
214
+ gt, dt, add_negative_np_pairs=compute_ilmcc_and_cgf1
215
+ )
216
+ if "segm" in self.iou_types:
217
+ for ann in gt["annotations"]:
218
+ ann["segmentations"] = [
219
+ _compress_rle(rle) for rle in ann["segmentations"]
220
+ ]
221
+ for d in dt:
222
+ d["image_id"] = d["video_id"]
223
+
224
+ results = {}
225
+ use_cats = False # Demo F1 evaluation does not use categories
226
+ ytvisGT = YTVIS(annotation_file=None, ignore_gt_cats=not use_cats)
227
+ ytvisGT.dataset = gt
228
+ ytvisGT.createIndex()
229
+ ytvisDT = ytvisGT.loadRes(dt)
230
+
231
+ video_np_level_results = {}
232
+ for iou_type in self.iou_types:
233
+ demoF1Eval = VideoDemoF1Eval(ytvisGT, ytvisDT, iou_type, self.prob_thresh)
234
+
235
+ demoF1Eval.params.useCats = use_cats
236
+ demoF1Eval.params.areaRng = [[0**2, 1e5**2]]
237
+ demoF1Eval.params.areaRngLbl = ["all"]
238
+ demoF1Eval.params.maxDets = [100000]
239
+
240
+ demoF1Eval.evaluate()
241
+ demoF1Eval.accumulate()
242
+ demoF1Eval.summarize()
243
+ result_prefix = f"{self.dataset_name}"
244
+ result_prefix += f"_{'mask' if iou_type == 'segm' else 'bbox'}_demo"
245
+
246
+ stats = demoF1Eval.stats
247
+
248
+ if compute_ilmcc_and_cgf1:
249
+ # Average IoU threshold (0.5:0.95)
250
+ cgf1_micro_avg_idx = _get_metric_index("cgF1", None)
251
+ positive_micro_f1_avg_idx = _get_metric_index("positive_micro_F1", None)
252
+ ilmcc_avg_idx = _get_metric_index("IL_MCC", None)
253
+ results[result_prefix + "_cgf1_micro_50_95"] = stats[cgf1_micro_avg_idx]
254
+ results[result_prefix + "_ilmcc_50_95"] = stats[ilmcc_avg_idx]
255
+ results[result_prefix + "_positive_micro_f1_50_95"] = stats[
256
+ positive_micro_f1_avg_idx
257
+ ]
258
+
259
+ # IoU = 0.5
260
+ cgf1_micro_50_idx = _get_metric_index("cgF1", 0.5)
261
+ positive_micro_f1_50_idx = _get_metric_index("positive_micro_F1", 0.5)
262
+ results[result_prefix + "_cgf1_micro_50"] = stats[cgf1_micro_50_idx]
263
+ results[result_prefix + "_ilmcc_50"] = float(
264
+ np.array(stats[cgf1_micro_50_idx])
265
+ / np.array(stats[positive_micro_f1_50_idx])
266
+ )
267
+ results[result_prefix + "_positive_micro_f1_50"] = stats[
268
+ positive_micro_f1_50_idx
269
+ ]
270
+
271
+ # IoU = 0.75
272
+ cgf1_micro_75_idx = _get_metric_index("cgF1", 0.75)
273
+ positive_micro_f1_75_idx = _get_metric_index("positive_micro_F1", 0.75)
274
+ results[result_prefix + "_cgf1_micro_75"] = stats[cgf1_micro_75_idx]
275
+ results[result_prefix + "_ilmcc_75"] = float(
276
+ np.array(stats[cgf1_micro_75_idx])
277
+ / np.array(stats[positive_micro_f1_75_idx])
278
+ )
279
+ results[result_prefix + "_positive_micro_f1_75"] = stats[
280
+ positive_micro_f1_75_idx
281
+ ]
282
+
283
+ self.extract_video_np_level_results(demoF1Eval, video_np_level_results)
284
+
285
+ return results, video_np_level_results
286
+
287
+ def extract_video_np_level_results(self, demoF1Eval, video_np_level_results):
288
+ """Aggregate statistics for video-level metrics."""
289
+ num_iou_thrs = len(demoF1Eval.params.iouThrs)
290
+ iou_50_index = int(np.where(demoF1Eval.params.iouThrs == 0.5)[0])
291
+ iou_75_index = int(np.where(demoF1Eval.params.iouThrs == 0.75)[0])
292
+
293
+ result_prefix = "mask" if demoF1Eval.params.iouType == "segm" else "bbox"
294
+
295
+ assert len(demoF1Eval.evalImgs) == len(demoF1Eval.cocoGt.dataset["images"])
296
+ for i, video in enumerate(demoF1Eval.cocoGt.dataset["images"]):
297
+ # the original video id and category id before remapping
298
+ video_id = video["orig_video_id"]
299
+ category_id = video["orig_category_id"]
300
+ eval_img_dict = demoF1Eval.evalImgs[i]
301
+
302
+ TPs = eval_img_dict.get("TPs", np.zeros(num_iou_thrs, dtype=np.int64))
303
+ FPs = eval_img_dict.get("FPs", np.zeros(num_iou_thrs, dtype=np.int64))
304
+ FNs = eval_img_dict.get("FNs", np.zeros(num_iou_thrs, dtype=np.int64))
305
+ assert len(TPs) == len(FPs) == len(FNs) == num_iou_thrs
306
+ # F1 = 2*TP / (2*TP + FP + FN), and we set F1 to 1.0 if denominator is 0
307
+ denominator = 2 * TPs + FPs + FNs
308
+ F1s = np.where(denominator > 0, 2 * TPs / np.maximum(denominator, 1), 1.0)
309
+ local_results = {
310
+ f"{result_prefix}_TP_50_95": float(TPs.mean()),
311
+ f"{result_prefix}_FP_50_95": float(FPs.mean()),
312
+ f"{result_prefix}_FN_50_95": float(FNs.mean()),
313
+ f"{result_prefix}_F1_50_95": float(F1s.mean()),
314
+ f"{result_prefix}_TP_50": float(TPs[iou_50_index]),
315
+ f"{result_prefix}_FP_50": float(FPs[iou_50_index]),
316
+ f"{result_prefix}_FN_50": float(FNs[iou_50_index]),
317
+ f"{result_prefix}_F1_50": float(F1s[iou_50_index]),
318
+ f"{result_prefix}_TP_75": float(TPs[iou_75_index]),
319
+ f"{result_prefix}_FP_75": float(FPs[iou_75_index]),
320
+ f"{result_prefix}_FN_75": float(FNs[iou_75_index]),
321
+ f"{result_prefix}_F1_75": float(F1s[iou_75_index]),
322
+ }
323
+ if (video_id, category_id) not in video_np_level_results:
324
+ video_np_level_results[(video_id, category_id)] = {}
325
+ video_np_level_results[(video_id, category_id)].update(local_results)
326
+
327
+
328
+ class VideoTetaEvaluator(BasePredFileEvaluator):
329
+ """Evaluate TETA metric using YouTubeVIS format prediction and GT files."""
330
+
331
+ def __init__(
332
+ self,
333
+ gt_ann_file: str,
334
+ dataset_name: str = "video",
335
+ tracker_name: str = "Sam3",
336
+ nms_threshold: float = 0.5,
337
+ nms_strategy: str = "none", # "track", "frame", or "none"
338
+ prob_thresh: float = 0.5,
339
+ is_exhaustive: bool = False,
340
+ use_mask: bool = False,
341
+ num_parallel_cores: int = 8,
342
+ ):
343
+ self.gt_ann_file = gt_ann_file
344
+ self.dataset_name = dataset_name
345
+ self.tracker_name = tracker_name
346
+ self.nms_threshold = nms_threshold
347
+ self.nms_strategy = nms_strategy.lower() # Convert to lowercase for consistency
348
+ self.prob_thresh = prob_thresh
349
+ self.metric_prefix = "TETA"
350
+ self.is_exhaustive = is_exhaustive
351
+ self.use_mask = use_mask
352
+ self.num_parallel_cores = num_parallel_cores
353
+
354
+ # Verify NMS strategy is valid
355
+ valid_strategies = ["track", "frame", "none"]
356
+ print("current nms_strategy:", self.nms_strategy)
357
+ if self.nms_strategy not in valid_strategies:
358
+ raise ValueError(
359
+ f"Invalid NMS strategy: {self.nms_strategy}. Must be one of {valid_strategies}"
360
+ )
361
+
362
+ print(f"Initialized VideoTetaEvaluator with NMS strategy: {self.nms_strategy}")
363
+ print(f"Probability threshold set to: {self.prob_thresh}")
364
+ print(f"Dataset exhaustivity set to: {self.is_exhaustive}")
365
+ print(f"Tracker name set to: {self.tracker_name}")
366
+ print(f"Dataset name set to: {self.dataset_name}")
367
+ print(f"Use mask set to: {self.use_mask}")
368
+
369
+ def process_predictions(self, pred_file: str, tmp_dir: str) -> str:
370
+ """Process predictions with selected NMS strategy"""
371
+ with open(pred_file, "r") as f:
372
+ raw_preds = json.load(f)
373
+ print(f"Processing predictions with {self.nms_strategy} NMS strategy")
374
+
375
+ # Filter by score threshold
376
+ if self.prob_thresh > 0:
377
+ raw_preds = [d for d in raw_preds if d["score"] >= self.prob_thresh]
378
+ print(
379
+ f"Filtered to {len(raw_preds)} predictions with score >= {self.prob_thresh}"
380
+ )
381
+ # Group predictions by video_id
382
+ video_groups = defaultdict(list)
383
+ for pred in raw_preds:
384
+ video_groups[pred["video_id"]].append(pred)
385
+ # Process based on NMS strategy
386
+ if self.nms_strategy == "track":
387
+ process_track_level_nms(video_groups, nms_threshold=self.nms_threshold)
388
+ elif self.nms_strategy == "frame":
389
+ process_frame_level_nms(video_groups, nms_threshold=self.nms_threshold)
390
+ elif self.nms_strategy == "none":
391
+ print("Skipping NMS processing as strategy is set to 'none'")
392
+ # No processing needed for "none" strategy
393
+ # Save processed predictions
394
+ processed_preds = [
395
+ track for tracks in video_groups.values() for track in tracks
396
+ ]
397
+ processed_path = os.path.join(tmp_dir, "processed_preds.json")
398
+ with open(processed_path, "w") as f:
399
+ json.dump(processed_preds, f)
400
+
401
+ print(f"Saved processed predictions to {processed_path}")
402
+ return processed_path
403
+
404
+ def evaluate(self, pred_file: str) -> Tuple[Dict[str, float], Dict]:
405
+ """Main evaluation method"""
406
+
407
+ print(f"Evaluating TETA Metric with {self.nms_strategy.upper()} NMS strategy")
408
+ with tempfile.TemporaryDirectory() as tmp_dir:
409
+ # Process predictions first
410
+ processed_pred_file = self.process_predictions(pred_file, tmp_dir)
411
+
412
+ # Convert GT to COCO-vid format
413
+ gt_dir = os.path.join(tmp_dir, "gt")
414
+ os.makedirs(gt_dir, exist_ok=True)
415
+ gt_coco_path = os.path.join(gt_dir, "annotations.json")
416
+ convert_ytbvis_to_cocovid_gt(self.gt_ann_file, gt_coco_path)
417
+
418
+ # Convert processed predictions to COCO-vid format
419
+ pred_dir = os.path.join(tmp_dir, "predictions")
420
+ tracker_dir = os.path.join(pred_dir, self.tracker_name)
421
+ os.makedirs(tracker_dir, exist_ok=True)
422
+ pred_coco_path = os.path.join(tracker_dir, "track_results_cocofmt.json")
423
+ convert_ytbvis_to_cocovid_pred(
424
+ youtubevis_pred_path=processed_pred_file,
425
+ converted_dataset_path=gt_coco_path,
426
+ output_path=pred_coco_path,
427
+ )
428
+ # Configure TETA evaluator
429
+ default_eval_config = config.get_default_eval_config()
430
+ default_eval_config["PRINT_ONLY_COMBINED"] = True
431
+ default_eval_config["DISPLAY_LESS_PROGRESS"] = True
432
+ default_eval_config["OUTPUT_TEMP_RAW_DATA"] = True
433
+ default_eval_config["NUM_PARALLEL_CORES"] = self.num_parallel_cores
434
+ default_dataset_config = config.get_default_dataset_config()
435
+ default_dataset_config["TRACKERS_TO_EVAL"] = [self.tracker_name]
436
+ default_dataset_config["GT_FOLDER"] = gt_dir
437
+ default_dataset_config["OUTPUT_FOLDER"] = pred_dir
438
+ default_dataset_config["TRACKER_SUB_FOLDER"] = tracker_dir
439
+ default_dataset_config["USE_MASK"] = self.use_mask
440
+
441
+ evaluator = Evaluator(default_eval_config)
442
+ if self.is_exhaustive:
443
+ dataset_list = [COCO(default_dataset_config)]
444
+ dataset_parsing_key = "COCO"
445
+ else:
446
+ dataset_list = [TAO(default_dataset_config)]
447
+ dataset_parsing_key = "TAO"
448
+
449
+ # Run evaluation
450
+ eval_results, _ = evaluator.evaluate(
451
+ dataset_list, [metrics.TETA(exhaustive=self.is_exhaustive)]
452
+ )
453
+
454
+ # Extract and format results
455
+ results = {
456
+ f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_teta": float(
457
+ eval_results[dataset_parsing_key]["TETA"][0]
458
+ ),
459
+ f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_loc_a": float(
460
+ eval_results[dataset_parsing_key]["TETA"][1]
461
+ ),
462
+ f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_assoc_a": float(
463
+ eval_results[dataset_parsing_key]["TETA"][2]
464
+ ),
465
+ f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_cls_a": float(
466
+ eval_results[dataset_parsing_key]["TETA"][3]
467
+ ),
468
+ f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_loc_re": float(
469
+ eval_results[dataset_parsing_key]["TETA"][4]
470
+ ),
471
+ f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_loc_pr": float(
472
+ eval_results[dataset_parsing_key]["TETA"][5]
473
+ ),
474
+ f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_assoc_re": float(
475
+ eval_results[dataset_parsing_key]["TETA"][6]
476
+ ),
477
+ f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_assoc_pr": float(
478
+ eval_results[dataset_parsing_key]["TETA"][7]
479
+ ),
480
+ f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_cls_re": float(
481
+ eval_results[dataset_parsing_key]["TETA"][8]
482
+ ),
483
+ f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_cls_pr": float(
484
+ eval_results[dataset_parsing_key]["TETA"][9]
485
+ ),
486
+ }
487
+
488
+ # video-NP level results not supported for `VideoTetaEvaluator` yet
489
+ video_np_level_results = {}
490
+ return results, video_np_level_results
491
+
492
+
493
+ class VideoPhraseHotaEvaluator(BasePredFileEvaluator):
494
+ """Evaluate Video Phrase HOTA with YT-VIS format prediction and GT files."""
495
+
496
+ def __init__(
497
+ self,
498
+ gt_ann_file: str,
499
+ dataset_name: str = "video",
500
+ prob_thresh: float = 0.5,
501
+ iou_types: Optional[Sequence[str]] = None,
502
+ compute_video_mot_hota: bool = False,
503
+ ):
504
+ self.gt_ann_file = gt_ann_file
505
+ self.dataset_name = dataset_name
506
+ self.prob_thresh = prob_thresh
507
+ self.metric_prefix = "phrase"
508
+ # the list of metrics to collect from the HOTA evaluation results
509
+ self.metric_to_collect = [
510
+ "HOTA",
511
+ "DetA",
512
+ "AssA",
513
+ "DetRe",
514
+ "DetPr",
515
+ "AssRe",
516
+ "AssPr",
517
+ "LocA",
518
+ "OWTA",
519
+ ]
520
+ self.iou_types = list(iou_types) if iou_types is not None else ["bbox", "segm"]
521
+ assert all(iou_type in ["bbox", "segm"] for iou_type in self.iou_types)
522
+
523
+ # If True, compute video MOT HOTA, aggregating predictions/GT from all categories.
524
+ self.compute_video_mot_hota = compute_video_mot_hota
525
+
526
+ def evaluate(self, pred_file: str) -> Dict[str, float]:
527
+ # use the YT-VIS evaluation toolkit in TrackEval
528
+
529
+ with open(self.gt_ann_file) as f:
530
+ gt = json.load(f)
531
+ with open(pred_file) as f:
532
+ dt = json.load(f)
533
+ # keep only predictions with score above the probability threshold
534
+ dt = [d for d in dt if d["score"] > self.prob_thresh]
535
+ for d in dt:
536
+ assert len(d["areas"]) == len(d["bboxes"])
537
+ assert len(d["areas"]) == len(d["segmentations"])
538
+ # remove empty boxes (otherwise they will count as false positives for during
539
+ # per-frame detection accuracy in HOTA evaluation)
540
+ for t in range(len(d["bboxes"])):
541
+ bbox = d["bboxes"][t]
542
+ if d["areas"][t] == 0 or bbox is None or all(x == 0 for x in bbox):
543
+ d["segmentations"][t] = None
544
+ d["bboxes"][t] = None
545
+ d["areas"][t] = None
546
+ # check that box occurence and mask occurence are consistent
547
+ for bbox, mask, area in zip(d["bboxes"], d["segmentations"], d["areas"]):
548
+ assert (area is None) == (bbox is None)
549
+ assert (area is None) == (mask is None)
550
+ # set all scores to 1.0 for HOTA evaluation (just like Demo F1, the exact score
551
+ # value is not used in HOTA metrics; it will be treated as a detection prediction
552
+ # as long as its score is above the threshold)
553
+ d["score"] = 1.0
554
+
555
+ # remap the GT and DT annotations for phrase HOTA evaluation
556
+ gt = _fill_in_ann_height_width(gt)
557
+ if not self.compute_video_mot_hota:
558
+ # remap the GT and DT annotations for phrase HOTA evaluation
559
+ gt, dt = self._remap_gt_dt(gt, dt)
560
+ else:
561
+ # Compute video-level MOT HOTA
562
+ # Apply track-level NMS
563
+ video_groups = defaultdict(list)
564
+ for pred in dt:
565
+ video_groups[pred["video_id"]].append(pred)
566
+ process_track_level_nms(video_groups, nms_threshold=0.5)
567
+ dt = [track for tracks in video_groups.values() for track in tracks]
568
+
569
+ # Remap GT track ids for class-agnostic HOTA
570
+ gt, dt = remap_gt_dt_class_agnostic(gt, dt)
571
+
572
+ # run the HOTA evaluation using TrackEval on the remapped (video_id, category_id) pairs
573
+ out_dict = {}
574
+ video_np_level_results = {}
575
+ for iou_type in self.iou_types:
576
+ output_res, _ = run_ytvis_eval(
577
+ args=[
578
+ "--METRICS",
579
+ "HOTA",
580
+ "--IOU_TYPE",
581
+ iou_type,
582
+ "--DATASET_NAME",
583
+ self.dataset_name,
584
+ "--USE_PARALLEL",
585
+ "True",
586
+ "--NUM_PARALLEL_CORES",
587
+ "8",
588
+ "--PLOT_CURVES",
589
+ "False",
590
+ "--LOG_ON_ERROR",
591
+ "None",
592
+ "--PRINT_ONLY_COMBINED",
593
+ "True",
594
+ "--OUTPUT_SUMMARY",
595
+ "False",
596
+ "--OUTPUT_DETAILED",
597
+ "False",
598
+ "--TIME_PROGRESS",
599
+ "False",
600
+ "--PRINT_CONFIG",
601
+ "False",
602
+ ],
603
+ gt_json=gt,
604
+ dt_json=dt,
605
+ )
606
+ self.extract_video_np_level_results(
607
+ iou_type=iou_type,
608
+ remapped_gt=gt,
609
+ raw_results=output_res[self.dataset_name]["tracker"],
610
+ video_np_level_results=video_np_level_results,
611
+ )
612
+
613
+ def _summarize_results(output_res, iou_type, field, suffix):
614
+ eval_res = output_res[self.dataset_name]["tracker"][field]
615
+ result_prefix = f"{self.dataset_name}_{'mask' if iou_type == 'segm' else 'bbox'}_{suffix}"
616
+ for metric_name in self.metric_to_collect:
617
+ eval_res_hota = eval_res["cls_comb_cls_av"]["HOTA"]
618
+ result_key = f"{result_prefix}_{self.metric_prefix}_{metric_name}"
619
+ result_value = float(np.mean(eval_res_hota[metric_name]))
620
+ out_dict[result_key] = result_value
621
+
622
+ _summarize_results(output_res, iou_type, "COMBINED_SEQ", "all")
623
+ if "COMBINED_SEQ_CHALLENGING" in output_res[self.dataset_name]["tracker"]:
624
+ _summarize_results(
625
+ output_res, iou_type, "COMBINED_SEQ_CHALLENGING", "challenging"
626
+ )
627
+
628
+ # video-NP level results not supported for `VideoPhraseHotaEvaluator` yet
629
+ return out_dict, video_np_level_results
630
+
631
+ def _remap_gt_dt(self, gt, dt):
632
+ # For phrase HOTA evaluation, we need to remap each pair of (video_id, category_id) to
633
+ # a new unique video_id, so that we don't mix detections from different categories
634
+ gt, dt = remap_video_category_pairs_to_unique_video_ids(gt, dt)
635
+ # We further map all the categories to category_id=1 in HOTA evaluation toolkit
636
+ # for phrase HOTA (similar to "useCat=False" for video phrase AP)
637
+ remapped_category_id = 1
638
+ gt["categories"] = [
639
+ {
640
+ "supercategory": "object",
641
+ "id": remapped_category_id,
642
+ "name": "_REMAPPED_FOR_PHRASE_METRICS_",
643
+ }
644
+ ]
645
+ for ann in gt["annotations"]:
646
+ ann["category_id"] = remapped_category_id
647
+ for d in dt:
648
+ d["category_id"] = remapped_category_id
649
+ # To be compatible with the TrackEval YT-VIS evaluation toolkit, we need to give
650
+ # unique filenames to each remapped video, so we add remapped video_id as prefix.
651
+ for video in gt["videos"]:
652
+ new_video_id = video["id"]
653
+ video["file_names"] = [
654
+ f"remapped_vid_{new_video_id:012d}/{name}"
655
+ for name in video["file_names"]
656
+ ]
657
+ return gt, dt
658
+
659
+ def extract_video_np_level_results(
660
+ self, iou_type, remapped_gt, raw_results, video_np_level_results
661
+ ):
662
+ """Aggregate statistics for video-level metrics."""
663
+ result_prefix = "mask" if iou_type == "segm" else "bbox"
664
+ for video in remapped_gt["videos"]:
665
+ # the original video id and category id before remapping
666
+ video_id = video["orig_video_id"]
667
+ category_id = video["orig_category_id"]
668
+ video_key = f"remapped_vid_{video['id']:012d}"
669
+ results = raw_results[video_key]["_REMAPPED_FOR_PHRASE_METRICS_"]["HOTA"]
670
+
671
+ local_results = {}
672
+ for metric_name in self.metric_to_collect:
673
+ result_key = f"{result_prefix}_{metric_name}"
674
+ local_results[result_key] = float(results[metric_name].mean())
675
+ if (video_id, category_id) not in video_np_level_results:
676
+ video_np_level_results[(video_id, category_id)] = {}
677
+ video_np_level_results[(video_id, category_id)].update(local_results)
678
+
679
+
680
+ class VideoClassBasedHotaEvaluator(VideoPhraseHotaEvaluator):
681
+ def __init__(
682
+ self,
683
+ gt_ann_file: str,
684
+ dataset_name: str = "video",
685
+ prob_thresh: float = 0.5,
686
+ ):
687
+ super().__init__(gt_ann_file, dataset_name, prob_thresh)
688
+ self.metric_prefix = "class"
689
+
690
+ def _remap_gt_dt(self, gt, dt):
691
+ return gt, dt # no remapping needed for class-based HOTA evaluation
692
+
693
+ def extract_video_np_level_results(self, *args, **kwargs):
694
+ pass # no video-NP level results for class-based HOTA evaluation
695
+
696
+
697
+ def _compress_rle(rle):
698
+ """Convert RLEs from uncompressed (integer list) to compressed (string) format."""
699
+ if rle is None:
700
+ return None
701
+ if isinstance(rle["counts"], list):
702
+ rle = pycocotools.mask.frPyObjects(rle, rle["size"][0], rle["size"][1])
703
+ rle["counts"] = rle["counts"].decode()
704
+ return rle
705
+
706
+
707
+ def remap_video_category_pairs_to_unique_video_ids(
708
+ gt_json, dt_json, add_negative_np_pairs=False
709
+ ):
710
+ """
711
+ Remap each pair of (video_id, category_id) to a new unique video_id. This is useful
712
+ for phrase AP and demo F1 evaluation on videos, where we have `useCat=False` and
713
+ rely on separating different NPs (from the same video) into different new video ids,
714
+ so that we don't mix detections from different categories in computeIoU under `useCat=False`.
715
+
716
+ This is consistent with how do we phrase AP and demo F1 evaluation on images, where we
717
+ use a remapped unique coco_image_id for each image-NP pair (based in its query["id"] in
718
+ CustomCocoDetectionAPI.load_queries in modulated_detection_api.py)
719
+ """
720
+ # collect the unique video_id-category_id pairs
721
+ video_id_to_video = {v["id"]: v for v in gt_json["videos"]}
722
+ video_id_category_id_pairs = set()
723
+ for pred in dt_json:
724
+ video_id_category_id_pairs.add((pred["video_id"], pred["category_id"]))
725
+ for ann in gt_json["annotations"]:
726
+ video_id_category_id_pairs.add((ann["video_id"], ann["category_id"]))
727
+
728
+ # assign the video_id-category_id pairs to unique video ids
729
+ video_id_category_id_pairs = sorted(video_id_category_id_pairs)
730
+ video_id_category_id_to_new_video_id = {
731
+ pair: (i + 1) for i, pair in enumerate(video_id_category_id_pairs)
732
+ }
733
+ # also map the negative NP pairs -- this is needed for IL_MCC and CG-F1 evaluation
734
+ if add_negative_np_pairs:
735
+ for vnp in gt_json["video_np_pairs"]:
736
+ pair = (vnp["video_id"], vnp["category_id"])
737
+ if pair not in video_id_category_id_to_new_video_id:
738
+ video_id_category_id_to_new_video_id[pair] = (
739
+ len(video_id_category_id_to_new_video_id) + 1
740
+ )
741
+
742
+ # map the "video_id" in predictions
743
+ for pred in dt_json:
744
+ pred["video_id"] = video_id_category_id_to_new_video_id[
745
+ (pred["video_id"], pred["category_id"])
746
+ ]
747
+ # map the "video_id" in gt_json["annotations"]
748
+ for ann in gt_json["annotations"]:
749
+ ann["video_id"] = video_id_category_id_to_new_video_id[
750
+ (ann["video_id"], ann["category_id"])
751
+ ]
752
+ # map and duplicate gt_json["videos"]
753
+ new_videos = []
754
+ for (
755
+ video_id,
756
+ category_id,
757
+ ), new_video_id in video_id_category_id_to_new_video_id.items():
758
+ video = video_id_to_video[video_id].copy()
759
+ video["id"] = new_video_id
760
+ # preserve the original video_id and category_id of each remapped video entry,
761
+ # so that we can associate sample-level eval metrics with the original video-NP pairs
762
+ video["orig_video_id"] = video_id
763
+ video["orig_category_id"] = category_id
764
+ new_videos.append(video)
765
+ gt_json["videos"] = new_videos
766
+
767
+ return gt_json, dt_json
768
+
769
+
770
+ def remap_gt_dt_class_agnostic(gt, dt):
771
+ """
772
+ For class-agnostic HOTA, merge all GT tracks for each video (across NPs),
773
+ ensure unique track_ids, and set all category_id to 1.
774
+ Also, add orig_video_id and orig_category_id for compatibility.
775
+ """
776
+ # 1. Remap all GT track_ids to be unique per video
777
+ gt_anns_by_video = defaultdict(list)
778
+ for ann in gt["annotations"]:
779
+ gt_anns_by_video[ann["video_id"]].append(ann)
780
+
781
+ # Ensure unique track ids across tracks of all videos
782
+ next_tid = 1
783
+ for _, anns in gt_anns_by_video.items():
784
+ # Map old track_ids to new unique ones
785
+ old_to_new_tid = {}
786
+ for ann in anns:
787
+ old_tid = ann["id"]
788
+ if old_tid not in old_to_new_tid:
789
+ old_to_new_tid[old_tid] = next_tid
790
+ next_tid += 1
791
+ ann["id"] = old_to_new_tid[old_tid]
792
+ # Set category_id to 1 for class-agnostic
793
+ ann["category_id"] = 1
794
+
795
+ # Set all GT categories to a single category
796
+ gt["categories"] = [
797
+ {
798
+ "supercategory": "object",
799
+ "id": 1,
800
+ "name": "_REMAPPED_FOR_PHRASE_METRICS_",
801
+ }
802
+ ]
803
+
804
+ # Add orig_video_id and orig_category_id to each video for compatibility
805
+ anns_by_video = defaultdict(list)
806
+ for ann in gt["annotations"]:
807
+ anns_by_video[ann["video_id"]].append(ann)
808
+ for video in gt["videos"]:
809
+ video["orig_video_id"] = video["id"]
810
+ # Use the first annotation's original category_id if available, else None
811
+ orig_cat = (
812
+ anns_by_video[video["id"]][0]["category_id"]
813
+ if anns_by_video[video["id"]]
814
+ else None
815
+ )
816
+ video["orig_category_id"] = orig_cat
817
+ video["file_names"] = [
818
+ f"remapped_vid_{video['id']:012d}/{name}" for name in video["file_names"]
819
+ ]
820
+
821
+ # Set all DT category_id to 1
822
+ for d in dt:
823
+ d["category_id"] = 1
824
+ return gt, dt
825
+
826
+
827
+ def _fill_in_ann_height_width(gt_json):
828
+ """Fill in missing height/width in GT annotations from its video info."""
829
+ video_id_to_video = {v["id"]: v for v in gt_json["videos"]}
830
+ for ann in gt_json["annotations"]:
831
+ if "height" not in ann or "width" not in ann:
832
+ video = video_id_to_video[ann["video_id"]]
833
+ if "height" not in ann:
834
+ ann["height"] = video["height"]
835
+ if "width" not in ann:
836
+ ann["width"] = video["width"]
837
+
838
+ return gt_json
sam3/eval/teta_eval_toolkit/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # fmt: off
2
+ # flake8: noqa
3
+
4
+ from . import config, datasets, metrics, utils
5
+ from .eval import Evaluator
sam3/eval/teta_eval_toolkit/_timing.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # fmt: off
2
+ # flake8: noqa
3
+
4
+ import inspect
5
+ from functools import wraps
6
+ from time import perf_counter
7
+
8
+ DO_TIMING = False
9
+ DISPLAY_LESS_PROGRESS = False
10
+ timer_dict = {}
11
+ counter = 0
12
+
13
+
14
+ def time(f):
15
+ @wraps(f)
16
+ def wrap(*args, **kw):
17
+ if DO_TIMING:
18
+ # Run function with timing
19
+ ts = perf_counter()
20
+ result = f(*args, **kw)
21
+ te = perf_counter()
22
+ tt = te - ts
23
+
24
+ # Get function name
25
+ arg_names = inspect.getfullargspec(f)[0]
26
+ if arg_names[0] == "self" and DISPLAY_LESS_PROGRESS:
27
+ return result
28
+ elif arg_names[0] == "self":
29
+ method_name = type(args[0]).__name__ + "." + f.__name__
30
+ else:
31
+ method_name = f.__name__
32
+
33
+ # Record accumulative time in each function for analysis
34
+ if method_name in timer_dict.keys():
35
+ timer_dict[method_name] += tt
36
+ else:
37
+ timer_dict[method_name] = tt
38
+
39
+ # If code is finished, display timing summary
40
+ if method_name == "Evaluator.evaluate":
41
+ print("")
42
+ print("Timing analysis:")
43
+ for key, value in timer_dict.items():
44
+ print("%-70s %2.4f sec" % (key, value))
45
+ else:
46
+ # Get function argument values for printing special arguments of interest
47
+ arg_titles = ["tracker", "seq", "cls"]
48
+ arg_vals = []
49
+ for i, a in enumerate(arg_names):
50
+ if a in arg_titles:
51
+ arg_vals.append(args[i])
52
+ arg_text = "(" + ", ".join(arg_vals) + ")"
53
+
54
+ # Display methods and functions with different indentation.
55
+ if arg_names[0] == "self":
56
+ print("%-74s %2.4f sec" % (" " * 4 + method_name + arg_text, tt))
57
+ elif arg_names[0] == "test":
58
+ pass
59
+ else:
60
+ global counter
61
+ counter += 1
62
+ print("%i %-70s %2.4f sec" % (counter, method_name + arg_text, tt))
63
+
64
+ return result
65
+ else:
66
+ # If config["TIME_PROGRESS"] is false, or config["USE_PARALLEL"] is true, run functions normally without timing.
67
+ return f(*args, **kw)
68
+
69
+ return wrap