Spaces:
Sleeping
Sleeping
Deploy Streamlit Space app
Browse files- app.py +6 -4
- models/attention_flow.py +7 -2
app.py
CHANGED
|
@@ -782,7 +782,7 @@ def render_caption_card(model_name, caption, weight_src, num_beams, length_penal
|
|
| 782 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 783 |
|
| 784 |
tab_caption, tab_compare, tab_attention, tab_results = st.tabs([
|
| 785 |
-
"πΌοΈ Caption", "π Compare All Models", "
|
| 786 |
])
|
| 787 |
|
| 788 |
|
|
@@ -998,11 +998,12 @@ with tab_compare:
|
|
| 998 |
|
| 999 |
|
| 1000 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1001 |
-
# Tab 3 β
|
| 1002 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1003 |
|
| 1004 |
with tab_attention:
|
| 1005 |
-
st.markdown("###
|
|
|
|
| 1006 |
st.caption(
|
| 1007 |
"Step-by-step cross-attention analysis with rollout across decoder layers, "
|
| 1008 |
"2x5 heatmap grid, IoU grounding score, and caption-length summary."
|
|
@@ -1049,10 +1050,11 @@ with tab_attention:
|
|
| 1049 |
)
|
| 1050 |
|
| 1051 |
max_attn_steps = st.slider(
|
| 1052 |
-
"
|
| 1053 |
min_value=3,
|
| 1054 |
max_value=12,
|
| 1055 |
value=9,
|
|
|
|
| 1056 |
key="attn_steps",
|
| 1057 |
)
|
| 1058 |
run_iou = st.toggle(
|
|
|
|
| 782 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 783 |
|
| 784 |
tab_caption, tab_compare, tab_attention, tab_results = st.tabs([
|
| 785 |
+
"πΌοΈ Caption", "π Compare All Models", "π§ Word Focus Map", "π Experiment Results"
|
| 786 |
])
|
| 787 |
|
| 788 |
|
|
|
|
| 998 |
|
| 999 |
|
| 1000 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1001 |
+
# Tab 3 β Word Focus Map (Task 2)
|
| 1002 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1003 |
|
| 1004 |
with tab_attention:
|
| 1005 |
+
st.markdown("### π§ Word Focus Map")
|
| 1006 |
+
st.markdown("`Task: Attention Weight Visualization & Cross-Attention Rollout for Caption Generation`")
|
| 1007 |
st.caption(
|
| 1008 |
"Step-by-step cross-attention analysis with rollout across decoder layers, "
|
| 1009 |
"2x5 heatmap grid, IoU grounding score, and caption-length summary."
|
|
|
|
| 1050 |
)
|
| 1051 |
|
| 1052 |
max_attn_steps = st.slider(
|
| 1053 |
+
"How many words to trace",
|
| 1054 |
min_value=3,
|
| 1055 |
max_value=12,
|
| 1056 |
value=9,
|
| 1057 |
+
help="One step = one word position in the generated/custom text (word 1, word 2, ...).",
|
| 1058 |
key="attn_steps",
|
| 1059 |
)
|
| 1060 |
run_iou = st.toggle(
|
models/attention_flow.py
CHANGED
|
@@ -80,7 +80,7 @@ def _normalize1d(tensor: torch.Tensor) -> torch.Tensor:
|
|
| 80 |
|
| 81 |
def compute_attention_flow(
|
| 82 |
extractor: FlowExtractor,
|
| 83 |
-
num_image_tokens: int =
|
| 84 |
residual_weight: float = 0.05,
|
| 85 |
out_resolution: int = 224,
|
| 86 |
) -> np.ndarray:
|
|
@@ -93,6 +93,12 @@ def compute_attention_flow(
|
|
| 93 |
if not valid_cams:
|
| 94 |
return np.zeros((out_resolution, out_resolution), dtype=np.float32)
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
uniform = torch.ones(num_image_tokens, device=valid_cams[0].device) / num_image_tokens
|
| 97 |
rollout = _normalize1d(valid_cams[0])
|
| 98 |
for cam in valid_cams[1:]:
|
|
@@ -325,4 +331,3 @@ def summarize_caption_alignment(results: List[dict], caption_length: int) -> dic
|
|
| 325 |
return {"caption_length": caption_length, "mean_alignment_iou": 0.0}
|
| 326 |
mean_iou = float(np.mean([item["iou"] for item in results]))
|
| 327 |
return {"caption_length": caption_length, "mean_alignment_iou": mean_iou}
|
| 328 |
-
|
|
|
|
| 80 |
|
| 81 |
def compute_attention_flow(
|
| 82 |
extractor: FlowExtractor,
|
| 83 |
+
num_image_tokens: int | None = None,
|
| 84 |
residual_weight: float = 0.05,
|
| 85 |
out_resolution: int = 224,
|
| 86 |
) -> np.ndarray:
|
|
|
|
| 93 |
if not valid_cams:
|
| 94 |
return np.zeros((out_resolution, out_resolution), dtype=np.float32)
|
| 95 |
|
| 96 |
+
if num_image_tokens is None:
|
| 97 |
+
num_image_tokens = int(valid_cams[0].numel())
|
| 98 |
+
valid_cams = [cam for cam in valid_cams if int(cam.numel()) == int(num_image_tokens)]
|
| 99 |
+
if not valid_cams:
|
| 100 |
+
return np.zeros((out_resolution, out_resolution), dtype=np.float32)
|
| 101 |
+
|
| 102 |
uniform = torch.ones(num_image_tokens, device=valid_cams[0].device) / num_image_tokens
|
| 103 |
rollout = _normalize1d(valid_cams[0])
|
| 104 |
for cam in valid_cams[1:]:
|
|
|
|
| 331 |
return {"caption_length": caption_length, "mean_alignment_iou": 0.0}
|
| 332 |
mean_iou = float(np.mean([item["iou"] for item in results]))
|
| 333 |
return {"caption_length": caption_length, "mean_alignment_iou": mean_iou}
|
|
|