import math from typing import List, Tuple import cv2 import matplotlib import numpy as np import torch import torch.nn.functional as F from PIL import Image matplotlib.use("Agg") import matplotlib.pyplot as plt STOP_WORDS = { "a", "an", "the", "and", "or", "but", "is", "are", "was", "were", "in", "on", "at", "to", "for", "with", "by", "it", "this", "that", "there", "here", "of", "up", "out", ".", ",", "!", "##", } class FlowExtractor: def __init__(self, model): self.model = model self._hooks = [] self.layers = [] for layer in model.text_decoder.bert.encoder.layer: if hasattr(layer, "crossattention"): holder = {"fwd": None, "grad": None} self.layers.append(holder) def _make_hook(h): def _fwd(module, inputs, outputs): if len(outputs) > 1 and outputs[1] is not None: h["fwd"] = outputs[1] if h["fwd"].requires_grad: h["fwd"].register_hook( lambda g, _h=h: _h.update({"grad": g.detach()}) ) return _fwd target = layer.crossattention.self self._hooks.append(target.register_forward_hook(_make_hook(holder))) def clear(self): for holder in self.layers: holder["fwd"] = None holder["grad"] = None def remove(self): for hook in self._hooks: hook.remove() self._hooks = [] def encode_image_for_flow(model, processor, device, image_pil: Image.Image): image_224 = image_pil.resize((224, 224), Image.LANCZOS) inputs = processor(images=image_224, return_tensors="pt").to(device) with torch.no_grad(): vision_out = model.vision_model(pixel_values=inputs["pixel_values"]) encoder_hidden = vision_out[0].detach().requires_grad_(False) encoder_mask = torch.ones(encoder_hidden.size()[:-1], dtype=torch.long, device=device) return image_224, encoder_hidden, encoder_mask def _single_layer_gradcam(holder, token_idx: int = -1) -> torch.Tensor: attn = holder["fwd"][:, :, token_idx, :] grad = holder["grad"][:, :, token_idx, :] cam = (attn * grad).mean(dim=1).squeeze() return torch.clamp(cam, min=0.0) def _normalize1d(tensor: torch.Tensor) -> torch.Tensor: denom = tensor.sum() if denom > 0: return tensor / denom return tensor def compute_attention_flow( extractor: FlowExtractor, num_image_tokens: int | None = None, residual_weight: float = 0.05, out_resolution: int = 224, ) -> np.ndarray: valid_cams = [] for holder in extractor.layers: if holder["fwd"] is None or holder["grad"] is None: continue valid_cams.append(_single_layer_gradcam(holder).detach()) if not valid_cams: return np.zeros((out_resolution, out_resolution), dtype=np.float32) if num_image_tokens is None: num_image_tokens = int(valid_cams[0].numel()) valid_cams = [cam for cam in valid_cams if int(cam.numel()) == int(num_image_tokens)] if not valid_cams: return np.zeros((out_resolution, out_resolution), dtype=np.float32) uniform = torch.ones(num_image_tokens, device=valid_cams[0].device) / num_image_tokens rollout = _normalize1d(valid_cams[0]) for cam in valid_cams[1:]: rollout = _normalize1d(rollout) * _normalize1d(cam) + residual_weight * uniform rollout = torch.clamp(rollout, min=0.0) spatial = rollout[1:] grid_size = int(math.sqrt(spatial.numel())) hm_tensor = spatial.detach().cpu().reshape(1, 1, grid_size, grid_size).float() hm_up = F.interpolate( hm_tensor, size=(out_resolution, out_resolution), mode="bicubic", align_corners=False, ).squeeze() hm_np = hm_up.numpy() lo, hi = hm_np.min(), hm_np.max() if hi > lo: hm_np = (hm_np - lo) / (hi - lo) else: hm_np = np.zeros_like(hm_np) return hm_np.astype(np.float32) def decode_generated_caption_with_flow( model, processor, device, encoder_hidden, encoder_mask, max_tokens: int = 20, ) -> Tuple[List[str], List[np.ndarray]]: extractor = FlowExtractor(model) input_ids = torch.LongTensor([[model.config.text_config.bos_token_id]]).to(device) tokens, heatmaps = [], [] for _ in range(max_tokens): model.zero_grad() extractor.clear() outputs = model.text_decoder( input_ids=input_ids, encoder_hidden_states=encoder_hidden, encoder_attention_mask=encoder_mask, output_attentions=True, return_dict=True, ) logits = outputs.logits[:, -1, :] next_token = torch.argmax(logits, dim=-1) if next_token.item() == model.config.text_config.sep_token_id: break logits[0, next_token.item()].backward(retain_graph=False) heatmaps.append(compute_attention_flow(extractor)) tokens.append(processor.tokenizer.decode([next_token.item()]).strip()) input_ids = torch.cat([input_ids, next_token.reshape(1, 1)], dim=-1) extractor.remove() return tokens, heatmaps def decode_custom_text_with_flow( model, processor, device, encoder_hidden, encoder_mask, text: str, max_tokens: int = 20, ) -> Tuple[List[str], List[np.ndarray]]: extractor = FlowExtractor(model) token_ids = processor.tokenizer( text, add_special_tokens=False, return_attention_mask=False, )["input_ids"][:max_tokens] input_ids = torch.LongTensor([[model.config.text_config.bos_token_id]]).to(device) tokens, heatmaps = [], [] for target_token_id in token_ids: model.zero_grad() extractor.clear() outputs = model.text_decoder( input_ids=input_ids, encoder_hidden_states=encoder_hidden, encoder_attention_mask=encoder_mask, output_attentions=True, return_dict=True, ) logits = outputs.logits[:, -1, :] score = logits[0, target_token_id] score.backward(retain_graph=False) heatmaps.append(compute_attention_flow(extractor)) tokens.append(processor.tokenizer.decode([target_token_id]).strip()) next_tensor = torch.LongTensor([[target_token_id]]).to(device) input_ids = torch.cat([input_ids, next_tensor], dim=-1) extractor.remove() return tokens, heatmaps def overlay_heatmap_on_image( image_pil: Image.Image, heatmap_np: np.ndarray, alpha: float = 0.5, hot_threshold: float = 0.1, ) -> Image.Image: h, w = heatmap_np.shape image_np = np.array(image_pil.resize((w, h), Image.LANCZOS)) hm_u8 = np.uint8(255.0 * heatmap_np) colored = cv2.applyColorMap(hm_u8, cv2.COLORMAP_INFERNO) colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB) mask = (heatmap_np > hot_threshold).astype(np.float32)[..., None] blended = image_np * (1 - mask * alpha) + colored * (mask * alpha) return Image.fromarray(blended.astype(np.uint8)) def build_attention_grid_figure( image_pil: Image.Image, tokens: List[str], heatmaps: List[np.ndarray], n_rows: int = 2, n_cols: int = 5, ): n_panels = n_rows * n_cols n_words = min(n_panels - 1, len(tokens)) fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3.2, n_rows * 3.2)) axes = axes.flatten() axes[0].imshow(image_pil) axes[0].set_title("Original", fontsize=11, fontweight="bold") axes[0].axis("off") for index in range(n_words): overlay = overlay_heatmap_on_image(image_pil, heatmaps[index]) axes[index + 1].imshow(overlay) axes[index + 1].set_title(f"'{tokens[index]}'", fontsize=10, fontweight="bold") axes[index + 1].axis("off") for index in range(n_words + 1, n_panels): axes[index].axis("off") caption_preview = " ".join(tokens[:12]) fig.suptitle( f"Cross-Attention Flow (2x5)\nCaption Tokens: {caption_preview}", fontsize=12, fontweight="bold", y=1.02, ) plt.tight_layout() return fig def load_owlvit_detector(device): from transformers import pipeline pipe_device = 0 if str(device).startswith("cuda") else -1 return pipeline( task="zero-shot-object-detection", model="google/owlvit-base-patch32", device=pipe_device, ) def binarize_heatmap(heatmap_np: np.ndarray, target_hw: tuple) -> np.ndarray: hm = cv2.resize(heatmap_np, (target_hw[1], target_hw[0])) hm_u8 = np.uint8(255.0 * hm) _, binary = cv2.threshold(hm_u8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) return binary > 0 def calculate_iou(mask: np.ndarray, box: list, img_shape: tuple) -> float: box_mask = np.zeros(img_shape, dtype=bool) xmin, ymin, xmax, ymax = map(int, box) xmin = max(0, xmin) ymin = max(0, ymin) xmax = min(img_shape[1], xmax) ymax = min(img_shape[0], ymax) box_mask[ymin:ymax, xmin:xmax] = True inter = np.logical_and(mask, box_mask).sum() union = np.logical_or(mask, box_mask).sum() return float(inter) / union if union > 0 else 0.0 def grade_alignment_with_detector( image_pil: Image.Image, tokens: List[str], heatmaps: List[np.ndarray], detector, min_detection_score: float = 0.05, ) -> List[dict]: results = [] img_shape = (image_pil.height, image_pil.width) for idx, (word, hm) in enumerate(zip(tokens, heatmaps)): clean_word = word.replace("##", "").lower() if len(clean_word) < 3 or clean_word in STOP_WORDS or not clean_word.isalpha(): continue detections = detector(image_pil, candidate_labels=[clean_word]) best_box, best_score = None, 0.0 for detection in detections: if detection["score"] > best_score and detection["score"] >= min_detection_score: best_score = detection["score"] best_box = [ detection["box"]["xmin"], detection["box"]["ymin"], detection["box"]["xmax"], detection["box"]["ymax"], ] if best_box is None: continue mask = binarize_heatmap(hm, img_shape) iou = calculate_iou(mask, best_box, img_shape) results.append( { "word": clean_word, "position": idx + 1, "iou": float(iou), "det_score": float(best_score), "box": best_box, } ) return results def summarize_caption_alignment(results: List[dict], caption_length: int) -> dict: if not results: return {"caption_length": caption_length, "mean_alignment_iou": 0.0} mean_iou = float(np.mean([item["iou"] for item in results])) return {"caption_length": caption_length, "mean_alignment_iou": mean_iou}