Spaces:
Sleeping
Sleeping
File size: 8,309 Bytes
039c47c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
'''
Example of using Captum Integrated Gradients with a Vision Transformer (ViT) model
to explain image classification predictions.
This example downloads a random image from the web, runs it through a pre-trained
ViT model, and uses Captum to compute and visualize attributions.
IG: It’s like asking the computer not just what’s in the image,
but which parts of the picture convinced it to give that answer.
IG: Integrated Gradients
Like turning up the brightness on a photo and seeing which parts
of the picture made the model confident in its answer.
'''
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import requests
import random
from io import BytesIO
import numpy as np
from PIL import Image as PILImage
import requests
import random
from PIL import ImageFilter
# Add logging
import logging, os
from logging.handlers import RotatingFileHandler
LOG_DIR = os.path.join(os.path.dirname(__file__), "logs")
os.makedirs(LOG_DIR, exist_ok=True)
logfile = os.path.join(LOG_DIR, "interp.log")
logger = logging.getLogger("vit_and_captum")
if not logger.handlers:
logger.setLevel(logging.INFO)
sh = logging.StreamHandler()
fh = RotatingFileHandler(logfile, maxBytes=5_000_000, backupCount=3, encoding="utf-8")
fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")
sh.setFormatter(fmt); fh.setFormatter(fmt)
logger.addHandler(sh); logger.addHandler(fh)
# ---- Step 1: Load model ----
# Using a Vision Transformer (ViT) model from Hugging Face Transformers
from transformers import ViTForImageClassification, ViTImageProcessor
# Load pre-trained model and processor
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)
# run in eval mode for inference
model.eval()
# ---- Step 2: Load an image ----
# Function to download a random image from DuckDuckGo
def download_random_image():
# DuckDuckGo image search for ImageNet-style images
search_terms = ["dog", "cat", "bird", "car", "airplane", "horse", "elephant", "tiger", "lion", "bear"]
term = random.choice(search_terms)
# multiple providers to improve reliability
providers = [
f"https://source.unsplash.com/224x224/?{term}",
f"https://picsum.photos/seed/{term}/224/224",
f"https://loremflickr.com/224/224/{term}",
# placekitten is a good fallback for cat-like images (serves an image for any request)
f"https://placekitten.com/224/224"
]
headers = {"User-Agent": "Mozilla/5.0 (compatible; ImageFetcher/1.0)"}
for url in providers:
try:
response = requests.get(url, timeout=10, headers=headers, allow_redirects=True)
if response.status_code != 200:
logger.warning("Provider %s returned status %d", url, response.status_code)
continue
# Try to identify and open image content
try:
img = Image.open(BytesIO(response.content)).convert("RGB")
except Exception as img_err:
logger.warning("Failed to parse image from %s: %s", url, img_err)
continue
# Ensure it's exactly 224x224
try:
img = img.resize((224, 224), Image.Resampling.LANCZOS)
except Exception:
# Fallback if PIL version doesn't have Image.Resampling
img = img.resize((224, 224), Image.LANCZOS)
logger.info("Downloaded random image from %s for term=%s", url, term)
return img
except requests.RequestException as e:
logger.warning("Request failed for %s: %s", url, e)
continue
logger.error("All providers failed; using fallback solid-color image.")
img = Image.new("RGB", (224, 224), color=(128, 128, 128))
return img
# Download and use a random image
img = download_random_image()
# Preprocess the image to pytorch tensor
inputs = processor(images=img, return_tensors="pt")
# ---- Step 3: Run prediction ----
with torch.no_grad(): # no gradients needed for inference
outputs = model(**inputs) # inputs is a dict
probs = outputs.logits.softmax(-1) # most probable class
pred_idx = probs.argmax(-1).item() # index of predicted class
logger.info("Predicted %s (idx=%d)", model.config.id2label[pred_idx], pred_idx)
# NEW: show top-k predictions to give context
topk = 5
topk_vals, topk_idx = torch.topk(probs, k=topk)
topk_vals = topk_vals.squeeze().cpu().numpy()
topk_idx = topk_idx.squeeze().cpu().numpy()
print("Top-{} predictions:".format(topk))
for v,i in zip(topk_vals, topk_idx):
print(f" {model.config.id2label[int(i)]:30s} {float(v):.4f}")
print("Chosen prediction:", model.config.id2label[pred_idx])
# ---- Step 4: Captum Integrated Gradients ----
from captum.attr import IntegratedGradients
# Captum expects a forward function that returns a tensor (not a ModelOutput dataclass)
def forward_func(pixel_values):
# ensure we call the model and return raw logits or probabilities as a Tensor
outputs = model(pixel_values=pixel_values)
# outputs is a ModelOutput dataclass; return the logits tensor
return outputs.logits
# IntegratedGradients should be given the forward function
ig = IntegratedGradients(forward_func)
# Captum needs the inputs to require gradients
input_tensor = inputs["pixel_values"].clone().detach()
input_tensor.requires_grad_(True)
# Now compute attributions for the predicted class index
# (recompute with more steps and ask for convergence delta)
attributions, convergence_delta = ig.attribute(
input_tensor,
target=pred_idx,
n_steps=100,
return_convergence_delta=True,
)
logger.info("IG convergence delta: %s", convergence_delta)
# ---- Step 5: Visualize attribution heatmap (normalized + overlay) ----
# aggregate over channels (signed mean keeps sign of contributions)
attr = attributions.squeeze().mean(dim=0).detach().cpu().numpy()
# Normalize to [-1,1] to show positive vs negative contributions with diverging colormap
min_v, max_v = float(attr.min()), float(attr.max())
norm_denom = max(abs(min_v), abs(max_v)) + 1e-8
attr_signed = attr / norm_denom # now in approx [-1,1]
# OPTIONAL: smooth heatmap slightly to make overlays more intuitive
try:
heat_pil = PILImage.fromarray(np.uint8((attr_signed + 1) * 127.5))
heat_pil = heat_pil.filter(ImageFilter.GaussianBlur(radius=1.5))
attr_signed = (np.array(heat_pil).astype(float) / 127.5) - 1.0
except Exception:
# If PIL filter not available, continue without smoothing
pass
# Create overlay using a diverging colormap (positive = warm, negative = cool)
plt.figure(figsize=(6,6))
plt.imshow(img)
plt.imshow(attr_signed, cmap="seismic", alpha=0.45, vmin=-1, vmax=1)
cb = plt.colorbar(fraction=0.046, pad=0.04)
cb.set_label("Signed attribution (normalized)")
plt.title(f"IG overlay — pred: {model.config.id2label[pred_idx]} ({float(probs.squeeze()[pred_idx]):.3f})")
plt.axis("off")
# Show standalone signed heatmap for clearer inspection
plt.figure(figsize=(4,4))
plt.imshow(attr_signed, cmap="seismic", vmin=-1, vmax=1)
plt.colorbar()
plt.title("Signed IG Attribution (neg=blue, pos=red)")
plt.axis("off")
plt.show()
# Add concise runtime interpretability guidance
def print_interpretability_summary():
print("\nHow to read the results (quick guide):")
print("- IG signed heatmap: red/warm = supports the predicted class; blue/cool = opposes it.")
print("- Normalize by max-abs when comparing images. Check IG 'convergence delta' — large values mean treat attributions cautiously.")
print("- LIME panel (if used): green/highlighted superpixels indicate locally important regions; background-dominated explanations are a red flag.")
print("- MC Dropout histogram: narrow peak → stable belief; wide/multi-modal → epistemic uncertainty.")
print("- TTA histogram: many flips under small augmentations → fragile/aleatoric sensitivity.")
print("- Predictive entropy: higher → more uncertainty in the full distribution.")
print("- Variation ratio: fraction of samples not matching majority; higher → more disagreement.\n")
print_interpretability_summary()
|