Stephentao-30 commited on
Commit
069a1f6
·
1 Parent(s): d74d0ca

Interaction view: always number patches 1-16, even if upstream rename skipped

Browse files
visualization/plotting/benchmark_interaction.py CHANGED
@@ -91,35 +91,43 @@ def create_benchmark_interaction_html(
91
  n_segs = 0
92
 
93
  if clip_summary:
94
- for item in clip_summary.get("image_region_values", []):
95
- # Use the actual segment number so the index matches segment_bboxes
96
- # and the label-map canvas (always in color-order 0, 1, 2, …).
 
 
 
 
 
 
 
97
  # seg_6 → 6 (UnSAM)
98
  # patch_1_2 → 1*grid+2 (raw patch-grid; assumes grid=sqrt(n))
99
  # "7" → 7 - 1 = 6 (post-rename patch-grid)
100
- label = item["label"]
101
  seg_num = n_segs # fallback: sequential
102
- if label.startswith("seg_"):
103
  try:
104
- seg_num = int(label.split("_", 1)[1])
105
  except (ValueError, IndexError):
106
  pass
107
- elif label.startswith("patch_"):
108
  try:
109
- _, r_str, c_str = label.split("_", 2)
110
- total_regions = len(clip_summary.get("image_region_values", []))
111
- grid = int(round(total_regions ** 0.5)) or 4
112
- seg_num = int(r_str) * grid + int(c_str)
113
  except (ValueError, IndexError):
114
  pass
115
- elif label.isdigit():
116
  try:
117
- seg_num = int(label) - 1
118
  except ValueError:
119
  pass
 
 
 
120
  regions.append({
121
  "index": seg_num,
122
- "label": label,
123
  "value": item["value"],
124
  "type": "segment",
125
  })
@@ -173,9 +181,22 @@ def create_benchmark_interaction_html(
173
  # Map subword token labels to whole caption words.
174
  from .medical_charts import _tok_to_word
175
  cross_source = all_cross_modal_pairs or clip_summary.get("cross_modal_interactions", [])
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  for item in cross_source:
177
  cross_interactions.append({
178
- "seg": item["pair"][0],
179
  "tok": _tok_to_word(item["pair"][1], caption) if caption else item["pair"][1].replace("tok:", "").lstrip("#"),
180
  "value": item["value"],
181
  })
 
91
  n_segs = 0
92
 
93
  if clip_summary:
94
+ raw_items = clip_summary.get("image_region_values", [])
95
+ total_regions = len(raw_items)
96
+ grid_guess = int(round(total_regions ** 0.5))
97
+ looks_like_patch_grid = (grid_guess * grid_guess == total_regions) and all(
98
+ str(it.get("label", "")).startswith("patch_")
99
+ or str(it.get("label", "")).isdigit()
100
+ for it in raw_items
101
+ )
102
+ for item in raw_items:
103
+ # Resolve the segment number:
104
  # seg_6 → 6 (UnSAM)
105
  # patch_1_2 → 1*grid+2 (raw patch-grid; assumes grid=sqrt(n))
106
  # "7" → 7 - 1 = 6 (post-rename patch-grid)
107
+ raw_label = str(item["label"])
108
  seg_num = n_segs # fallback: sequential
109
+ if raw_label.startswith("seg_"):
110
  try:
111
+ seg_num = int(raw_label.split("_", 1)[1])
112
  except (ValueError, IndexError):
113
  pass
114
+ elif raw_label.startswith("patch_"):
115
  try:
116
+ _, r_str, c_str = raw_label.split("_", 2)
117
+ seg_num = int(r_str) * grid_guess + int(c_str)
 
 
118
  except (ValueError, IndexError):
119
  pass
120
+ elif raw_label.isdigit():
121
  try:
122
+ seg_num = int(raw_label) - 1
123
  except ValueError:
124
  pass
125
+ # Display label: in patch-grid mode always show "1".."N" in reading
126
+ # order so the overlay doesn't leak raw "patch_r_c" text.
127
+ display_label = str(seg_num + 1) if looks_like_patch_grid else raw_label
128
  regions.append({
129
  "index": seg_num,
130
+ "label": display_label,
131
  "value": item["value"],
132
  "type": "segment",
133
  })
 
181
  # Map subword token labels to whole caption words.
182
  from .medical_charts import _tok_to_word
183
  cross_source = all_cross_modal_pairs or clip_summary.get("cross_modal_interactions", [])
184
+
185
+ def _seg_display(seg_raw: str) -> str:
186
+ # Normalize cross-pair segment labels the same way we normalized
187
+ # region labels above — otherwise arrows can't match regions.
188
+ s = str(seg_raw)
189
+ if looks_like_patch_grid and s.startswith("patch_"):
190
+ try:
191
+ _, rr, cc = s.split("_", 2)
192
+ return str(int(rr) * grid_guess + int(cc) + 1)
193
+ except (ValueError, IndexError):
194
+ return s
195
+ return s
196
+
197
  for item in cross_source:
198
  cross_interactions.append({
199
+ "seg": _seg_display(item["pair"][0]),
200
  "tok": _tok_to_word(item["pair"][1], caption) if caption else item["pair"][1].replace("tok:", "").lstrip("#"),
201
  "value": item["value"],
202
  })