Spaces:
Sleeping
Sleeping
| import math | |
| from typing import List, Tuple | |
| import cv2 | |
| import matplotlib | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| STOP_WORDS = { | |
| "a", "an", "the", "and", "or", "but", "is", "are", "was", "were", | |
| "in", "on", "at", "to", "for", "with", "by", "it", "this", "that", | |
| "there", "here", "of", "up", "out", ".", ",", "!", "##", | |
| } | |
| class FlowExtractor: | |
| def __init__(self, model): | |
| self.model = model | |
| self._hooks = [] | |
| self.layers = [] | |
| for layer in model.text_decoder.bert.encoder.layer: | |
| if hasattr(layer, "crossattention"): | |
| holder = {"fwd": None, "grad": None} | |
| self.layers.append(holder) | |
| def _make_hook(h): | |
| def _fwd(module, inputs, outputs): | |
| if len(outputs) > 1 and outputs[1] is not None: | |
| h["fwd"] = outputs[1] | |
| if h["fwd"].requires_grad: | |
| h["fwd"].register_hook( | |
| lambda g, _h=h: _h.update({"grad": g.detach()}) | |
| ) | |
| return _fwd | |
| target = layer.crossattention.self | |
| self._hooks.append(target.register_forward_hook(_make_hook(holder))) | |
| def clear(self): | |
| for holder in self.layers: | |
| holder["fwd"] = None | |
| holder["grad"] = None | |
| def remove(self): | |
| for hook in self._hooks: | |
| hook.remove() | |
| self._hooks = [] | |
| def encode_image_for_flow(model, processor, device, image_pil: Image.Image): | |
| image_224 = image_pil.resize((224, 224), Image.LANCZOS) | |
| inputs = processor(images=image_224, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| vision_out = model.vision_model(pixel_values=inputs["pixel_values"]) | |
| encoder_hidden = vision_out[0].detach().requires_grad_(False) | |
| encoder_mask = torch.ones(encoder_hidden.size()[:-1], dtype=torch.long, device=device) | |
| return image_224, encoder_hidden, encoder_mask | |
| def _single_layer_gradcam(holder, token_idx: int = -1) -> torch.Tensor: | |
| attn = holder["fwd"][:, :, token_idx, :] | |
| grad = holder["grad"][:, :, token_idx, :] | |
| cam = (attn * grad).mean(dim=1).squeeze() | |
| return torch.clamp(cam, min=0.0) | |
| def _normalize1d(tensor: torch.Tensor) -> torch.Tensor: | |
| denom = tensor.sum() | |
| if denom > 0: | |
| return tensor / denom | |
| return tensor | |
| def compute_attention_flow( | |
| extractor: FlowExtractor, | |
| num_image_tokens: int | None = None, | |
| residual_weight: float = 0.05, | |
| out_resolution: int = 224, | |
| ) -> np.ndarray: | |
| valid_cams = [] | |
| for holder in extractor.layers: | |
| if holder["fwd"] is None or holder["grad"] is None: | |
| continue | |
| valid_cams.append(_single_layer_gradcam(holder).detach()) | |
| if not valid_cams: | |
| return np.zeros((out_resolution, out_resolution), dtype=np.float32) | |
| if num_image_tokens is None: | |
| num_image_tokens = int(valid_cams[0].numel()) | |
| valid_cams = [cam for cam in valid_cams if int(cam.numel()) == int(num_image_tokens)] | |
| if not valid_cams: | |
| return np.zeros((out_resolution, out_resolution), dtype=np.float32) | |
| uniform = torch.ones(num_image_tokens, device=valid_cams[0].device) / num_image_tokens | |
| rollout = _normalize1d(valid_cams[0]) | |
| for cam in valid_cams[1:]: | |
| rollout = _normalize1d(rollout) * _normalize1d(cam) + residual_weight * uniform | |
| rollout = torch.clamp(rollout, min=0.0) | |
| spatial = rollout[1:] | |
| grid_size = int(math.sqrt(spatial.numel())) | |
| hm_tensor = spatial.detach().cpu().reshape(1, 1, grid_size, grid_size).float() | |
| hm_up = F.interpolate( | |
| hm_tensor, | |
| size=(out_resolution, out_resolution), | |
| mode="bicubic", | |
| align_corners=False, | |
| ).squeeze() | |
| hm_np = hm_up.numpy() | |
| lo, hi = hm_np.min(), hm_np.max() | |
| if hi > lo: | |
| hm_np = (hm_np - lo) / (hi - lo) | |
| else: | |
| hm_np = np.zeros_like(hm_np) | |
| return hm_np.astype(np.float32) | |
| def decode_generated_caption_with_flow( | |
| model, | |
| processor, | |
| device, | |
| encoder_hidden, | |
| encoder_mask, | |
| max_tokens: int = 20, | |
| ) -> Tuple[List[str], List[np.ndarray]]: | |
| extractor = FlowExtractor(model) | |
| input_ids = torch.LongTensor([[model.config.text_config.bos_token_id]]).to(device) | |
| tokens, heatmaps = [], [] | |
| for _ in range(max_tokens): | |
| model.zero_grad() | |
| extractor.clear() | |
| outputs = model.text_decoder( | |
| input_ids=input_ids, | |
| encoder_hidden_states=encoder_hidden, | |
| encoder_attention_mask=encoder_mask, | |
| output_attentions=True, | |
| return_dict=True, | |
| ) | |
| logits = outputs.logits[:, -1, :] | |
| next_token = torch.argmax(logits, dim=-1) | |
| if next_token.item() == model.config.text_config.sep_token_id: | |
| break | |
| logits[0, next_token.item()].backward(retain_graph=False) | |
| heatmaps.append(compute_attention_flow(extractor)) | |
| tokens.append(processor.tokenizer.decode([next_token.item()]).strip()) | |
| input_ids = torch.cat([input_ids, next_token.reshape(1, 1)], dim=-1) | |
| extractor.remove() | |
| return tokens, heatmaps | |
| def decode_custom_text_with_flow( | |
| model, | |
| processor, | |
| device, | |
| encoder_hidden, | |
| encoder_mask, | |
| text: str, | |
| max_tokens: int = 20, | |
| ) -> Tuple[List[str], List[np.ndarray]]: | |
| extractor = FlowExtractor(model) | |
| token_ids = processor.tokenizer( | |
| text, | |
| add_special_tokens=False, | |
| return_attention_mask=False, | |
| )["input_ids"][:max_tokens] | |
| input_ids = torch.LongTensor([[model.config.text_config.bos_token_id]]).to(device) | |
| tokens, heatmaps = [], [] | |
| for target_token_id in token_ids: | |
| model.zero_grad() | |
| extractor.clear() | |
| outputs = model.text_decoder( | |
| input_ids=input_ids, | |
| encoder_hidden_states=encoder_hidden, | |
| encoder_attention_mask=encoder_mask, | |
| output_attentions=True, | |
| return_dict=True, | |
| ) | |
| logits = outputs.logits[:, -1, :] | |
| score = logits[0, target_token_id] | |
| score.backward(retain_graph=False) | |
| heatmaps.append(compute_attention_flow(extractor)) | |
| tokens.append(processor.tokenizer.decode([target_token_id]).strip()) | |
| next_tensor = torch.LongTensor([[target_token_id]]).to(device) | |
| input_ids = torch.cat([input_ids, next_tensor], dim=-1) | |
| extractor.remove() | |
| return tokens, heatmaps | |
| def overlay_heatmap_on_image( | |
| image_pil: Image.Image, | |
| heatmap_np: np.ndarray, | |
| alpha: float = 0.5, | |
| hot_threshold: float = 0.1, | |
| ) -> Image.Image: | |
| h, w = heatmap_np.shape | |
| image_np = np.array(image_pil.resize((w, h), Image.LANCZOS)) | |
| hm_u8 = np.uint8(255.0 * heatmap_np) | |
| colored = cv2.applyColorMap(hm_u8, cv2.COLORMAP_INFERNO) | |
| colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB) | |
| mask = (heatmap_np > hot_threshold).astype(np.float32)[..., None] | |
| blended = image_np * (1 - mask * alpha) + colored * (mask * alpha) | |
| return Image.fromarray(blended.astype(np.uint8)) | |
| def build_attention_grid_figure( | |
| image_pil: Image.Image, | |
| tokens: List[str], | |
| heatmaps: List[np.ndarray], | |
| n_rows: int = 2, | |
| n_cols: int = 5, | |
| ): | |
| n_panels = n_rows * n_cols | |
| n_words = min(n_panels - 1, len(tokens)) | |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3.2, n_rows * 3.2)) | |
| axes = axes.flatten() | |
| axes[0].imshow(image_pil) | |
| axes[0].set_title("Original", fontsize=11, fontweight="bold") | |
| axes[0].axis("off") | |
| for index in range(n_words): | |
| overlay = overlay_heatmap_on_image(image_pil, heatmaps[index]) | |
| axes[index + 1].imshow(overlay) | |
| axes[index + 1].set_title(f"'{tokens[index]}'", fontsize=10, fontweight="bold") | |
| axes[index + 1].axis("off") | |
| for index in range(n_words + 1, n_panels): | |
| axes[index].axis("off") | |
| caption_preview = " ".join(tokens[:12]) | |
| fig.suptitle( | |
| f"Cross-Attention Flow (2x5)\nCaption Tokens: {caption_preview}", | |
| fontsize=12, | |
| fontweight="bold", | |
| y=1.02, | |
| ) | |
| plt.tight_layout() | |
| return fig | |
| def load_owlvit_detector(device): | |
| from transformers import pipeline | |
| pipe_device = 0 if str(device).startswith("cuda") else -1 | |
| return pipeline( | |
| task="zero-shot-object-detection", | |
| model="google/owlvit-base-patch32", | |
| device=pipe_device, | |
| ) | |
| def binarize_heatmap(heatmap_np: np.ndarray, target_hw: tuple) -> np.ndarray: | |
| hm = cv2.resize(heatmap_np, (target_hw[1], target_hw[0])) | |
| hm_u8 = np.uint8(255.0 * hm) | |
| _, binary = cv2.threshold(hm_u8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| return binary > 0 | |
| def calculate_iou(mask: np.ndarray, box: list, img_shape: tuple) -> float: | |
| box_mask = np.zeros(img_shape, dtype=bool) | |
| xmin, ymin, xmax, ymax = map(int, box) | |
| xmin = max(0, xmin) | |
| ymin = max(0, ymin) | |
| xmax = min(img_shape[1], xmax) | |
| ymax = min(img_shape[0], ymax) | |
| box_mask[ymin:ymax, xmin:xmax] = True | |
| inter = np.logical_and(mask, box_mask).sum() | |
| union = np.logical_or(mask, box_mask).sum() | |
| return float(inter) / union if union > 0 else 0.0 | |
| def grade_alignment_with_detector( | |
| image_pil: Image.Image, | |
| tokens: List[str], | |
| heatmaps: List[np.ndarray], | |
| detector, | |
| min_detection_score: float = 0.05, | |
| ) -> List[dict]: | |
| results = [] | |
| img_shape = (image_pil.height, image_pil.width) | |
| for idx, (word, hm) in enumerate(zip(tokens, heatmaps)): | |
| clean_word = word.replace("##", "").lower() | |
| if len(clean_word) < 3 or clean_word in STOP_WORDS or not clean_word.isalpha(): | |
| continue | |
| detections = detector(image_pil, candidate_labels=[clean_word]) | |
| best_box, best_score = None, 0.0 | |
| for detection in detections: | |
| if detection["score"] > best_score and detection["score"] >= min_detection_score: | |
| best_score = detection["score"] | |
| best_box = [ | |
| detection["box"]["xmin"], | |
| detection["box"]["ymin"], | |
| detection["box"]["xmax"], | |
| detection["box"]["ymax"], | |
| ] | |
| if best_box is None: | |
| continue | |
| mask = binarize_heatmap(hm, img_shape) | |
| iou = calculate_iou(mask, best_box, img_shape) | |
| results.append( | |
| { | |
| "word": clean_word, | |
| "position": idx + 1, | |
| "iou": float(iou), | |
| "det_score": float(best_score), | |
| "box": best_box, | |
| } | |
| ) | |
| return results | |
| def summarize_caption_alignment(results: List[dict], caption_length: int) -> dict: | |
| if not results: | |
| return {"caption_length": caption_length, "mean_alignment_iou": 0.0} | |
| mean_iou = float(np.mean([item["iou"] for item in results])) | |
| return {"caption_length": caption_length, "mean_alignment_iou": mean_iou} | |