| |
| |
| """ |
| Token→region cross-attention visualization for GroundingDINO integrated as a helper. |
| |
| Usage from other modules: |
| from vg_token_attention import run_token_ca_visualization |
| |
| paths = run_token_ca_visualization( |
| cfg_path="VG/config/GroundingDINO_SwinT_OGC_2.py", |
| ckpt_path="VG/weights/checkpoint0399_log4.pth", |
| image_path=image_path, |
| prompt=text_prompt, |
| terms=chexbert_terms, # e.g. ["edema", "effusion"] |
| out_dir="outputs/attn_overlays", |
| device="cuda" or "cpu", |
| ) |
| """ |
|
|
| import os |
| import math |
| import re |
| import cv2 |
| import torch |
| import numpy as np |
| import torch.nn.functional as F |
| from torch import nn |
| from PIL import Image |
| import torchvision.transforms as T |
|
|
| from VG.groundingdino.util.inference import load_model |
| from VG.groundingdino.util.misc import NestedTensor |
|
|
| from transformers import AutoTokenizer |
|
|
| DEVICE_DEFAULT = "cuda" if torch.cuda.is_available() else "cpu" |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
|
|
| |
| |
| |
| def preprocess_image_fn_factory(device=DEVICE_DEFAULT, longest=1024, pad_divisor=32): |
| to_tensor = T.ToTensor() |
| normalize = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) |
|
|
| def _resize_longest(pil_img: Image.Image, longest_side=1024): |
| w, h = pil_img.size |
| scale = float(longest_side) / max(w, h) |
| new_w, new_h = int(round(w * scale)), int(round(h * scale)) |
| return pil_img.resize((new_w, new_h), Image.BICUBIC) |
|
|
| def preprocess_image_fn(pil_img: Image.Image): |
| img_resized = _resize_longest(pil_img, longest_side=longest) |
| x = normalize(to_tensor(img_resized)) |
| _, H, W = x.shape |
|
|
| |
| H_pad = math.ceil(H / pad_divisor) * pad_divisor |
| W_pad = math.ceil(W / pad_divisor) * pad_divisor |
| pad_h, pad_w = H_pad - H, W_pad - W |
| x = F.pad(x, (0, pad_w, 0, pad_h), value=0.0) |
|
|
| |
| mask = torch.zeros((H_pad, W_pad), dtype=torch.bool) |
| if pad_h > 0: |
| mask[H:, :] = True |
| if pad_w > 0: |
| mask[:, W:] = True |
|
|
| return x.unsqueeze(0).to(device), mask.unsqueeze(0).to(device) |
|
|
| return preprocess_image_fn |
|
|
|
|
| |
| |
| |
| BIOMEDVLP_TOKENIZER_PATH = "VG/weights/BiomedVLP-CXR-BERT/" |
|
|
| _tokenizer = AutoTokenizer.from_pretrained(BIOMEDVLP_TOKENIZER_PATH) |
|
|
|
|
| def tokenize_with_offsets(prompt: str, device=DEVICE_DEFAULT): |
| enc = _tokenizer( |
| prompt, |
| return_tensors="pt", |
| return_offsets_mapping=True, |
| add_special_tokens=True, |
| truncation=True, |
| ) |
| tokens = _tokenizer.convert_ids_to_tokens(enc["input_ids"][0]) |
| offsets = enc["offset_mapping"][0].tolist() |
| return { |
| "input_ids": enc["input_ids"].to(device), |
| "attention_mask": enc["attention_mask"].to(device), |
| "tokens": tokens, |
| "offsets": offsets, |
| } |
|
|
|
|
| def find_token_span_by_offsets(prompt: str, offsets, term: str): |
| s = prompt.lower() |
| t = term.lower() |
| m = re.search(r'\b' + re.escape(t) + r'\b', s) or re.search(re.escape(t), s) |
| if not m: |
| return [] |
| a, b = m.start(), m.end() |
| idxs = [] |
| for i, (u, v) in enumerate(offsets): |
| if ( |
| u is None or v is None or |
| u < 0 or v < 0 or |
| (u == 0 and v == 0) |
| ): |
| continue |
| if not (v <= a or u >= b): |
| idxs.append(i) |
| return idxs |
|
|
|
|
| def model_span_indices_for_term(tokens, offsets, attn_T, term: str): |
| |
| raw_hf_idxs = find_token_span_by_offsets( |
| "".join(t if t != "[PAD]" else " " for t in tokens), |
| offsets, |
| term |
| ) |
| if not raw_hf_idxs: |
| low = term.lower() |
| raw_hf_idxs = [i for i, t in enumerate(tokens) if low in t.lower()] |
|
|
| |
| non_special_hf = [] |
| for i, (tok_i, (u, v)) in enumerate(zip(tokens, offsets)): |
| if tok_i in ("[CLS]", "[SEP]", "[PAD]"): |
| continue |
| if u is None or v is None or u < 0 or v < 0 or (u == 0 and v == 0): |
| continue |
| non_special_hf.append(i) |
|
|
| non_special_hf = non_special_hf[:attn_T] |
| hf2model = {hf_idx: j for j, hf_idx in enumerate(non_special_hf)} |
| model_term_idxs = [hf2model[i] for i in raw_hf_idxs if i in hf2model] |
|
|
| return torch.tensor(model_term_idxs, dtype=torch.long) |
|
|
|
|
| |
| |
| |
| class CrossAttnRecorder: |
| def __init__(self, decoder_layers, attn_attr_name='ca_text'): |
| self.attn_weights = [] |
| self.handles = [] |
| self._register(decoder_layers, attn_attr_name) |
|
|
| def _hook(self, module, input, output): |
| if isinstance(output, tuple) and len(output) >= 2: |
| attn_w = output[1] |
| elif hasattr(module, 'attn_output_weights'): |
| attn_w = module.attn_output_weights |
| else: |
| attn_w = None |
| if attn_w is not None: |
| self.attn_weights.append(attn_w.detach().to('cpu', dtype=torch.float32)) |
|
|
| def _wrap_forward(self, mha_module: nn.MultiheadAttention): |
| orig_forward = mha_module.forward |
|
|
| def wrapped_forward(*args, **kwargs): |
| kwargs['need_weights'] = True |
| kwargs['average_attn_weights'] = False |
| return orig_forward(*args, **kwargs) |
|
|
| return orig_forward, wrapped_forward |
|
|
| def _register(self, decoder_layers, attn_attr_name): |
| for layer in decoder_layers: |
| attn_module = getattr(layer, attn_attr_name, None) |
| if attn_module is None: |
| continue |
| if isinstance(attn_module, nn.MultiheadAttention): |
| orig_fwd, wrapped = self._wrap_forward(attn_module) |
| attn_module.forward = wrapped |
| handle = attn_module.register_forward_hook(self._hook) |
| self.handles.append((attn_module, handle, orig_fwd)) |
| else: |
| handle = attn_module.register_forward_hook(self._hook) |
| self.handles.append((attn_module, handle, None)) |
|
|
| def close(self): |
| for attn_module, handle, orig_fwd in self.handles: |
| handle.remove() |
| if (orig_fwd is not None) and isinstance(attn_module, nn.MultiheadAttention): |
| attn_module.forward = orig_fwd |
|
|
|
|
| |
| |
| |
| def boxes_to_heatmap(boxes_xyxy, weights, hw, score_scale=None, blur_ksize=51, blur_sigma=0): |
| H, W = hw |
| heat = np.zeros((H, W), dtype=np.float32) |
|
|
| w = weights.detach().cpu().numpy() |
| if score_scale is not None: |
| s = score_scale.detach().cpu().numpy() |
| w = w * s |
|
|
| for i, box in enumerate(boxes_xyxy): |
| x1, y1, x2, y2 = map(int, box.tolist()) |
| x1 = max(0, min(W - 1, x1)); x2 = max(0, min(W - 1, x2)) |
| y1 = max(0, min(H - 1, y1)); y2 = max(0, min(H - 1, y2)) |
| if x2 <= x1 or y2 <= y1: |
| continue |
| heat[y1:y2, x1:x2] += float(w[i]) |
|
|
| if blur_ksize is not None and blur_ksize >= 3 and blur_ksize % 2 == 1: |
| heat = cv2.GaussianBlur(heat, (blur_ksize, blur_ksize), blur_sigma) |
|
|
| mx = heat.max() |
| if mx > 1e-6: |
| heat /= mx |
| return heat |
|
|
|
|
| def overlay_heatmap(img_pil: Image.Image, heatmap, alpha=0.45, cmap=cv2.COLORMAP_JET): |
| img = np.array(img_pil.convert("RGB")) |
| H, W = img.shape[:2] |
| h = (np.clip(heatmap, 0, 1) * 255).astype(np.uint8) |
| h_color = cv2.applyColorMap(h, cmap)[:, :, ::-1] |
| blended = cv2.addWeighted(h_color, alpha, img, 1 - alpha, 0) |
| return Image.fromarray(blended) |
|
|
|
|
| def load_image_keep_longest(path, longest=1024): |
| img = Image.open(path).convert("RGB") |
| w, h = img.size |
| s = float(longest) / max(w, h) |
| new_w, new_h = int(round(w * s)), int(round(h * s)) |
| return img.resize((new_w, new_h), Image.BICUBIC) |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def run_token_ca_visualization( |
| cfg_path: str, |
| ckpt_path: str, |
| image_path: str, |
| prompt: str, |
| terms, |
| out_dir: str, |
| device: str = DEVICE_DEFAULT, |
| score_thresh: float = 0.25, |
| topk: int = 100, |
| term_agg: str = "mean", |
| save_per_term: bool = True, |
| ): |
| """ |
| Returns: |
| { |
| "combined": <path_to_combined_overlay>, |
| "per_term": { term: path_to_overlay, ... } |
| } |
| """ |
| if isinstance(terms, str): |
| terms = [terms] |
| |
| prompt_lower = prompt.lower() |
|
|
| |
| terms = [t for t in terms if t.lower() in prompt_lower] |
|
|
| if not terms: |
| print(f"[TokenCA] No configured terms found in prompt: {prompt!r}") |
| return {} |
| |
| |
| |
|
|
| device = device or DEVICE_DEFAULT |
| model = load_model(cfg_path, ckpt_path).to(device).eval() |
| preprocess_image_fn = preprocess_image_fn_factory(device=device, longest=1024, pad_divisor=32) |
|
|
| img_pil = load_image_keep_longest(image_path, longest=1024) |
|
|
| os.makedirs(out_dir, exist_ok=True) |
| base_name = os.path.splitext(os.path.basename(image_path))[0] |
| combined_path = os.path.join(out_dir, f"{base_name}__attn_combined.png") |
|
|
| |
| decoder_layers = model.transformer.decoder.layers |
| recorder = CrossAttnRecorder(decoder_layers, attn_attr_name="ca_text") |
|
|
| |
| img_tensor, mask = preprocess_image_fn(img_pil) |
| samples = NestedTensor(img_tensor, mask) |
|
|
| outputs = model(samples, captions=[prompt]) |
|
|
| |
| pred_logits = outputs["pred_logits"] |
| pred_boxes = outputs["pred_boxes"] |
| logits = pred_logits[0].sigmoid() |
| scores, _ = logits.max(dim=1) |
|
|
| keep = torch.nonzero(scores > score_thresh).squeeze(1) |
| if keep.numel() == 0: |
| keep = torch.argsort(scores, descending=True)[:min(topk, scores.numel())] |
| else: |
| keep = keep[:topk] |
|
|
| W, H = img_pil.size |
| boxes_cxcywh = pred_boxes[0][keep] |
| cx, cy, w, h = boxes_cxcywh.unbind(-1) |
| x1 = (cx - 0.5 * w) * W |
| y1 = (cy - 0.5 * h) * H |
| x2 = (cx + 0.5 * w) * W |
| y2 = (cy + 0.5 * h) * H |
| boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=-1) |
| kept_scores = scores[keep] |
|
|
| keep_cpu = keep.cpu() |
|
|
| if len(recorder.attn_weights) == 0: |
| recorder.close() |
| raise RuntimeError("No attention weights captured. Check that 'ca_text' exists.") |
| attn_qt_layers = [] |
| for w_att in recorder.attn_weights: |
| w_att = w_att.squeeze(0).mean(0) |
| attn_qt_layers.append(w_att) |
| attn_qt = torch.stack(attn_qt_layers, 0).mean(0) |
| recorder.close() |
|
|
| |
| tok = tokenize_with_offsets(prompt, device="cpu") |
| tokens, offsets = tok["tokens"], tok["offsets"] |
| T_text = attn_qt.shape[1] |
|
|
| per_term_attn_kept = {} |
| per_term_attn_full = {} |
|
|
| for t in terms: |
| model_idxs = model_span_indices_for_term(tokens, offsets, T_text, t) |
| if model_idxs.numel() == 0: |
| continue |
| attn_per_query = attn_qt[:, model_idxs].mean(1) |
| attn_kept = attn_per_query[keep_cpu] |
| attn_kept = (attn_kept - attn_kept.min()) / (attn_kept.max() - attn_kept.min() + 1e-6) |
| per_term_attn_kept[t] = attn_kept |
| per_term_attn_full[t] = attn_per_query |
|
|
| if not per_term_attn_kept: |
| |
| print(f"[TokenCA] None of the terms were found in the first T tokens: {terms}") |
| |
| return {} |
|
|
| |
| agg = None |
| for t, v in per_term_attn_full.items(): |
| agg = v if agg is None else ( |
| agg + v if term_agg == "sum" |
| else torch.maximum(agg, v) if term_agg == "max" |
| else (agg + v) |
| ) |
| if term_agg == "mean": |
| agg = agg / float(len(per_term_attn_full)) |
|
|
| agg_kept = agg[keep_cpu] |
| agg_kept = (agg_kept - agg_kept.min()) / (agg_kept.max() - agg_kept.min() + 1e-6) |
|
|
| heat = boxes_to_heatmap( |
| boxes_xyxy=boxes_xyxy, |
| weights=agg_kept, |
| hw=(H, W), |
| score_scale=kept_scores, |
| blur_ksize=61, |
| blur_sigma=0, |
| ) |
| overlay = overlay_heatmap(img_pil, heat, alpha=0.45) |
| overlay.save(combined_path) |
|
|
| per_term_paths = {} |
| if save_per_term and len(per_term_attn_kept) > 1: |
| for t, v in per_term_attn_kept.items(): |
| heat_t = boxes_to_heatmap( |
| boxes_xyxy=boxes_xyxy, |
| weights=v, |
| hw=(H, W), |
| score_scale=kept_scores, |
| blur_ksize=61, |
| blur_sigma=0, |
| ) |
| ov_t = overlay_heatmap(img_pil, heat_t, alpha=0.45) |
| term_tag = re.sub(r"[^a-zA-Z0-9]+", "_", t.lower())[:32] |
| p = os.path.join(out_dir, f"{base_name}__{term_tag}.png") |
| ov_t.save(p) |
| per_term_paths[t] = p |
|
|
| return { |
| "combined": combined_path, |
| "per_term": per_term_paths, |
| } |
|
|