Stephentao-30 commited on
Commit ·
069a1f6
1
Parent(s): d74d0ca
Interaction view: always number patches 1-16, even if upstream rename skipped
Browse files
visualization/plotting/benchmark_interaction.py
CHANGED
|
@@ -91,35 +91,43 @@ def create_benchmark_interaction_html(
|
|
| 91 |
n_segs = 0
|
| 92 |
|
| 93 |
if clip_summary:
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
# seg_6 → 6 (UnSAM)
|
| 98 |
# patch_1_2 → 1*grid+2 (raw patch-grid; assumes grid=sqrt(n))
|
| 99 |
# "7" → 7 - 1 = 6 (post-rename patch-grid)
|
| 100 |
-
|
| 101 |
seg_num = n_segs # fallback: sequential
|
| 102 |
-
if
|
| 103 |
try:
|
| 104 |
-
seg_num = int(
|
| 105 |
except (ValueError, IndexError):
|
| 106 |
pass
|
| 107 |
-
elif
|
| 108 |
try:
|
| 109 |
-
_, r_str, c_str =
|
| 110 |
-
|
| 111 |
-
grid = int(round(total_regions ** 0.5)) or 4
|
| 112 |
-
seg_num = int(r_str) * grid + int(c_str)
|
| 113 |
except (ValueError, IndexError):
|
| 114 |
pass
|
| 115 |
-
elif
|
| 116 |
try:
|
| 117 |
-
seg_num = int(
|
| 118 |
except ValueError:
|
| 119 |
pass
|
|
|
|
|
|
|
|
|
|
| 120 |
regions.append({
|
| 121 |
"index": seg_num,
|
| 122 |
-
"label":
|
| 123 |
"value": item["value"],
|
| 124 |
"type": "segment",
|
| 125 |
})
|
|
@@ -173,9 +181,22 @@ def create_benchmark_interaction_html(
|
|
| 173 |
# Map subword token labels to whole caption words.
|
| 174 |
from .medical_charts import _tok_to_word
|
| 175 |
cross_source = all_cross_modal_pairs or clip_summary.get("cross_modal_interactions", [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
for item in cross_source:
|
| 177 |
cross_interactions.append({
|
| 178 |
-
"seg": item["pair"][0],
|
| 179 |
"tok": _tok_to_word(item["pair"][1], caption) if caption else item["pair"][1].replace("tok:", "").lstrip("#"),
|
| 180 |
"value": item["value"],
|
| 181 |
})
|
|
|
|
| 91 |
n_segs = 0
|
| 92 |
|
| 93 |
if clip_summary:
|
| 94 |
+
raw_items = clip_summary.get("image_region_values", [])
|
| 95 |
+
total_regions = len(raw_items)
|
| 96 |
+
grid_guess = int(round(total_regions ** 0.5))
|
| 97 |
+
looks_like_patch_grid = (grid_guess * grid_guess == total_regions) and all(
|
| 98 |
+
str(it.get("label", "")).startswith("patch_")
|
| 99 |
+
or str(it.get("label", "")).isdigit()
|
| 100 |
+
for it in raw_items
|
| 101 |
+
)
|
| 102 |
+
for item in raw_items:
|
| 103 |
+
# Resolve the segment number:
|
| 104 |
# seg_6 → 6 (UnSAM)
|
| 105 |
# patch_1_2 → 1*grid+2 (raw patch-grid; assumes grid=sqrt(n))
|
| 106 |
# "7" → 7 - 1 = 6 (post-rename patch-grid)
|
| 107 |
+
raw_label = str(item["label"])
|
| 108 |
seg_num = n_segs # fallback: sequential
|
| 109 |
+
if raw_label.startswith("seg_"):
|
| 110 |
try:
|
| 111 |
+
seg_num = int(raw_label.split("_", 1)[1])
|
| 112 |
except (ValueError, IndexError):
|
| 113 |
pass
|
| 114 |
+
elif raw_label.startswith("patch_"):
|
| 115 |
try:
|
| 116 |
+
_, r_str, c_str = raw_label.split("_", 2)
|
| 117 |
+
seg_num = int(r_str) * grid_guess + int(c_str)
|
|
|
|
|
|
|
| 118 |
except (ValueError, IndexError):
|
| 119 |
pass
|
| 120 |
+
elif raw_label.isdigit():
|
| 121 |
try:
|
| 122 |
+
seg_num = int(raw_label) - 1
|
| 123 |
except ValueError:
|
| 124 |
pass
|
| 125 |
+
# Display label: in patch-grid mode always show "1".."N" in reading
|
| 126 |
+
# order so the overlay doesn't leak raw "patch_r_c" text.
|
| 127 |
+
display_label = str(seg_num + 1) if looks_like_patch_grid else raw_label
|
| 128 |
regions.append({
|
| 129 |
"index": seg_num,
|
| 130 |
+
"label": display_label,
|
| 131 |
"value": item["value"],
|
| 132 |
"type": "segment",
|
| 133 |
})
|
|
|
|
| 181 |
# Map subword token labels to whole caption words.
|
| 182 |
from .medical_charts import _tok_to_word
|
| 183 |
cross_source = all_cross_modal_pairs or clip_summary.get("cross_modal_interactions", [])
|
| 184 |
+
|
| 185 |
+
def _seg_display(seg_raw: str) -> str:
|
| 186 |
+
# Normalize cross-pair segment labels the same way we normalized
|
| 187 |
+
# region labels above — otherwise arrows can't match regions.
|
| 188 |
+
s = str(seg_raw)
|
| 189 |
+
if looks_like_patch_grid and s.startswith("patch_"):
|
| 190 |
+
try:
|
| 191 |
+
_, rr, cc = s.split("_", 2)
|
| 192 |
+
return str(int(rr) * grid_guess + int(cc) + 1)
|
| 193 |
+
except (ValueError, IndexError):
|
| 194 |
+
return s
|
| 195 |
+
return s
|
| 196 |
+
|
| 197 |
for item in cross_source:
|
| 198 |
cross_interactions.append({
|
| 199 |
+
"seg": _seg_display(item["pair"][0]),
|
| 200 |
"tok": _tok_to_word(item["pair"][1], caption) if caption else item["pair"][1].replace("tok:", "").lstrip("#"),
|
| 201 |
"value": item["value"],
|
| 202 |
})
|