Spaces:
Sleeping
Sleeping
| """ | |
| step5_iou_grade.py | |
| ================== | |
| STEP 5 — Quantitative Alignment Grading via Zero-Shot Object Detection. | |
| Responsibilities: | |
| - Load OWL-ViT (zero-shot open-vocabulary object detector). | |
| - For each meaningful word in the caption, find its bounding box in the image. | |
| - Binarize the Attention Flow heatmap with Otsu's thresholding. | |
| - Compute Intersection over Union (IoU) between heatmap mask and bounding box. | |
| - Plot and save a word-position vs IoU scatter chart. | |
| Why OWL-ViT? | |
| OWL-ViT is a zero-shot detector: it can find *any* object in an image | |
| just by reading its name. This means we do NOT need any pre-annotated | |
| bounding boxes — just our generated caption words. It acts as a fully | |
| automated judge of how well the AI's attention was grounded. | |
| Why Otsu's Thresholding? | |
| Otsu's method automatically finds the optimal binary split point of the | |
| heatmap histogram, separating "looking here" from "not looking here" | |
| without needing a hand-tuned cut-off value. | |
| IoU Interpretation: | |
| 0.0 = No overlap (attention fired in the wrong place). | |
| 0.1–0.3 = Weak grounding (partial overlap, some drift). | |
| 0.3+ = Good grounding (attention focused on the right region). | |
| """ | |
| import os | |
| import sys | |
| import numpy as np | |
| import cv2 | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| _PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR)) | |
| if _PROJECT_ROOT not in sys.path: | |
| sys.path.insert(0, _PROJECT_ROOT) | |
| # ── Stop-word filter ────────────────────────────────────────────────────────── | |
| _STOP_WORDS = { | |
| "a", "an", "the", "and", "or", "but", "is", "are", "was", "were", | |
| "in", "on", "at", "to", "for", "with", "by", "it", "this", "that", | |
| "there", "here", "of", "up", "out", ".", ",", "!", "##", | |
| } | |
| def load_detector(device, verbose: bool = True): | |
| """ | |
| Load the OWL-ViT zero-shot object detection pipeline. | |
| Args: | |
| device : torch.device (MPS, CUDA, or CPU). | |
| verbose : Print loading message. | |
| Returns: | |
| detector – transformers.pipeline object. | |
| """ | |
| from transformers import pipeline | |
| if verbose: | |
| print("🔭 Loading OWL-ViT zero-shot object detector …") | |
| detector = pipeline( | |
| task="zero-shot-object-detection", | |
| model="google/owlvit-base-patch32", | |
| device=device, | |
| ) | |
| if verbose: | |
| print("✅ OWL-ViT ready") | |
| return detector | |
| def binarize_heatmap(heatmap_np: np.ndarray, target_hw: tuple) -> np.ndarray: | |
| """ | |
| Resize and Otsu-threshold a float heatmap into a boolean mask. | |
| Args: | |
| heatmap_np : (H, W) float32 heatmap in [0, 1]. | |
| target_hw : (height, width) of the original image. | |
| Returns: | |
| (H, W) boolean mask. | |
| """ | |
| hm = cv2.resize(heatmap_np, (target_hw[1], target_hw[0])) | |
| hm_u8 = np.uint8(255.0 * hm) | |
| _, binary = cv2.threshold(hm_u8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| return binary > 0 | |
| def calculate_iou(mask: np.ndarray, box: list, img_shape: tuple) -> float: | |
| """ | |
| Calculate Intersection over Union between a boolean mask and a bounding box. | |
| Args: | |
| mask : (H, W) boolean numpy array. | |
| box : [xmin, ymin, xmax, ymax] in image pixel coords. | |
| img_shape : (H, W) of the image. | |
| Returns: | |
| IoU score in [0, 1]. | |
| """ | |
| box_mask = np.zeros(img_shape, dtype=bool) | |
| xmin, ymin, xmax, ymax = map(int, box) | |
| # Clamp to image bounds | |
| xmin = max(0, xmin); ymin = max(0, ymin) | |
| xmax = min(img_shape[1], xmax); ymax = min(img_shape[0], ymax) | |
| box_mask[ymin:ymax, xmin:xmax] = True | |
| inter = np.logical_and(mask, box_mask).sum() | |
| union = np.logical_or(mask, box_mask).sum() | |
| return float(inter) / union if union > 0 else 0.0 | |
| def grade_alignment( | |
| image_pil, | |
| tokens: list, | |
| heatmaps: list, | |
| detector, | |
| min_detection_score: float = 0.05, | |
| verbose: bool = True, | |
| ) -> list: | |
| """ | |
| For each meaningful token, attempt to detect its object in the image | |
| and compute IoU against the Attention Flow heatmap. | |
| Args: | |
| image_pil : Original PIL image (un-resized). | |
| tokens : List of decoded word strings. | |
| heatmaps : Parallel list of (H, W) numpy heatmaps. | |
| detector : OWL-ViT pipeline. | |
| min_detection_score : Only accept detections above this confidence. | |
| verbose : Print per-token results. | |
| Returns: | |
| List of dicts: { 'word', 'position', 'iou', 'det_score' } | |
| """ | |
| img_shape = (image_pil.height, image_pil.width) | |
| results = [] | |
| if verbose: | |
| print("\n📊 Grading alignment (Attention Flow IoU)…") | |
| for idx, (word, hm) in enumerate(zip(tokens, heatmaps)): | |
| clean_word = word.replace("##", "").lower() | |
| if len(clean_word) < 3 or clean_word in _STOP_WORDS or not clean_word.isalpha(): | |
| continue | |
| detections = detector(image_pil, candidate_labels=[clean_word]) | |
| best_box, best_score = None, 0.0 | |
| for d in detections: | |
| if d["score"] > best_score and d["score"] >= min_detection_score: | |
| best_score = d["score"] | |
| best_box = [d["box"]["xmin"], d["box"]["ymin"], | |
| d["box"]["xmax"], d["box"]["ymax"]] | |
| if best_box is not None: | |
| mask = binarize_heatmap(hm, img_shape) | |
| iou = calculate_iou(mask, best_box, img_shape) | |
| if verbose: | |
| print(f" '{clean_word}' (pos {idx+1}): det_score={best_score:.2f}, IoU={iou:.4f}") | |
| results.append({"word": clean_word, "position": idx + 1, | |
| "iou": iou, "det_score": best_score}) | |
| else: | |
| if verbose: | |
| print(f" '{clean_word}' (pos {idx+1}): no detection found") | |
| mean_iou = np.mean([r["iou"] for r in results]) if results else 0.0 | |
| if verbose: | |
| print(f"\n Mean Alignment IoU: {mean_iou:.4f}") | |
| return results | |
| def plot_iou_chart( | |
| all_results: list, | |
| out_path: str, | |
| verbose: bool = True, | |
| ) -> str: | |
| """ | |
| Save a scatter plot of word position vs Attention Flow IoU. | |
| Args: | |
| all_results : Flat list of result dicts from grade_alignment(). | |
| out_path : Absolute path to save the PNG. | |
| verbose : Print save confirmation. | |
| Returns: | |
| out_path. | |
| """ | |
| if not all_results: | |
| if verbose: | |
| print("⚠️ No IoU results to plot.") | |
| return out_path | |
| positions = [r["position"] for r in all_results] | |
| ious = [r["iou"] for r in all_results] | |
| words = [r["word"] for r in all_results] | |
| fig, ax = plt.subplots(figsize=(9, 5)) | |
| sc = ax.scatter(positions, ious, zorder=5, alpha=0.8, s=80, | |
| c=ious, cmap="RdYlGn", vmin=0, vmax=0.5) | |
| plt.colorbar(sc, ax=ax, label="IoU") | |
| # Annotate each point with the word | |
| for pos, iou, word in zip(positions, ious, words): | |
| ax.annotate(word, (pos, iou), textcoords="offset points", | |
| xytext=(4, 4), fontsize=8, alpha=0.9) | |
| # Trend line | |
| if len(positions) > 1: | |
| z = np.polyfit(positions, ious, 1) | |
| p_fn = np.poly1d(z) | |
| xs = sorted(positions) | |
| ax.plot(xs, [p_fn(x) for x in xs], "b--", alpha=0.5, label="Trend") | |
| ax.legend() | |
| ax.set_title("Word Position vs. Attention Flow Alignment (IoU)\n" | |
| "(Higher = model actually looked at the right region)", fontsize=13) | |
| ax.set_xlabel("Word position in caption") | |
| ax.set_ylabel("Alignment IoU") | |
| ax.grid(True, linestyle="--", alpha=0.5) | |
| plt.tight_layout() | |
| plt.savefig(out_path, dpi=150, bbox_inches="tight") | |
| plt.close() | |
| if verbose: | |
| print(f"✅ IoU plot saved → {out_path}") | |
| return out_path | |