griddev commited on
Commit
64b98e5
Β·
verified Β·
1 Parent(s): ce25d0a

Deploy Streamlit Space app

Browse files
Files changed (2) hide show
  1. app.py +6 -4
  2. models/attention_flow.py +7 -2
app.py CHANGED
@@ -782,7 +782,7 @@ def render_caption_card(model_name, caption, weight_src, num_beams, length_penal
782
  # ─────────────────────────────────────────────────────────────────────────────
783
 
784
  tab_caption, tab_compare, tab_attention, tab_results = st.tabs([
785
- "πŸ–ΌοΈ Caption", "πŸ”€ Compare All Models", "🧠 Attention Explorer", "πŸ“Š Experiment Results"
786
  ])
787
 
788
 
@@ -998,11 +998,12 @@ with tab_compare:
998
 
999
 
1000
  # ═══════════════════════════════════════════════════════════════════════════
1001
- # Tab 3 β€” Attention Explorer (Task 2)
1002
  # ═══════════════════════════════════════════════════════════════════════════
1003
 
1004
  with tab_attention:
1005
- st.markdown("### 🧠 BLIP Attention Explorer")
 
1006
  st.caption(
1007
  "Step-by-step cross-attention analysis with rollout across decoder layers, "
1008
  "2x5 heatmap grid, IoU grounding score, and caption-length summary."
@@ -1049,10 +1050,11 @@ with tab_attention:
1049
  )
1050
 
1051
  max_attn_steps = st.slider(
1052
- "Caption Steps to Analyze",
1053
  min_value=3,
1054
  max_value=12,
1055
  value=9,
 
1056
  key="attn_steps",
1057
  )
1058
  run_iou = st.toggle(
 
782
  # ─────────────────────────────────────────────────────────────────────────────
783
 
784
  tab_caption, tab_compare, tab_attention, tab_results = st.tabs([
785
+ "πŸ–ΌοΈ Caption", "πŸ”€ Compare All Models", "🧭 Word Focus Map", "πŸ“Š Experiment Results"
786
  ])
787
 
788
 
 
998
 
999
 
1000
  # ═══════════════════════════════════════════════════════════════════════════
1001
+ # Tab 3 β€” Word Focus Map (Task 2)
1002
  # ═══════════════════════════════════════════════════════════════════════════
1003
 
1004
  with tab_attention:
1005
+ st.markdown("### 🧭 Word Focus Map")
1006
+ st.markdown("`Task: Attention Weight Visualization & Cross-Attention Rollout for Caption Generation`")
1007
  st.caption(
1008
  "Step-by-step cross-attention analysis with rollout across decoder layers, "
1009
  "2x5 heatmap grid, IoU grounding score, and caption-length summary."
 
1050
  )
1051
 
1052
  max_attn_steps = st.slider(
1053
+ "How many words to trace",
1054
  min_value=3,
1055
  max_value=12,
1056
  value=9,
1057
+ help="One step = one word position in the generated/custom text (word 1, word 2, ...).",
1058
  key="attn_steps",
1059
  )
1060
  run_iou = st.toggle(
models/attention_flow.py CHANGED
@@ -80,7 +80,7 @@ def _normalize1d(tensor: torch.Tensor) -> torch.Tensor:
80
 
81
  def compute_attention_flow(
82
  extractor: FlowExtractor,
83
- num_image_tokens: int = 197,
84
  residual_weight: float = 0.05,
85
  out_resolution: int = 224,
86
  ) -> np.ndarray:
@@ -93,6 +93,12 @@ def compute_attention_flow(
93
  if not valid_cams:
94
  return np.zeros((out_resolution, out_resolution), dtype=np.float32)
95
 
 
 
 
 
 
 
96
  uniform = torch.ones(num_image_tokens, device=valid_cams[0].device) / num_image_tokens
97
  rollout = _normalize1d(valid_cams[0])
98
  for cam in valid_cams[1:]:
@@ -325,4 +331,3 @@ def summarize_caption_alignment(results: List[dict], caption_length: int) -> dic
325
  return {"caption_length": caption_length, "mean_alignment_iou": 0.0}
326
  mean_iou = float(np.mean([item["iou"] for item in results]))
327
  return {"caption_length": caption_length, "mean_alignment_iou": mean_iou}
328
-
 
80
 
81
  def compute_attention_flow(
82
  extractor: FlowExtractor,
83
+ num_image_tokens: int | None = None,
84
  residual_weight: float = 0.05,
85
  out_resolution: int = 224,
86
  ) -> np.ndarray:
 
93
  if not valid_cams:
94
  return np.zeros((out_resolution, out_resolution), dtype=np.float32)
95
 
96
+ if num_image_tokens is None:
97
+ num_image_tokens = int(valid_cams[0].numel())
98
+ valid_cams = [cam for cam in valid_cams if int(cam.numel()) == int(num_image_tokens)]
99
+ if not valid_cams:
100
+ return np.zeros((out_resolution, out_resolution), dtype=np.float32)
101
+
102
  uniform = torch.ones(num_image_tokens, device=valid_cams[0].device) / num_image_tokens
103
  rollout = _normalize1d(valid_cams[0])
104
  for cam in valid_cams[1:]:
 
331
  return {"caption_length": caption_length, "mean_alignment_iou": 0.0}
332
  mean_iou = float(np.mean([item["iou"] for item in results]))
333
  return {"caption_length": caption_length, "mean_alignment_iou": mean_iou}