griddev commited on
Commit
ce25d0a
·
verified ·
1 Parent(s): 7c69cda

Deploy Streamlit Space app

Browse files
Files changed (3) hide show
  1. app.py +235 -3
  2. models/attention_flow.py +328 -0
  3. requirements.txt +2 -0
app.py CHANGED
@@ -14,6 +14,7 @@ Features:
14
  import os
15
  import warnings
16
  import torch
 
17
  import streamlit as st
18
  from PIL import Image
19
  from models.blip_tuner import generate_with_mask
@@ -476,6 +477,42 @@ def load_toxicity_filter():
476
  return tok, mdl
477
 
478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  # ─────────────────────────────────────────────────────────────────────────────
480
  # Toxicity Check
481
  # ─────────────────────────────────────────────────────────────────────────────
@@ -744,8 +781,8 @@ def render_caption_card(model_name, caption, weight_src, num_beams, length_penal
744
  # Tabs
745
  # ─────────────────────────────────────────────────────────────────────────────
746
 
747
- tab_caption, tab_compare, tab_results = st.tabs([
748
- "🖼️ Caption", "🔀 Compare All Models", "📊 Experiment Results"
749
  ])
750
 
751
 
@@ -961,7 +998,202 @@ with tab_compare:
961
 
962
 
963
  # ═══════════════════════════════════════════════════════════════════════════
964
- # Tab 3 — Experiment Results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
965
  # ═══════════════════════════════════════════════════════════════════════════
966
 
967
  with tab_results:
 
14
  import os
15
  import warnings
16
  import torch
17
+ import numpy as np
18
  import streamlit as st
19
  from PIL import Image
20
  from models.blip_tuner import generate_with_mask
 
477
  return tok, mdl
478
 
479
 
480
+ @st.cache_resource(show_spinner=False)
481
+ def load_blip_attention_model(weight_source="base"):
482
+ from transformers import BlipForConditionalGeneration, BlipProcessor
483
+ device = get_device()
484
+ processor = BlipProcessor.from_pretrained(
485
+ "Salesforce/blip-image-captioning-base", use_fast=True
486
+ )
487
+ model = BlipForConditionalGeneration.from_pretrained(
488
+ "Salesforce/blip-image-captioning-base"
489
+ )
490
+
491
+ if weight_source != "base":
492
+ output_root, _, _ = _resolve_weight_paths(
493
+ need_outputs=True, need_shakespeare=False
494
+ )
495
+ ckpt = _ckpt_path(output_root, "blip", weight_source)
496
+ if os.path.isdir(ckpt) and os.listdir(ckpt):
497
+ loaded = BlipForConditionalGeneration.from_pretrained(ckpt)
498
+ model.load_state_dict(loaded.state_dict(), strict=False)
499
+ del loaded
500
+
501
+ try:
502
+ model.gradient_checkpointing_disable()
503
+ except Exception:
504
+ pass
505
+ model.config.use_cache = False
506
+ model.to(device).eval()
507
+ return processor, model, device
508
+
509
+
510
+ @st.cache_resource(show_spinner=False)
511
+ def load_alignment_detector():
512
+ from models.attention_flow import load_owlvit_detector
513
+ return load_owlvit_detector(get_device())
514
+
515
+
516
  # ─────────────────────────────────────────────────────────────────────────────
517
  # Toxicity Check
518
  # ─────────────────────────────────────────────────────────────────────────────
 
781
  # Tabs
782
  # ─────────────────────────────────────────────────────────────────────────────
783
 
784
+ tab_caption, tab_compare, tab_attention, tab_results = st.tabs([
785
+ "🖼️ Caption", "🔀 Compare All Models", "🧠 Attention Explorer", "📊 Experiment Results"
786
  ])
787
 
788
 
 
998
 
999
 
1000
  # ═══════════════════════════════════════════════════════════════════════════
1001
+ # Tab 3 — Attention Explorer (Task 2)
1002
+ # ═══════════════════════════════════════════════════════════════════════════
1003
+
1004
+ with tab_attention:
1005
+ st.markdown("### 🧠 BLIP Attention Explorer")
1006
+ st.caption(
1007
+ "Step-by-step cross-attention analysis with rollout across decoder layers, "
1008
+ "2x5 heatmap grid, IoU grounding score, and caption-length summary."
1009
+ )
1010
+
1011
+ attn_col_left, attn_col_right = st.columns([1, 1], gap="large")
1012
+ with attn_col_left:
1013
+ attn_file = st.file_uploader(
1014
+ "Upload an image for attention analysis",
1015
+ type=["jpg", "jpeg", "png", "webp"],
1016
+ key="attention_uploader",
1017
+ )
1018
+ if attn_file:
1019
+ attn_image = Image.open(attn_file).convert("RGB")
1020
+ st.image(attn_image, caption="Attention Input Image", use_column_width=True)
1021
+
1022
+ with attn_col_right:
1023
+ _ensure_model_outputs_available("blip")
1024
+ attn_weight_options = {"Base (Pretrained)": "base"}
1025
+ if _has_finetuned("blip", "best"):
1026
+ attn_weight_options["Fine-tuned (Best)"] = "best"
1027
+ if _has_finetuned("blip", "latest"):
1028
+ attn_weight_options["Fine-tuned (Latest)"] = "latest"
1029
+ attn_weight_choice = st.selectbox(
1030
+ "BLIP Weight Source",
1031
+ list(attn_weight_options.keys()),
1032
+ index=0,
1033
+ key="attn_weight_choice",
1034
+ )
1035
+ attn_weight_source = attn_weight_options[attn_weight_choice]
1036
+
1037
+ token_mode = st.radio(
1038
+ "Token Source",
1039
+ ["Generated Caption", "Custom Text Prompt"],
1040
+ horizontal=True,
1041
+ key="attn_token_mode",
1042
+ )
1043
+ custom_text = ""
1044
+ if token_mode == "Custom Text Prompt":
1045
+ custom_text = st.text_input(
1046
+ "Enter custom text/words for heatmap tracing",
1047
+ value="a dog playing with a ball",
1048
+ key="attn_custom_text",
1049
+ )
1050
+
1051
+ max_attn_steps = st.slider(
1052
+ "Caption Steps to Analyze",
1053
+ min_value=3,
1054
+ max_value=12,
1055
+ value=9,
1056
+ key="attn_steps",
1057
+ )
1058
+ run_iou = st.toggle(
1059
+ "Compute IoU Alignment with OWL-ViT (slower)",
1060
+ value=True,
1061
+ key="attn_iou_toggle",
1062
+ )
1063
+
1064
+ run_attention_btn = st.button(
1065
+ "Run Step-by-Step Attention Analysis",
1066
+ disabled=(attn_file is None or (token_mode == "Custom Text Prompt" and not custom_text.strip())),
1067
+ key="attn_run_btn",
1068
+ )
1069
+
1070
+ if run_attention_btn and attn_file:
1071
+ from models.attention_flow import (
1072
+ build_attention_grid_figure,
1073
+ decode_custom_text_with_flow,
1074
+ decode_generated_caption_with_flow,
1075
+ encode_image_for_flow,
1076
+ grade_alignment_with_detector,
1077
+ summarize_caption_alignment,
1078
+ )
1079
+
1080
+ attn_image = Image.open(attn_file).convert("RGB")
1081
+ iou_results = []
1082
+
1083
+ with st.status("Running attention pipeline...", expanded=True) as status:
1084
+ st.write("Step 1/5: Loading BLIP model and selected weights")
1085
+ attn_processor, attn_model, attn_device = load_blip_attention_model(attn_weight_source)
1086
+
1087
+ st.write("Step 2/5: Encoding image through ViT")
1088
+ image_224, enc_hidden, enc_mask = encode_image_for_flow(
1089
+ attn_model, attn_processor, attn_device, attn_image
1090
+ )
1091
+
1092
+ st.write("Step 3/5: Extracting rollout heatmaps token-by-token")
1093
+ if token_mode == "Custom Text Prompt":
1094
+ tokens, heatmaps = decode_custom_text_with_flow(
1095
+ attn_model,
1096
+ attn_processor,
1097
+ attn_device,
1098
+ enc_hidden,
1099
+ enc_mask,
1100
+ custom_text,
1101
+ max_tokens=max_attn_steps,
1102
+ )
1103
+ else:
1104
+ tokens, heatmaps = decode_generated_caption_with_flow(
1105
+ attn_model,
1106
+ attn_processor,
1107
+ attn_device,
1108
+ enc_hidden,
1109
+ enc_mask,
1110
+ max_tokens=max_attn_steps,
1111
+ )
1112
+
1113
+ st.write("Step 4/5: Building 2x5 attention grid")
1114
+ fig_grid = build_attention_grid_figure(image_224, tokens, heatmaps, n_rows=2, n_cols=5)
1115
+
1116
+ if run_iou:
1117
+ st.write("Step 5/5: Computing IoU alignment using OWL-ViT detections")
1118
+ detector = load_alignment_detector()
1119
+ iou_results = grade_alignment_with_detector(attn_image, tokens, heatmaps, detector)
1120
+ else:
1121
+ st.write("Step 5/5: IoU grading skipped by user")
1122
+
1123
+ status.update(label="Attention pipeline complete", state="complete", expanded=False)
1124
+
1125
+ st.pyplot(fig_grid, use_container_width=True)
1126
+ caption_tokens = " ".join(tokens) if tokens else "[No tokens generated]"
1127
+ st.markdown(f"**Decoded tokens:** `{caption_tokens}`")
1128
+
1129
+ summary = summarize_caption_alignment(iou_results, len(tokens))
1130
+ st.markdown(
1131
+ f"**Caption length:** `{summary['caption_length']}` | "
1132
+ f"**Mean alignment IoU:** `{summary['mean_alignment_iou']:.4f}`"
1133
+ )
1134
+
1135
+ if run_iou:
1136
+ st.markdown("#### Word-level Alignment (IoU)")
1137
+ if iou_results:
1138
+ table_rows = [
1139
+ {
1140
+ "word": item["word"],
1141
+ "position": item["position"],
1142
+ "iou": round(item["iou"], 4),
1143
+ "det_score": round(item["det_score"], 4),
1144
+ "box": [int(x) for x in item["box"]],
1145
+ }
1146
+ for item in iou_results
1147
+ ]
1148
+ st.dataframe(table_rows, use_container_width=True)
1149
+
1150
+ strong = [item["word"] for item in iou_results if item["iou"] >= 0.30]
1151
+ weak = [item["word"] for item in iou_results if item["iou"] < 0.10]
1152
+ if strong:
1153
+ st.success("Strongly grounded words: " + ", ".join(strong))
1154
+ if weak:
1155
+ st.warning("Weakly grounded words: " + ", ".join(weak))
1156
+ else:
1157
+ st.info("No detectable object-word matches found for IoU grading on this run.")
1158
+
1159
+ if "alignment_history" not in st.session_state:
1160
+ st.session_state["alignment_history"] = []
1161
+ st.session_state["alignment_history"].append(
1162
+ {
1163
+ "caption_length": int(summary["caption_length"]),
1164
+ "mean_alignment_iou": float(summary["mean_alignment_iou"]),
1165
+ "mode": token_mode,
1166
+ "weights": attn_weight_source,
1167
+ }
1168
+ )
1169
+
1170
+ st.markdown("#### Caption Length -> Mean Alignment IoU")
1171
+ history = st.session_state["alignment_history"]
1172
+ if history:
1173
+ try:
1174
+ import matplotlib.pyplot as plt
1175
+
1176
+ x_vals = [item["caption_length"] for item in history]
1177
+ y_vals = [item["mean_alignment_iou"] for item in history]
1178
+ fig_summary, ax_summary = plt.subplots(figsize=(6, 3.2))
1179
+ ax_summary.scatter(x_vals, y_vals, color="#58a6ff", alpha=0.85)
1180
+ if len(x_vals) > 1:
1181
+ z = np.polyfit(x_vals, y_vals, 1)
1182
+ trend = np.poly1d(z)
1183
+ xs = sorted(x_vals)
1184
+ ax_summary.plot(xs, [trend(v) for v in xs], linestyle="--", color="#ff7b72")
1185
+ ax_summary.set_xlabel("Caption length")
1186
+ ax_summary.set_ylabel("Mean IoU")
1187
+ ax_summary.set_title("Alignment Trend")
1188
+ ax_summary.grid(alpha=0.35, linestyle="--")
1189
+ st.pyplot(fig_summary, use_container_width=True)
1190
+ except Exception:
1191
+ pass
1192
+ st.dataframe(history[-20:], use_container_width=True)
1193
+
1194
+
1195
+ # ═══════════════════════════════════════════════════════════════════════════
1196
+ # Tab 4 — Experiment Results
1197
  # ═══════════════════════════════════════════════════════════════════════════
1198
 
1199
  with tab_results:
models/attention_flow.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Tuple
3
+
4
+ import cv2
5
+ import matplotlib
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from PIL import Image
10
+
11
+ matplotlib.use("Agg")
12
+ import matplotlib.pyplot as plt
13
+
14
+
15
+ STOP_WORDS = {
16
+ "a", "an", "the", "and", "or", "but", "is", "are", "was", "were",
17
+ "in", "on", "at", "to", "for", "with", "by", "it", "this", "that",
18
+ "there", "here", "of", "up", "out", ".", ",", "!", "##",
19
+ }
20
+
21
+
22
+ class FlowExtractor:
23
+ def __init__(self, model):
24
+ self.model = model
25
+ self._hooks = []
26
+ self.layers = []
27
+
28
+ for layer in model.text_decoder.bert.encoder.layer:
29
+ if hasattr(layer, "crossattention"):
30
+ holder = {"fwd": None, "grad": None}
31
+ self.layers.append(holder)
32
+
33
+ def _make_hook(h):
34
+ def _fwd(module, inputs, outputs):
35
+ if len(outputs) > 1 and outputs[1] is not None:
36
+ h["fwd"] = outputs[1]
37
+ if h["fwd"].requires_grad:
38
+ h["fwd"].register_hook(
39
+ lambda g, _h=h: _h.update({"grad": g.detach()})
40
+ )
41
+ return _fwd
42
+
43
+ target = layer.crossattention.self
44
+ self._hooks.append(target.register_forward_hook(_make_hook(holder)))
45
+
46
+ def clear(self):
47
+ for holder in self.layers:
48
+ holder["fwd"] = None
49
+ holder["grad"] = None
50
+
51
+ def remove(self):
52
+ for hook in self._hooks:
53
+ hook.remove()
54
+ self._hooks = []
55
+
56
+
57
+ def encode_image_for_flow(model, processor, device, image_pil: Image.Image):
58
+ image_224 = image_pil.resize((224, 224), Image.LANCZOS)
59
+ inputs = processor(images=image_224, return_tensors="pt").to(device)
60
+ with torch.no_grad():
61
+ vision_out = model.vision_model(pixel_values=inputs["pixel_values"])
62
+ encoder_hidden = vision_out[0].detach().requires_grad_(False)
63
+ encoder_mask = torch.ones(encoder_hidden.size()[:-1], dtype=torch.long, device=device)
64
+ return image_224, encoder_hidden, encoder_mask
65
+
66
+
67
+ def _single_layer_gradcam(holder, token_idx: int = -1) -> torch.Tensor:
68
+ attn = holder["fwd"][:, :, token_idx, :]
69
+ grad = holder["grad"][:, :, token_idx, :]
70
+ cam = (attn * grad).mean(dim=1).squeeze()
71
+ return torch.clamp(cam, min=0.0)
72
+
73
+
74
+ def _normalize1d(tensor: torch.Tensor) -> torch.Tensor:
75
+ denom = tensor.sum()
76
+ if denom > 0:
77
+ return tensor / denom
78
+ return tensor
79
+
80
+
81
+ def compute_attention_flow(
82
+ extractor: FlowExtractor,
83
+ num_image_tokens: int = 197,
84
+ residual_weight: float = 0.05,
85
+ out_resolution: int = 224,
86
+ ) -> np.ndarray:
87
+ valid_cams = []
88
+ for holder in extractor.layers:
89
+ if holder["fwd"] is None or holder["grad"] is None:
90
+ continue
91
+ valid_cams.append(_single_layer_gradcam(holder).detach())
92
+
93
+ if not valid_cams:
94
+ return np.zeros((out_resolution, out_resolution), dtype=np.float32)
95
+
96
+ uniform = torch.ones(num_image_tokens, device=valid_cams[0].device) / num_image_tokens
97
+ rollout = _normalize1d(valid_cams[0])
98
+ for cam in valid_cams[1:]:
99
+ rollout = _normalize1d(rollout) * _normalize1d(cam) + residual_weight * uniform
100
+ rollout = torch.clamp(rollout, min=0.0)
101
+
102
+ spatial = rollout[1:]
103
+ grid_size = int(math.sqrt(spatial.numel()))
104
+ hm_tensor = spatial.detach().cpu().reshape(1, 1, grid_size, grid_size).float()
105
+ hm_up = F.interpolate(
106
+ hm_tensor,
107
+ size=(out_resolution, out_resolution),
108
+ mode="bicubic",
109
+ align_corners=False,
110
+ ).squeeze()
111
+ hm_np = hm_up.numpy()
112
+ lo, hi = hm_np.min(), hm_np.max()
113
+ if hi > lo:
114
+ hm_np = (hm_np - lo) / (hi - lo)
115
+ else:
116
+ hm_np = np.zeros_like(hm_np)
117
+ return hm_np.astype(np.float32)
118
+
119
+
120
+ def decode_generated_caption_with_flow(
121
+ model,
122
+ processor,
123
+ device,
124
+ encoder_hidden,
125
+ encoder_mask,
126
+ max_tokens: int = 20,
127
+ ) -> Tuple[List[str], List[np.ndarray]]:
128
+ extractor = FlowExtractor(model)
129
+ input_ids = torch.LongTensor([[model.config.text_config.bos_token_id]]).to(device)
130
+ tokens, heatmaps = [], []
131
+
132
+ for _ in range(max_tokens):
133
+ model.zero_grad()
134
+ extractor.clear()
135
+ outputs = model.text_decoder(
136
+ input_ids=input_ids,
137
+ encoder_hidden_states=encoder_hidden,
138
+ encoder_attention_mask=encoder_mask,
139
+ output_attentions=True,
140
+ return_dict=True,
141
+ )
142
+ logits = outputs.logits[:, -1, :]
143
+ next_token = torch.argmax(logits, dim=-1)
144
+ if next_token.item() == model.config.text_config.sep_token_id:
145
+ break
146
+
147
+ logits[0, next_token.item()].backward(retain_graph=False)
148
+ heatmaps.append(compute_attention_flow(extractor))
149
+ tokens.append(processor.tokenizer.decode([next_token.item()]).strip())
150
+ input_ids = torch.cat([input_ids, next_token.reshape(1, 1)], dim=-1)
151
+
152
+ extractor.remove()
153
+ return tokens, heatmaps
154
+
155
+
156
+ def decode_custom_text_with_flow(
157
+ model,
158
+ processor,
159
+ device,
160
+ encoder_hidden,
161
+ encoder_mask,
162
+ text: str,
163
+ max_tokens: int = 20,
164
+ ) -> Tuple[List[str], List[np.ndarray]]:
165
+ extractor = FlowExtractor(model)
166
+ token_ids = processor.tokenizer(
167
+ text,
168
+ add_special_tokens=False,
169
+ return_attention_mask=False,
170
+ )["input_ids"][:max_tokens]
171
+
172
+ input_ids = torch.LongTensor([[model.config.text_config.bos_token_id]]).to(device)
173
+ tokens, heatmaps = [], []
174
+
175
+ for target_token_id in token_ids:
176
+ model.zero_grad()
177
+ extractor.clear()
178
+ outputs = model.text_decoder(
179
+ input_ids=input_ids,
180
+ encoder_hidden_states=encoder_hidden,
181
+ encoder_attention_mask=encoder_mask,
182
+ output_attentions=True,
183
+ return_dict=True,
184
+ )
185
+ logits = outputs.logits[:, -1, :]
186
+ score = logits[0, target_token_id]
187
+ score.backward(retain_graph=False)
188
+
189
+ heatmaps.append(compute_attention_flow(extractor))
190
+ tokens.append(processor.tokenizer.decode([target_token_id]).strip())
191
+ next_tensor = torch.LongTensor([[target_token_id]]).to(device)
192
+ input_ids = torch.cat([input_ids, next_tensor], dim=-1)
193
+
194
+ extractor.remove()
195
+ return tokens, heatmaps
196
+
197
+
198
+ def overlay_heatmap_on_image(
199
+ image_pil: Image.Image,
200
+ heatmap_np: np.ndarray,
201
+ alpha: float = 0.5,
202
+ hot_threshold: float = 0.1,
203
+ ) -> Image.Image:
204
+ h, w = heatmap_np.shape
205
+ image_np = np.array(image_pil.resize((w, h), Image.LANCZOS))
206
+ hm_u8 = np.uint8(255.0 * heatmap_np)
207
+ colored = cv2.applyColorMap(hm_u8, cv2.COLORMAP_INFERNO)
208
+ colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
209
+ mask = (heatmap_np > hot_threshold).astype(np.float32)[..., None]
210
+ blended = image_np * (1 - mask * alpha) + colored * (mask * alpha)
211
+ return Image.fromarray(blended.astype(np.uint8))
212
+
213
+
214
+ def build_attention_grid_figure(
215
+ image_pil: Image.Image,
216
+ tokens: List[str],
217
+ heatmaps: List[np.ndarray],
218
+ n_rows: int = 2,
219
+ n_cols: int = 5,
220
+ ):
221
+ n_panels = n_rows * n_cols
222
+ n_words = min(n_panels - 1, len(tokens))
223
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3.2, n_rows * 3.2))
224
+ axes = axes.flatten()
225
+
226
+ axes[0].imshow(image_pil)
227
+ axes[0].set_title("Original", fontsize=11, fontweight="bold")
228
+ axes[0].axis("off")
229
+
230
+ for index in range(n_words):
231
+ overlay = overlay_heatmap_on_image(image_pil, heatmaps[index])
232
+ axes[index + 1].imshow(overlay)
233
+ axes[index + 1].set_title(f"'{tokens[index]}'", fontsize=10, fontweight="bold")
234
+ axes[index + 1].axis("off")
235
+
236
+ for index in range(n_words + 1, n_panels):
237
+ axes[index].axis("off")
238
+
239
+ caption_preview = " ".join(tokens[:12])
240
+ fig.suptitle(
241
+ f"Cross-Attention Flow (2x5)\nCaption Tokens: {caption_preview}",
242
+ fontsize=12,
243
+ fontweight="bold",
244
+ y=1.02,
245
+ )
246
+ plt.tight_layout()
247
+ return fig
248
+
249
+
250
+ def load_owlvit_detector(device):
251
+ from transformers import pipeline
252
+ pipe_device = 0 if str(device).startswith("cuda") else -1
253
+ return pipeline(
254
+ task="zero-shot-object-detection",
255
+ model="google/owlvit-base-patch32",
256
+ device=pipe_device,
257
+ )
258
+
259
+
260
+ def binarize_heatmap(heatmap_np: np.ndarray, target_hw: tuple) -> np.ndarray:
261
+ hm = cv2.resize(heatmap_np, (target_hw[1], target_hw[0]))
262
+ hm_u8 = np.uint8(255.0 * hm)
263
+ _, binary = cv2.threshold(hm_u8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
264
+ return binary > 0
265
+
266
+
267
+ def calculate_iou(mask: np.ndarray, box: list, img_shape: tuple) -> float:
268
+ box_mask = np.zeros(img_shape, dtype=bool)
269
+ xmin, ymin, xmax, ymax = map(int, box)
270
+ xmin = max(0, xmin)
271
+ ymin = max(0, ymin)
272
+ xmax = min(img_shape[1], xmax)
273
+ ymax = min(img_shape[0], ymax)
274
+ box_mask[ymin:ymax, xmin:xmax] = True
275
+ inter = np.logical_and(mask, box_mask).sum()
276
+ union = np.logical_or(mask, box_mask).sum()
277
+ return float(inter) / union if union > 0 else 0.0
278
+
279
+
280
+ def grade_alignment_with_detector(
281
+ image_pil: Image.Image,
282
+ tokens: List[str],
283
+ heatmaps: List[np.ndarray],
284
+ detector,
285
+ min_detection_score: float = 0.05,
286
+ ) -> List[dict]:
287
+ results = []
288
+ img_shape = (image_pil.height, image_pil.width)
289
+ for idx, (word, hm) in enumerate(zip(tokens, heatmaps)):
290
+ clean_word = word.replace("##", "").lower()
291
+ if len(clean_word) < 3 or clean_word in STOP_WORDS or not clean_word.isalpha():
292
+ continue
293
+
294
+ detections = detector(image_pil, candidate_labels=[clean_word])
295
+ best_box, best_score = None, 0.0
296
+ for detection in detections:
297
+ if detection["score"] > best_score and detection["score"] >= min_detection_score:
298
+ best_score = detection["score"]
299
+ best_box = [
300
+ detection["box"]["xmin"],
301
+ detection["box"]["ymin"],
302
+ detection["box"]["xmax"],
303
+ detection["box"]["ymax"],
304
+ ]
305
+ if best_box is None:
306
+ continue
307
+
308
+ mask = binarize_heatmap(hm, img_shape)
309
+ iou = calculate_iou(mask, best_box, img_shape)
310
+ results.append(
311
+ {
312
+ "word": clean_word,
313
+ "position": idx + 1,
314
+ "iou": float(iou),
315
+ "det_score": float(best_score),
316
+ "box": best_box,
317
+ }
318
+ )
319
+
320
+ return results
321
+
322
+
323
+ def summarize_caption_alignment(results: List[dict], caption_length: int) -> dict:
324
+ if not results:
325
+ return {"caption_length": caption_length, "mean_alignment_iou": 0.0}
326
+ mean_iou = float(np.mean([item["iou"] for item in results]))
327
+ return {"caption_length": caption_length, "mean_alignment_iou": mean_iou}
328
+
requirements.txt CHANGED
@@ -12,3 +12,5 @@ tqdm
12
  accelerate
13
  sentencepiece
14
  pycocoevalcap
 
 
 
12
  accelerate
13
  sentencepiece
14
  pycocoevalcap
15
+ matplotlib
16
+ opencv-python-headless