Spaces:
Sleeping
Sleeping
Deploy Streamlit Space app
Browse files- app.py +235 -3
- models/attention_flow.py +328 -0
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -14,6 +14,7 @@ Features:
|
|
| 14 |
import os
|
| 15 |
import warnings
|
| 16 |
import torch
|
|
|
|
| 17 |
import streamlit as st
|
| 18 |
from PIL import Image
|
| 19 |
from models.blip_tuner import generate_with_mask
|
|
@@ -476,6 +477,42 @@ def load_toxicity_filter():
|
|
| 476 |
return tok, mdl
|
| 477 |
|
| 478 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
# ─────────────────────────────────────────────────────────────────────────────
|
| 480 |
# Toxicity Check
|
| 481 |
# ─────────────────────────────────────────────────────────────────────────────
|
|
@@ -744,8 +781,8 @@ def render_caption_card(model_name, caption, weight_src, num_beams, length_penal
|
|
| 744 |
# Tabs
|
| 745 |
# ─────────────────────────────────────────────────────────────────────────────
|
| 746 |
|
| 747 |
-
tab_caption, tab_compare, tab_results = st.tabs([
|
| 748 |
-
"🖼️ Caption", "🔀 Compare All Models", "📊 Experiment Results"
|
| 749 |
])
|
| 750 |
|
| 751 |
|
|
@@ -961,7 +998,202 @@ with tab_compare:
|
|
| 961 |
|
| 962 |
|
| 963 |
# ═══════════════════════════════════════════════════════════════════════════
|
| 964 |
-
# Tab 3 —
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 965 |
# ═══════════════════════════════════════════════════════════════════════════
|
| 966 |
|
| 967 |
with tab_results:
|
|
|
|
| 14 |
import os
|
| 15 |
import warnings
|
| 16 |
import torch
|
| 17 |
+
import numpy as np
|
| 18 |
import streamlit as st
|
| 19 |
from PIL import Image
|
| 20 |
from models.blip_tuner import generate_with_mask
|
|
|
|
| 477 |
return tok, mdl
|
| 478 |
|
| 479 |
|
| 480 |
+
@st.cache_resource(show_spinner=False)
|
| 481 |
+
def load_blip_attention_model(weight_source="base"):
|
| 482 |
+
from transformers import BlipForConditionalGeneration, BlipProcessor
|
| 483 |
+
device = get_device()
|
| 484 |
+
processor = BlipProcessor.from_pretrained(
|
| 485 |
+
"Salesforce/blip-image-captioning-base", use_fast=True
|
| 486 |
+
)
|
| 487 |
+
model = BlipForConditionalGeneration.from_pretrained(
|
| 488 |
+
"Salesforce/blip-image-captioning-base"
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
if weight_source != "base":
|
| 492 |
+
output_root, _, _ = _resolve_weight_paths(
|
| 493 |
+
need_outputs=True, need_shakespeare=False
|
| 494 |
+
)
|
| 495 |
+
ckpt = _ckpt_path(output_root, "blip", weight_source)
|
| 496 |
+
if os.path.isdir(ckpt) and os.listdir(ckpt):
|
| 497 |
+
loaded = BlipForConditionalGeneration.from_pretrained(ckpt)
|
| 498 |
+
model.load_state_dict(loaded.state_dict(), strict=False)
|
| 499 |
+
del loaded
|
| 500 |
+
|
| 501 |
+
try:
|
| 502 |
+
model.gradient_checkpointing_disable()
|
| 503 |
+
except Exception:
|
| 504 |
+
pass
|
| 505 |
+
model.config.use_cache = False
|
| 506 |
+
model.to(device).eval()
|
| 507 |
+
return processor, model, device
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
@st.cache_resource(show_spinner=False)
|
| 511 |
+
def load_alignment_detector():
|
| 512 |
+
from models.attention_flow import load_owlvit_detector
|
| 513 |
+
return load_owlvit_detector(get_device())
|
| 514 |
+
|
| 515 |
+
|
| 516 |
# ─────────────────────────────────────────────────────────────────────────────
|
| 517 |
# Toxicity Check
|
| 518 |
# ─────────────────────────────────────────────────────────────────────────────
|
|
|
|
| 781 |
# Tabs
|
| 782 |
# ─────────────────────────────────────────────────────────────────────────────
|
| 783 |
|
| 784 |
+
tab_caption, tab_compare, tab_attention, tab_results = st.tabs([
|
| 785 |
+
"🖼️ Caption", "🔀 Compare All Models", "🧠 Attention Explorer", "📊 Experiment Results"
|
| 786 |
])
|
| 787 |
|
| 788 |
|
|
|
|
| 998 |
|
| 999 |
|
| 1000 |
# ═══════════════════════════════════════════════════════════════════════════
|
| 1001 |
+
# Tab 3 — Attention Explorer (Task 2)
|
| 1002 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 1003 |
+
|
| 1004 |
+
with tab_attention:
|
| 1005 |
+
st.markdown("### 🧠 BLIP Attention Explorer")
|
| 1006 |
+
st.caption(
|
| 1007 |
+
"Step-by-step cross-attention analysis with rollout across decoder layers, "
|
| 1008 |
+
"2x5 heatmap grid, IoU grounding score, and caption-length summary."
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
+
attn_col_left, attn_col_right = st.columns([1, 1], gap="large")
|
| 1012 |
+
with attn_col_left:
|
| 1013 |
+
attn_file = st.file_uploader(
|
| 1014 |
+
"Upload an image for attention analysis",
|
| 1015 |
+
type=["jpg", "jpeg", "png", "webp"],
|
| 1016 |
+
key="attention_uploader",
|
| 1017 |
+
)
|
| 1018 |
+
if attn_file:
|
| 1019 |
+
attn_image = Image.open(attn_file).convert("RGB")
|
| 1020 |
+
st.image(attn_image, caption="Attention Input Image", use_column_width=True)
|
| 1021 |
+
|
| 1022 |
+
with attn_col_right:
|
| 1023 |
+
_ensure_model_outputs_available("blip")
|
| 1024 |
+
attn_weight_options = {"Base (Pretrained)": "base"}
|
| 1025 |
+
if _has_finetuned("blip", "best"):
|
| 1026 |
+
attn_weight_options["Fine-tuned (Best)"] = "best"
|
| 1027 |
+
if _has_finetuned("blip", "latest"):
|
| 1028 |
+
attn_weight_options["Fine-tuned (Latest)"] = "latest"
|
| 1029 |
+
attn_weight_choice = st.selectbox(
|
| 1030 |
+
"BLIP Weight Source",
|
| 1031 |
+
list(attn_weight_options.keys()),
|
| 1032 |
+
index=0,
|
| 1033 |
+
key="attn_weight_choice",
|
| 1034 |
+
)
|
| 1035 |
+
attn_weight_source = attn_weight_options[attn_weight_choice]
|
| 1036 |
+
|
| 1037 |
+
token_mode = st.radio(
|
| 1038 |
+
"Token Source",
|
| 1039 |
+
["Generated Caption", "Custom Text Prompt"],
|
| 1040 |
+
horizontal=True,
|
| 1041 |
+
key="attn_token_mode",
|
| 1042 |
+
)
|
| 1043 |
+
custom_text = ""
|
| 1044 |
+
if token_mode == "Custom Text Prompt":
|
| 1045 |
+
custom_text = st.text_input(
|
| 1046 |
+
"Enter custom text/words for heatmap tracing",
|
| 1047 |
+
value="a dog playing with a ball",
|
| 1048 |
+
key="attn_custom_text",
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
max_attn_steps = st.slider(
|
| 1052 |
+
"Caption Steps to Analyze",
|
| 1053 |
+
min_value=3,
|
| 1054 |
+
max_value=12,
|
| 1055 |
+
value=9,
|
| 1056 |
+
key="attn_steps",
|
| 1057 |
+
)
|
| 1058 |
+
run_iou = st.toggle(
|
| 1059 |
+
"Compute IoU Alignment with OWL-ViT (slower)",
|
| 1060 |
+
value=True,
|
| 1061 |
+
key="attn_iou_toggle",
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
run_attention_btn = st.button(
|
| 1065 |
+
"Run Step-by-Step Attention Analysis",
|
| 1066 |
+
disabled=(attn_file is None or (token_mode == "Custom Text Prompt" and not custom_text.strip())),
|
| 1067 |
+
key="attn_run_btn",
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
if run_attention_btn and attn_file:
|
| 1071 |
+
from models.attention_flow import (
|
| 1072 |
+
build_attention_grid_figure,
|
| 1073 |
+
decode_custom_text_with_flow,
|
| 1074 |
+
decode_generated_caption_with_flow,
|
| 1075 |
+
encode_image_for_flow,
|
| 1076 |
+
grade_alignment_with_detector,
|
| 1077 |
+
summarize_caption_alignment,
|
| 1078 |
+
)
|
| 1079 |
+
|
| 1080 |
+
attn_image = Image.open(attn_file).convert("RGB")
|
| 1081 |
+
iou_results = []
|
| 1082 |
+
|
| 1083 |
+
with st.status("Running attention pipeline...", expanded=True) as status:
|
| 1084 |
+
st.write("Step 1/5: Loading BLIP model and selected weights")
|
| 1085 |
+
attn_processor, attn_model, attn_device = load_blip_attention_model(attn_weight_source)
|
| 1086 |
+
|
| 1087 |
+
st.write("Step 2/5: Encoding image through ViT")
|
| 1088 |
+
image_224, enc_hidden, enc_mask = encode_image_for_flow(
|
| 1089 |
+
attn_model, attn_processor, attn_device, attn_image
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
st.write("Step 3/5: Extracting rollout heatmaps token-by-token")
|
| 1093 |
+
if token_mode == "Custom Text Prompt":
|
| 1094 |
+
tokens, heatmaps = decode_custom_text_with_flow(
|
| 1095 |
+
attn_model,
|
| 1096 |
+
attn_processor,
|
| 1097 |
+
attn_device,
|
| 1098 |
+
enc_hidden,
|
| 1099 |
+
enc_mask,
|
| 1100 |
+
custom_text,
|
| 1101 |
+
max_tokens=max_attn_steps,
|
| 1102 |
+
)
|
| 1103 |
+
else:
|
| 1104 |
+
tokens, heatmaps = decode_generated_caption_with_flow(
|
| 1105 |
+
attn_model,
|
| 1106 |
+
attn_processor,
|
| 1107 |
+
attn_device,
|
| 1108 |
+
enc_hidden,
|
| 1109 |
+
enc_mask,
|
| 1110 |
+
max_tokens=max_attn_steps,
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
st.write("Step 4/5: Building 2x5 attention grid")
|
| 1114 |
+
fig_grid = build_attention_grid_figure(image_224, tokens, heatmaps, n_rows=2, n_cols=5)
|
| 1115 |
+
|
| 1116 |
+
if run_iou:
|
| 1117 |
+
st.write("Step 5/5: Computing IoU alignment using OWL-ViT detections")
|
| 1118 |
+
detector = load_alignment_detector()
|
| 1119 |
+
iou_results = grade_alignment_with_detector(attn_image, tokens, heatmaps, detector)
|
| 1120 |
+
else:
|
| 1121 |
+
st.write("Step 5/5: IoU grading skipped by user")
|
| 1122 |
+
|
| 1123 |
+
status.update(label="Attention pipeline complete", state="complete", expanded=False)
|
| 1124 |
+
|
| 1125 |
+
st.pyplot(fig_grid, use_container_width=True)
|
| 1126 |
+
caption_tokens = " ".join(tokens) if tokens else "[No tokens generated]"
|
| 1127 |
+
st.markdown(f"**Decoded tokens:** `{caption_tokens}`")
|
| 1128 |
+
|
| 1129 |
+
summary = summarize_caption_alignment(iou_results, len(tokens))
|
| 1130 |
+
st.markdown(
|
| 1131 |
+
f"**Caption length:** `{summary['caption_length']}` | "
|
| 1132 |
+
f"**Mean alignment IoU:** `{summary['mean_alignment_iou']:.4f}`"
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
if run_iou:
|
| 1136 |
+
st.markdown("#### Word-level Alignment (IoU)")
|
| 1137 |
+
if iou_results:
|
| 1138 |
+
table_rows = [
|
| 1139 |
+
{
|
| 1140 |
+
"word": item["word"],
|
| 1141 |
+
"position": item["position"],
|
| 1142 |
+
"iou": round(item["iou"], 4),
|
| 1143 |
+
"det_score": round(item["det_score"], 4),
|
| 1144 |
+
"box": [int(x) for x in item["box"]],
|
| 1145 |
+
}
|
| 1146 |
+
for item in iou_results
|
| 1147 |
+
]
|
| 1148 |
+
st.dataframe(table_rows, use_container_width=True)
|
| 1149 |
+
|
| 1150 |
+
strong = [item["word"] for item in iou_results if item["iou"] >= 0.30]
|
| 1151 |
+
weak = [item["word"] for item in iou_results if item["iou"] < 0.10]
|
| 1152 |
+
if strong:
|
| 1153 |
+
st.success("Strongly grounded words: " + ", ".join(strong))
|
| 1154 |
+
if weak:
|
| 1155 |
+
st.warning("Weakly grounded words: " + ", ".join(weak))
|
| 1156 |
+
else:
|
| 1157 |
+
st.info("No detectable object-word matches found for IoU grading on this run.")
|
| 1158 |
+
|
| 1159 |
+
if "alignment_history" not in st.session_state:
|
| 1160 |
+
st.session_state["alignment_history"] = []
|
| 1161 |
+
st.session_state["alignment_history"].append(
|
| 1162 |
+
{
|
| 1163 |
+
"caption_length": int(summary["caption_length"]),
|
| 1164 |
+
"mean_alignment_iou": float(summary["mean_alignment_iou"]),
|
| 1165 |
+
"mode": token_mode,
|
| 1166 |
+
"weights": attn_weight_source,
|
| 1167 |
+
}
|
| 1168 |
+
)
|
| 1169 |
+
|
| 1170 |
+
st.markdown("#### Caption Length -> Mean Alignment IoU")
|
| 1171 |
+
history = st.session_state["alignment_history"]
|
| 1172 |
+
if history:
|
| 1173 |
+
try:
|
| 1174 |
+
import matplotlib.pyplot as plt
|
| 1175 |
+
|
| 1176 |
+
x_vals = [item["caption_length"] for item in history]
|
| 1177 |
+
y_vals = [item["mean_alignment_iou"] for item in history]
|
| 1178 |
+
fig_summary, ax_summary = plt.subplots(figsize=(6, 3.2))
|
| 1179 |
+
ax_summary.scatter(x_vals, y_vals, color="#58a6ff", alpha=0.85)
|
| 1180 |
+
if len(x_vals) > 1:
|
| 1181 |
+
z = np.polyfit(x_vals, y_vals, 1)
|
| 1182 |
+
trend = np.poly1d(z)
|
| 1183 |
+
xs = sorted(x_vals)
|
| 1184 |
+
ax_summary.plot(xs, [trend(v) for v in xs], linestyle="--", color="#ff7b72")
|
| 1185 |
+
ax_summary.set_xlabel("Caption length")
|
| 1186 |
+
ax_summary.set_ylabel("Mean IoU")
|
| 1187 |
+
ax_summary.set_title("Alignment Trend")
|
| 1188 |
+
ax_summary.grid(alpha=0.35, linestyle="--")
|
| 1189 |
+
st.pyplot(fig_summary, use_container_width=True)
|
| 1190 |
+
except Exception:
|
| 1191 |
+
pass
|
| 1192 |
+
st.dataframe(history[-20:], use_container_width=True)
|
| 1193 |
+
|
| 1194 |
+
|
| 1195 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 1196 |
+
# Tab 4 — Experiment Results
|
| 1197 |
# ═══════════════════════════════════════════════════════════════════════════
|
| 1198 |
|
| 1199 |
with tab_results:
|
models/attention_flow.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import matplotlib
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
matplotlib.use("Agg")
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
STOP_WORDS = {
|
| 16 |
+
"a", "an", "the", "and", "or", "but", "is", "are", "was", "were",
|
| 17 |
+
"in", "on", "at", "to", "for", "with", "by", "it", "this", "that",
|
| 18 |
+
"there", "here", "of", "up", "out", ".", ",", "!", "##",
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class FlowExtractor:
|
| 23 |
+
def __init__(self, model):
|
| 24 |
+
self.model = model
|
| 25 |
+
self._hooks = []
|
| 26 |
+
self.layers = []
|
| 27 |
+
|
| 28 |
+
for layer in model.text_decoder.bert.encoder.layer:
|
| 29 |
+
if hasattr(layer, "crossattention"):
|
| 30 |
+
holder = {"fwd": None, "grad": None}
|
| 31 |
+
self.layers.append(holder)
|
| 32 |
+
|
| 33 |
+
def _make_hook(h):
|
| 34 |
+
def _fwd(module, inputs, outputs):
|
| 35 |
+
if len(outputs) > 1 and outputs[1] is not None:
|
| 36 |
+
h["fwd"] = outputs[1]
|
| 37 |
+
if h["fwd"].requires_grad:
|
| 38 |
+
h["fwd"].register_hook(
|
| 39 |
+
lambda g, _h=h: _h.update({"grad": g.detach()})
|
| 40 |
+
)
|
| 41 |
+
return _fwd
|
| 42 |
+
|
| 43 |
+
target = layer.crossattention.self
|
| 44 |
+
self._hooks.append(target.register_forward_hook(_make_hook(holder)))
|
| 45 |
+
|
| 46 |
+
def clear(self):
|
| 47 |
+
for holder in self.layers:
|
| 48 |
+
holder["fwd"] = None
|
| 49 |
+
holder["grad"] = None
|
| 50 |
+
|
| 51 |
+
def remove(self):
|
| 52 |
+
for hook in self._hooks:
|
| 53 |
+
hook.remove()
|
| 54 |
+
self._hooks = []
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def encode_image_for_flow(model, processor, device, image_pil: Image.Image):
|
| 58 |
+
image_224 = image_pil.resize((224, 224), Image.LANCZOS)
|
| 59 |
+
inputs = processor(images=image_224, return_tensors="pt").to(device)
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
vision_out = model.vision_model(pixel_values=inputs["pixel_values"])
|
| 62 |
+
encoder_hidden = vision_out[0].detach().requires_grad_(False)
|
| 63 |
+
encoder_mask = torch.ones(encoder_hidden.size()[:-1], dtype=torch.long, device=device)
|
| 64 |
+
return image_224, encoder_hidden, encoder_mask
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _single_layer_gradcam(holder, token_idx: int = -1) -> torch.Tensor:
|
| 68 |
+
attn = holder["fwd"][:, :, token_idx, :]
|
| 69 |
+
grad = holder["grad"][:, :, token_idx, :]
|
| 70 |
+
cam = (attn * grad).mean(dim=1).squeeze()
|
| 71 |
+
return torch.clamp(cam, min=0.0)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _normalize1d(tensor: torch.Tensor) -> torch.Tensor:
|
| 75 |
+
denom = tensor.sum()
|
| 76 |
+
if denom > 0:
|
| 77 |
+
return tensor / denom
|
| 78 |
+
return tensor
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def compute_attention_flow(
|
| 82 |
+
extractor: FlowExtractor,
|
| 83 |
+
num_image_tokens: int = 197,
|
| 84 |
+
residual_weight: float = 0.05,
|
| 85 |
+
out_resolution: int = 224,
|
| 86 |
+
) -> np.ndarray:
|
| 87 |
+
valid_cams = []
|
| 88 |
+
for holder in extractor.layers:
|
| 89 |
+
if holder["fwd"] is None or holder["grad"] is None:
|
| 90 |
+
continue
|
| 91 |
+
valid_cams.append(_single_layer_gradcam(holder).detach())
|
| 92 |
+
|
| 93 |
+
if not valid_cams:
|
| 94 |
+
return np.zeros((out_resolution, out_resolution), dtype=np.float32)
|
| 95 |
+
|
| 96 |
+
uniform = torch.ones(num_image_tokens, device=valid_cams[0].device) / num_image_tokens
|
| 97 |
+
rollout = _normalize1d(valid_cams[0])
|
| 98 |
+
for cam in valid_cams[1:]:
|
| 99 |
+
rollout = _normalize1d(rollout) * _normalize1d(cam) + residual_weight * uniform
|
| 100 |
+
rollout = torch.clamp(rollout, min=0.0)
|
| 101 |
+
|
| 102 |
+
spatial = rollout[1:]
|
| 103 |
+
grid_size = int(math.sqrt(spatial.numel()))
|
| 104 |
+
hm_tensor = spatial.detach().cpu().reshape(1, 1, grid_size, grid_size).float()
|
| 105 |
+
hm_up = F.interpolate(
|
| 106 |
+
hm_tensor,
|
| 107 |
+
size=(out_resolution, out_resolution),
|
| 108 |
+
mode="bicubic",
|
| 109 |
+
align_corners=False,
|
| 110 |
+
).squeeze()
|
| 111 |
+
hm_np = hm_up.numpy()
|
| 112 |
+
lo, hi = hm_np.min(), hm_np.max()
|
| 113 |
+
if hi > lo:
|
| 114 |
+
hm_np = (hm_np - lo) / (hi - lo)
|
| 115 |
+
else:
|
| 116 |
+
hm_np = np.zeros_like(hm_np)
|
| 117 |
+
return hm_np.astype(np.float32)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def decode_generated_caption_with_flow(
|
| 121 |
+
model,
|
| 122 |
+
processor,
|
| 123 |
+
device,
|
| 124 |
+
encoder_hidden,
|
| 125 |
+
encoder_mask,
|
| 126 |
+
max_tokens: int = 20,
|
| 127 |
+
) -> Tuple[List[str], List[np.ndarray]]:
|
| 128 |
+
extractor = FlowExtractor(model)
|
| 129 |
+
input_ids = torch.LongTensor([[model.config.text_config.bos_token_id]]).to(device)
|
| 130 |
+
tokens, heatmaps = [], []
|
| 131 |
+
|
| 132 |
+
for _ in range(max_tokens):
|
| 133 |
+
model.zero_grad()
|
| 134 |
+
extractor.clear()
|
| 135 |
+
outputs = model.text_decoder(
|
| 136 |
+
input_ids=input_ids,
|
| 137 |
+
encoder_hidden_states=encoder_hidden,
|
| 138 |
+
encoder_attention_mask=encoder_mask,
|
| 139 |
+
output_attentions=True,
|
| 140 |
+
return_dict=True,
|
| 141 |
+
)
|
| 142 |
+
logits = outputs.logits[:, -1, :]
|
| 143 |
+
next_token = torch.argmax(logits, dim=-1)
|
| 144 |
+
if next_token.item() == model.config.text_config.sep_token_id:
|
| 145 |
+
break
|
| 146 |
+
|
| 147 |
+
logits[0, next_token.item()].backward(retain_graph=False)
|
| 148 |
+
heatmaps.append(compute_attention_flow(extractor))
|
| 149 |
+
tokens.append(processor.tokenizer.decode([next_token.item()]).strip())
|
| 150 |
+
input_ids = torch.cat([input_ids, next_token.reshape(1, 1)], dim=-1)
|
| 151 |
+
|
| 152 |
+
extractor.remove()
|
| 153 |
+
return tokens, heatmaps
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def decode_custom_text_with_flow(
|
| 157 |
+
model,
|
| 158 |
+
processor,
|
| 159 |
+
device,
|
| 160 |
+
encoder_hidden,
|
| 161 |
+
encoder_mask,
|
| 162 |
+
text: str,
|
| 163 |
+
max_tokens: int = 20,
|
| 164 |
+
) -> Tuple[List[str], List[np.ndarray]]:
|
| 165 |
+
extractor = FlowExtractor(model)
|
| 166 |
+
token_ids = processor.tokenizer(
|
| 167 |
+
text,
|
| 168 |
+
add_special_tokens=False,
|
| 169 |
+
return_attention_mask=False,
|
| 170 |
+
)["input_ids"][:max_tokens]
|
| 171 |
+
|
| 172 |
+
input_ids = torch.LongTensor([[model.config.text_config.bos_token_id]]).to(device)
|
| 173 |
+
tokens, heatmaps = [], []
|
| 174 |
+
|
| 175 |
+
for target_token_id in token_ids:
|
| 176 |
+
model.zero_grad()
|
| 177 |
+
extractor.clear()
|
| 178 |
+
outputs = model.text_decoder(
|
| 179 |
+
input_ids=input_ids,
|
| 180 |
+
encoder_hidden_states=encoder_hidden,
|
| 181 |
+
encoder_attention_mask=encoder_mask,
|
| 182 |
+
output_attentions=True,
|
| 183 |
+
return_dict=True,
|
| 184 |
+
)
|
| 185 |
+
logits = outputs.logits[:, -1, :]
|
| 186 |
+
score = logits[0, target_token_id]
|
| 187 |
+
score.backward(retain_graph=False)
|
| 188 |
+
|
| 189 |
+
heatmaps.append(compute_attention_flow(extractor))
|
| 190 |
+
tokens.append(processor.tokenizer.decode([target_token_id]).strip())
|
| 191 |
+
next_tensor = torch.LongTensor([[target_token_id]]).to(device)
|
| 192 |
+
input_ids = torch.cat([input_ids, next_tensor], dim=-1)
|
| 193 |
+
|
| 194 |
+
extractor.remove()
|
| 195 |
+
return tokens, heatmaps
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def overlay_heatmap_on_image(
|
| 199 |
+
image_pil: Image.Image,
|
| 200 |
+
heatmap_np: np.ndarray,
|
| 201 |
+
alpha: float = 0.5,
|
| 202 |
+
hot_threshold: float = 0.1,
|
| 203 |
+
) -> Image.Image:
|
| 204 |
+
h, w = heatmap_np.shape
|
| 205 |
+
image_np = np.array(image_pil.resize((w, h), Image.LANCZOS))
|
| 206 |
+
hm_u8 = np.uint8(255.0 * heatmap_np)
|
| 207 |
+
colored = cv2.applyColorMap(hm_u8, cv2.COLORMAP_INFERNO)
|
| 208 |
+
colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
|
| 209 |
+
mask = (heatmap_np > hot_threshold).astype(np.float32)[..., None]
|
| 210 |
+
blended = image_np * (1 - mask * alpha) + colored * (mask * alpha)
|
| 211 |
+
return Image.fromarray(blended.astype(np.uint8))
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def build_attention_grid_figure(
|
| 215 |
+
image_pil: Image.Image,
|
| 216 |
+
tokens: List[str],
|
| 217 |
+
heatmaps: List[np.ndarray],
|
| 218 |
+
n_rows: int = 2,
|
| 219 |
+
n_cols: int = 5,
|
| 220 |
+
):
|
| 221 |
+
n_panels = n_rows * n_cols
|
| 222 |
+
n_words = min(n_panels - 1, len(tokens))
|
| 223 |
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3.2, n_rows * 3.2))
|
| 224 |
+
axes = axes.flatten()
|
| 225 |
+
|
| 226 |
+
axes[0].imshow(image_pil)
|
| 227 |
+
axes[0].set_title("Original", fontsize=11, fontweight="bold")
|
| 228 |
+
axes[0].axis("off")
|
| 229 |
+
|
| 230 |
+
for index in range(n_words):
|
| 231 |
+
overlay = overlay_heatmap_on_image(image_pil, heatmaps[index])
|
| 232 |
+
axes[index + 1].imshow(overlay)
|
| 233 |
+
axes[index + 1].set_title(f"'{tokens[index]}'", fontsize=10, fontweight="bold")
|
| 234 |
+
axes[index + 1].axis("off")
|
| 235 |
+
|
| 236 |
+
for index in range(n_words + 1, n_panels):
|
| 237 |
+
axes[index].axis("off")
|
| 238 |
+
|
| 239 |
+
caption_preview = " ".join(tokens[:12])
|
| 240 |
+
fig.suptitle(
|
| 241 |
+
f"Cross-Attention Flow (2x5)\nCaption Tokens: {caption_preview}",
|
| 242 |
+
fontsize=12,
|
| 243 |
+
fontweight="bold",
|
| 244 |
+
y=1.02,
|
| 245 |
+
)
|
| 246 |
+
plt.tight_layout()
|
| 247 |
+
return fig
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def load_owlvit_detector(device):
|
| 251 |
+
from transformers import pipeline
|
| 252 |
+
pipe_device = 0 if str(device).startswith("cuda") else -1
|
| 253 |
+
return pipeline(
|
| 254 |
+
task="zero-shot-object-detection",
|
| 255 |
+
model="google/owlvit-base-patch32",
|
| 256 |
+
device=pipe_device,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def binarize_heatmap(heatmap_np: np.ndarray, target_hw: tuple) -> np.ndarray:
|
| 261 |
+
hm = cv2.resize(heatmap_np, (target_hw[1], target_hw[0]))
|
| 262 |
+
hm_u8 = np.uint8(255.0 * hm)
|
| 263 |
+
_, binary = cv2.threshold(hm_u8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 264 |
+
return binary > 0
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def calculate_iou(mask: np.ndarray, box: list, img_shape: tuple) -> float:
|
| 268 |
+
box_mask = np.zeros(img_shape, dtype=bool)
|
| 269 |
+
xmin, ymin, xmax, ymax = map(int, box)
|
| 270 |
+
xmin = max(0, xmin)
|
| 271 |
+
ymin = max(0, ymin)
|
| 272 |
+
xmax = min(img_shape[1], xmax)
|
| 273 |
+
ymax = min(img_shape[0], ymax)
|
| 274 |
+
box_mask[ymin:ymax, xmin:xmax] = True
|
| 275 |
+
inter = np.logical_and(mask, box_mask).sum()
|
| 276 |
+
union = np.logical_or(mask, box_mask).sum()
|
| 277 |
+
return float(inter) / union if union > 0 else 0.0
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def grade_alignment_with_detector(
|
| 281 |
+
image_pil: Image.Image,
|
| 282 |
+
tokens: List[str],
|
| 283 |
+
heatmaps: List[np.ndarray],
|
| 284 |
+
detector,
|
| 285 |
+
min_detection_score: float = 0.05,
|
| 286 |
+
) -> List[dict]:
|
| 287 |
+
results = []
|
| 288 |
+
img_shape = (image_pil.height, image_pil.width)
|
| 289 |
+
for idx, (word, hm) in enumerate(zip(tokens, heatmaps)):
|
| 290 |
+
clean_word = word.replace("##", "").lower()
|
| 291 |
+
if len(clean_word) < 3 or clean_word in STOP_WORDS or not clean_word.isalpha():
|
| 292 |
+
continue
|
| 293 |
+
|
| 294 |
+
detections = detector(image_pil, candidate_labels=[clean_word])
|
| 295 |
+
best_box, best_score = None, 0.0
|
| 296 |
+
for detection in detections:
|
| 297 |
+
if detection["score"] > best_score and detection["score"] >= min_detection_score:
|
| 298 |
+
best_score = detection["score"]
|
| 299 |
+
best_box = [
|
| 300 |
+
detection["box"]["xmin"],
|
| 301 |
+
detection["box"]["ymin"],
|
| 302 |
+
detection["box"]["xmax"],
|
| 303 |
+
detection["box"]["ymax"],
|
| 304 |
+
]
|
| 305 |
+
if best_box is None:
|
| 306 |
+
continue
|
| 307 |
+
|
| 308 |
+
mask = binarize_heatmap(hm, img_shape)
|
| 309 |
+
iou = calculate_iou(mask, best_box, img_shape)
|
| 310 |
+
results.append(
|
| 311 |
+
{
|
| 312 |
+
"word": clean_word,
|
| 313 |
+
"position": idx + 1,
|
| 314 |
+
"iou": float(iou),
|
| 315 |
+
"det_score": float(best_score),
|
| 316 |
+
"box": best_box,
|
| 317 |
+
}
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
return results
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def summarize_caption_alignment(results: List[dict], caption_length: int) -> dict:
|
| 324 |
+
if not results:
|
| 325 |
+
return {"caption_length": caption_length, "mean_alignment_iou": 0.0}
|
| 326 |
+
mean_iou = float(np.mean([item["iou"] for item in results]))
|
| 327 |
+
return {"caption_length": caption_length, "mean_alignment_iou": mean_iou}
|
| 328 |
+
|
requirements.txt
CHANGED
|
@@ -12,3 +12,5 @@ tqdm
|
|
| 12 |
accelerate
|
| 13 |
sentencepiece
|
| 14 |
pycocoevalcap
|
|
|
|
|
|
|
|
|
| 12 |
accelerate
|
| 13 |
sentencepiece
|
| 14 |
pycocoevalcap
|
| 15 |
+
matplotlib
|
| 16 |
+
opencv-python-headless
|