project_02_DS / models /attention_flow.py
griddev's picture
Deploy Streamlit Space app
64b98e5 verified
import math
from typing import List, Tuple
import cv2
import matplotlib
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
matplotlib.use("Agg")
import matplotlib.pyplot as plt
STOP_WORDS = {
"a", "an", "the", "and", "or", "but", "is", "are", "was", "were",
"in", "on", "at", "to", "for", "with", "by", "it", "this", "that",
"there", "here", "of", "up", "out", ".", ",", "!", "##",
}
class FlowExtractor:
def __init__(self, model):
self.model = model
self._hooks = []
self.layers = []
for layer in model.text_decoder.bert.encoder.layer:
if hasattr(layer, "crossattention"):
holder = {"fwd": None, "grad": None}
self.layers.append(holder)
def _make_hook(h):
def _fwd(module, inputs, outputs):
if len(outputs) > 1 and outputs[1] is not None:
h["fwd"] = outputs[1]
if h["fwd"].requires_grad:
h["fwd"].register_hook(
lambda g, _h=h: _h.update({"grad": g.detach()})
)
return _fwd
target = layer.crossattention.self
self._hooks.append(target.register_forward_hook(_make_hook(holder)))
def clear(self):
for holder in self.layers:
holder["fwd"] = None
holder["grad"] = None
def remove(self):
for hook in self._hooks:
hook.remove()
self._hooks = []
def encode_image_for_flow(model, processor, device, image_pil: Image.Image):
image_224 = image_pil.resize((224, 224), Image.LANCZOS)
inputs = processor(images=image_224, return_tensors="pt").to(device)
with torch.no_grad():
vision_out = model.vision_model(pixel_values=inputs["pixel_values"])
encoder_hidden = vision_out[0].detach().requires_grad_(False)
encoder_mask = torch.ones(encoder_hidden.size()[:-1], dtype=torch.long, device=device)
return image_224, encoder_hidden, encoder_mask
def _single_layer_gradcam(holder, token_idx: int = -1) -> torch.Tensor:
attn = holder["fwd"][:, :, token_idx, :]
grad = holder["grad"][:, :, token_idx, :]
cam = (attn * grad).mean(dim=1).squeeze()
return torch.clamp(cam, min=0.0)
def _normalize1d(tensor: torch.Tensor) -> torch.Tensor:
denom = tensor.sum()
if denom > 0:
return tensor / denom
return tensor
def compute_attention_flow(
extractor: FlowExtractor,
num_image_tokens: int | None = None,
residual_weight: float = 0.05,
out_resolution: int = 224,
) -> np.ndarray:
valid_cams = []
for holder in extractor.layers:
if holder["fwd"] is None or holder["grad"] is None:
continue
valid_cams.append(_single_layer_gradcam(holder).detach())
if not valid_cams:
return np.zeros((out_resolution, out_resolution), dtype=np.float32)
if num_image_tokens is None:
num_image_tokens = int(valid_cams[0].numel())
valid_cams = [cam for cam in valid_cams if int(cam.numel()) == int(num_image_tokens)]
if not valid_cams:
return np.zeros((out_resolution, out_resolution), dtype=np.float32)
uniform = torch.ones(num_image_tokens, device=valid_cams[0].device) / num_image_tokens
rollout = _normalize1d(valid_cams[0])
for cam in valid_cams[1:]:
rollout = _normalize1d(rollout) * _normalize1d(cam) + residual_weight * uniform
rollout = torch.clamp(rollout, min=0.0)
spatial = rollout[1:]
grid_size = int(math.sqrt(spatial.numel()))
hm_tensor = spatial.detach().cpu().reshape(1, 1, grid_size, grid_size).float()
hm_up = F.interpolate(
hm_tensor,
size=(out_resolution, out_resolution),
mode="bicubic",
align_corners=False,
).squeeze()
hm_np = hm_up.numpy()
lo, hi = hm_np.min(), hm_np.max()
if hi > lo:
hm_np = (hm_np - lo) / (hi - lo)
else:
hm_np = np.zeros_like(hm_np)
return hm_np.astype(np.float32)
def decode_generated_caption_with_flow(
model,
processor,
device,
encoder_hidden,
encoder_mask,
max_tokens: int = 20,
) -> Tuple[List[str], List[np.ndarray]]:
extractor = FlowExtractor(model)
input_ids = torch.LongTensor([[model.config.text_config.bos_token_id]]).to(device)
tokens, heatmaps = [], []
for _ in range(max_tokens):
model.zero_grad()
extractor.clear()
outputs = model.text_decoder(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden,
encoder_attention_mask=encoder_mask,
output_attentions=True,
return_dict=True,
)
logits = outputs.logits[:, -1, :]
next_token = torch.argmax(logits, dim=-1)
if next_token.item() == model.config.text_config.sep_token_id:
break
logits[0, next_token.item()].backward(retain_graph=False)
heatmaps.append(compute_attention_flow(extractor))
tokens.append(processor.tokenizer.decode([next_token.item()]).strip())
input_ids = torch.cat([input_ids, next_token.reshape(1, 1)], dim=-1)
extractor.remove()
return tokens, heatmaps
def decode_custom_text_with_flow(
model,
processor,
device,
encoder_hidden,
encoder_mask,
text: str,
max_tokens: int = 20,
) -> Tuple[List[str], List[np.ndarray]]:
extractor = FlowExtractor(model)
token_ids = processor.tokenizer(
text,
add_special_tokens=False,
return_attention_mask=False,
)["input_ids"][:max_tokens]
input_ids = torch.LongTensor([[model.config.text_config.bos_token_id]]).to(device)
tokens, heatmaps = [], []
for target_token_id in token_ids:
model.zero_grad()
extractor.clear()
outputs = model.text_decoder(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden,
encoder_attention_mask=encoder_mask,
output_attentions=True,
return_dict=True,
)
logits = outputs.logits[:, -1, :]
score = logits[0, target_token_id]
score.backward(retain_graph=False)
heatmaps.append(compute_attention_flow(extractor))
tokens.append(processor.tokenizer.decode([target_token_id]).strip())
next_tensor = torch.LongTensor([[target_token_id]]).to(device)
input_ids = torch.cat([input_ids, next_tensor], dim=-1)
extractor.remove()
return tokens, heatmaps
def overlay_heatmap_on_image(
image_pil: Image.Image,
heatmap_np: np.ndarray,
alpha: float = 0.5,
hot_threshold: float = 0.1,
) -> Image.Image:
h, w = heatmap_np.shape
image_np = np.array(image_pil.resize((w, h), Image.LANCZOS))
hm_u8 = np.uint8(255.0 * heatmap_np)
colored = cv2.applyColorMap(hm_u8, cv2.COLORMAP_INFERNO)
colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
mask = (heatmap_np > hot_threshold).astype(np.float32)[..., None]
blended = image_np * (1 - mask * alpha) + colored * (mask * alpha)
return Image.fromarray(blended.astype(np.uint8))
def build_attention_grid_figure(
image_pil: Image.Image,
tokens: List[str],
heatmaps: List[np.ndarray],
n_rows: int = 2,
n_cols: int = 5,
):
n_panels = n_rows * n_cols
n_words = min(n_panels - 1, len(tokens))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3.2, n_rows * 3.2))
axes = axes.flatten()
axes[0].imshow(image_pil)
axes[0].set_title("Original", fontsize=11, fontweight="bold")
axes[0].axis("off")
for index in range(n_words):
overlay = overlay_heatmap_on_image(image_pil, heatmaps[index])
axes[index + 1].imshow(overlay)
axes[index + 1].set_title(f"'{tokens[index]}'", fontsize=10, fontweight="bold")
axes[index + 1].axis("off")
for index in range(n_words + 1, n_panels):
axes[index].axis("off")
caption_preview = " ".join(tokens[:12])
fig.suptitle(
f"Cross-Attention Flow (2x5)\nCaption Tokens: {caption_preview}",
fontsize=12,
fontweight="bold",
y=1.02,
)
plt.tight_layout()
return fig
def load_owlvit_detector(device):
from transformers import pipeline
pipe_device = 0 if str(device).startswith("cuda") else -1
return pipeline(
task="zero-shot-object-detection",
model="google/owlvit-base-patch32",
device=pipe_device,
)
def binarize_heatmap(heatmap_np: np.ndarray, target_hw: tuple) -> np.ndarray:
hm = cv2.resize(heatmap_np, (target_hw[1], target_hw[0]))
hm_u8 = np.uint8(255.0 * hm)
_, binary = cv2.threshold(hm_u8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return binary > 0
def calculate_iou(mask: np.ndarray, box: list, img_shape: tuple) -> float:
box_mask = np.zeros(img_shape, dtype=bool)
xmin, ymin, xmax, ymax = map(int, box)
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = min(img_shape[1], xmax)
ymax = min(img_shape[0], ymax)
box_mask[ymin:ymax, xmin:xmax] = True
inter = np.logical_and(mask, box_mask).sum()
union = np.logical_or(mask, box_mask).sum()
return float(inter) / union if union > 0 else 0.0
def grade_alignment_with_detector(
image_pil: Image.Image,
tokens: List[str],
heatmaps: List[np.ndarray],
detector,
min_detection_score: float = 0.05,
) -> List[dict]:
results = []
img_shape = (image_pil.height, image_pil.width)
for idx, (word, hm) in enumerate(zip(tokens, heatmaps)):
clean_word = word.replace("##", "").lower()
if len(clean_word) < 3 or clean_word in STOP_WORDS or not clean_word.isalpha():
continue
detections = detector(image_pil, candidate_labels=[clean_word])
best_box, best_score = None, 0.0
for detection in detections:
if detection["score"] > best_score and detection["score"] >= min_detection_score:
best_score = detection["score"]
best_box = [
detection["box"]["xmin"],
detection["box"]["ymin"],
detection["box"]["xmax"],
detection["box"]["ymax"],
]
if best_box is None:
continue
mask = binarize_heatmap(hm, img_shape)
iou = calculate_iou(mask, best_box, img_shape)
results.append(
{
"word": clean_word,
"position": idx + 1,
"iou": float(iou),
"det_score": float(best_score),
"box": best_box,
}
)
return results
def summarize_caption_alignment(results: List[dict], caption_length: int) -> dict:
if not results:
return {"caption_length": caption_length, "mean_alignment_iou": 0.0}
mean_iou = float(np.mean([item["iou"] for item in results]))
return {"caption_length": caption_length, "mean_alignment_iou": mean_iou}