import torch from PIL import Image from src.interpretability import cross_attention_to_image import numpy as np import matplotlib.cm as cm def resize_for_display(pil_img, max_dim=5000): w, h = pil_img.size if max(w, h) <= max_dim: return pil_img scale = max_dim / max(w, h) new_w = int(w * scale) new_h = int(h * scale) return pil_img.resize((new_w, new_h), Image.LANCZOS) @torch.no_grad() def generate_rollout_for_demo(model, tokenizer, img, preprocess, device="cuda", max_new_tokens=32, alpha=0.45): model.eval() img_tensor = preprocess(img).unsqueeze(0).to(device) vision_out = model.vision_encoder(img_tensor) img_embeds = vision_out["image_embeds"] if img_embeds.dim() == 2: img_embeds = img_embeds.unsqueeze(1) projected = model.projector(img_embeds) decoder_input_ids = torch.tensor( [[model.t5.config.decoder_start_token_id]], device=device ) generated_ids = [] avg_frames = [] labels = [] per_head_frames = [] num_heads = None # Decode token-by-token for step in range(max_new_tokens): outputs = model.t5( encoder_outputs=(projected,), decoder_input_ids=decoder_input_ids, output_attentions=True, return_dict=True, ) # Cross-attention from last decoder layer last_cross = outputs.cross_attentions[-1][0] # (heads, tgt, src) num_heads = last_cross.size(0) # average over heads (tgt, src) attn_avg = last_cross.mean(dim=0) # Get attention for the last generated token (tgt index = -1) attn_vec = attn_avg[-1] # shape: (src_len,) heat_avg = cross_attention_to_image(attn_vec) if isinstance(heat_avg, tuple): heat_avg = heat_avg[0] if isinstance(heat_avg, np.ndarray): heat_avg = Image.fromarray((heat_avg * 255).astype("uint8")) avg_frames.append( overlay_attention_for_demo(img_tensor, heat_avg, alpha=alpha) ) head_overlays = [] for h in range(num_heads): attn_vec_h = last_cross[h][-1] # (src_len,) hmap = cross_attention_to_image(attn_vec_h) if isinstance(hmap, tuple): hmap = hmap[0] if isinstance(hmap, np.ndarray): hmap = Image.fromarray((hmap * 255).astype("uint8")) head_overlays.append( overlay_attention_for_demo(img_tensor, hmap, alpha=alpha) ) per_head_frames.append(head_overlays) # Decode next token next_token = outputs.logits[:, -1, :].argmax(-1) token_str = tokenizer.decode(next_token, skip_special_tokens=True) labels.append(f"Token #{step}: \"{token_str}\"") generated_ids.append(int(next_token)) if next_token.item() == tokenizer.eos_token_id: break decoder_input_ids = torch.cat( [decoder_input_ids, next_token.unsqueeze(0)], dim=1 ) # Caption caption = tokenizer.decode(generated_ids, skip_special_tokens=True) # Return structured dict for Gradio return { "caption": caption, "avg": { "frames": avg_frames, "labels": labels }, "heads": { "frames": per_head_frames, # list[step][head] = PIL image "labels": labels, "num_heads": num_heads } } def overlay_attention_for_demo(image_tensor, heatmap, alpha=0.45): img = image_tensor[0].detach().cpu().permute(1, 2, 0).numpy() img = (img - img.min()) / (img.max() - img.min()) # normalize img_uint8 = (img * 255).astype("uint8") heatmap = heatmap.resize((img_uint8.shape[1], img_uint8.shape[0]), Image.BILINEAR) heat_np = np.asarray(heatmap).astype("float32") / 255.0 base = Image.fromarray(img_uint8).convert("RGBA") colored = cm.inferno(heat_np) # returns RGBA float array colored_uint8 = (colored * 255).astype("uint8") heat = Image.fromarray(colored_uint8).convert("RGBA") heat.putalpha(int(alpha * 255)) blended = Image.alpha_composite(base, heat) blended = blended.convert("RGB") return blended #resize_for_display(blended, max_dim=500)