File size: 8,309 Bytes
039c47c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
'''
Example of using Captum Integrated Gradients with a Vision Transformer (ViT) model
to explain image classification predictions.
This example downloads a random image from the web, runs it through a pre-trained
ViT model, and uses Captum to compute and visualize attributions.

IG: It’s like asking the computer not just what’s in the image, 
but which parts of the picture convinced it to give that answer.

IG: Integrated Gradients

Like turning up the brightness on a photo and seeing which parts 
of the picture made the model confident in its answer.
'''

import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import requests
import random
from io import BytesIO
import numpy as np
from PIL import Image as PILImage
import requests
import random
from PIL import ImageFilter

# Add logging
import logging, os
from logging.handlers import RotatingFileHandler
LOG_DIR = os.path.join(os.path.dirname(__file__), "logs")
os.makedirs(LOG_DIR, exist_ok=True)
logfile = os.path.join(LOG_DIR, "interp.log")
logger = logging.getLogger("vit_and_captum")
if not logger.handlers:
    logger.setLevel(logging.INFO)
    sh = logging.StreamHandler()
    fh = RotatingFileHandler(logfile, maxBytes=5_000_000, backupCount=3, encoding="utf-8")
    fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")
    sh.setFormatter(fmt); fh.setFormatter(fmt)
    logger.addHandler(sh); logger.addHandler(fh)

# ---- Step 1: Load model ----
# Using a Vision Transformer (ViT) model from Hugging Face Transformers
from transformers import ViTForImageClassification, ViTImageProcessor

# Load pre-trained model and processor
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)

# run in eval mode for inference 
model.eval()

# ---- Step 2: Load an image ----
# Function to download a random image from DuckDuckGo
def download_random_image():
    # DuckDuckGo image search for ImageNet-style images
    search_terms = ["dog", "cat", "bird", "car", "airplane", "horse", "elephant", "tiger", "lion", "bear"]
    term = random.choice(search_terms)

    # multiple providers to improve reliability
    providers = [
        f"https://source.unsplash.com/224x224/?{term}",
        f"https://picsum.photos/seed/{term}/224/224",
        f"https://loremflickr.com/224/224/{term}",
        # placekitten is a good fallback for cat-like images (serves an image for any request)
        f"https://placekitten.com/224/224"
    ]

    headers = {"User-Agent": "Mozilla/5.0 (compatible; ImageFetcher/1.0)"}
    for url in providers:
        try:
            response = requests.get(url, timeout=10, headers=headers, allow_redirects=True)
            if response.status_code != 200:
                logger.warning("Provider %s returned status %d", url, response.status_code)
                continue

            # Try to identify and open image content
            try:
                img = Image.open(BytesIO(response.content)).convert("RGB")
            except Exception as img_err:
                logger.warning("Failed to parse image from %s: %s", url, img_err)
                continue

            # Ensure it's exactly 224x224
            try:
                img = img.resize((224, 224), Image.Resampling.LANCZOS)
            except Exception:
                # Fallback if PIL version doesn't have Image.Resampling
                img = img.resize((224, 224), Image.LANCZOS)
            logger.info("Downloaded random image from %s for term=%s", url, term)
            return img
        except requests.RequestException as e:
            logger.warning("Request failed for %s: %s", url, e)
            continue

    logger.error("All providers failed; using fallback solid-color image.")
    img = Image.new("RGB", (224, 224), color=(128, 128, 128))
    return img

# Download and use a random image
img = download_random_image()
# Preprocess the image to pytorch tensor
inputs = processor(images=img, return_tensors="pt")

# ---- Step 3: Run prediction ----
with torch.no_grad(): # no gradients needed for inference
    outputs = model(**inputs) # inputs is a dict
    probs = outputs.logits.softmax(-1) # most probable class
    pred_idx = probs.argmax(-1).item() # index of predicted class
    logger.info("Predicted %s (idx=%d)", model.config.id2label[pred_idx], pred_idx)

    # NEW: show top-k predictions to give context
    topk = 5
    topk_vals, topk_idx = torch.topk(probs, k=topk)
    topk_vals = topk_vals.squeeze().cpu().numpy()
    topk_idx = topk_idx.squeeze().cpu().numpy()
    print("Top-{} predictions:".format(topk))
    for v,i in zip(topk_vals, topk_idx):
        print(f"  {model.config.id2label[int(i)]:30s} {float(v):.4f}")
    print("Chosen prediction:", model.config.id2label[pred_idx])

# ---- Step 4: Captum Integrated Gradients ----
from captum.attr import IntegratedGradients
# Captum expects a forward function that returns a tensor (not a ModelOutput dataclass)
def forward_func(pixel_values):
    # ensure we call the model and return raw logits or probabilities as a Tensor
    outputs = model(pixel_values=pixel_values)
    # outputs is a ModelOutput dataclass; return the logits tensor
    return outputs.logits

# IntegratedGradients should be given the forward function
ig = IntegratedGradients(forward_func)

# Captum needs the inputs to require gradients
input_tensor = inputs["pixel_values"].clone().detach()
input_tensor.requires_grad_(True)

# Now compute attributions for the predicted class index
# (recompute with more steps and ask for convergence delta)
attributions, convergence_delta = ig.attribute(
    input_tensor,
    target=pred_idx,
    n_steps=100,
    return_convergence_delta=True,
)
logger.info("IG convergence delta: %s", convergence_delta)

# ---- Step 5: Visualize attribution heatmap (normalized + overlay) ----

# aggregate over channels (signed mean keeps sign of contributions)
attr = attributions.squeeze().mean(dim=0).detach().cpu().numpy()

# Normalize to [-1,1] to show positive vs negative contributions with diverging colormap
min_v, max_v = float(attr.min()), float(attr.max())
norm_denom = max(abs(min_v), abs(max_v)) + 1e-8
attr_signed = attr / norm_denom  # now in approx [-1,1]

# OPTIONAL: smooth heatmap slightly to make overlays more intuitive
try:
    heat_pil = PILImage.fromarray(np.uint8((attr_signed + 1) * 127.5))
    heat_pil = heat_pil.filter(ImageFilter.GaussianBlur(radius=1.5))
    attr_signed = (np.array(heat_pil).astype(float) / 127.5) - 1.0
except Exception:
    # If PIL filter not available, continue without smoothing
    pass

# Create overlay using a diverging colormap (positive = warm, negative = cool)
plt.figure(figsize=(6,6))
plt.imshow(img)
plt.imshow(attr_signed, cmap="seismic", alpha=0.45, vmin=-1, vmax=1)
cb = plt.colorbar(fraction=0.046, pad=0.04)
cb.set_label("Signed attribution (normalized)")
plt.title(f"IG overlay — pred: {model.config.id2label[pred_idx]} ({float(probs.squeeze()[pred_idx]):.3f})")
plt.axis("off")

# Show standalone signed heatmap for clearer inspection
plt.figure(figsize=(4,4))
plt.imshow(attr_signed, cmap="seismic", vmin=-1, vmax=1)
plt.colorbar()
plt.title("Signed IG Attribution (neg=blue, pos=red)")
plt.axis("off")

plt.show()

# Add concise runtime interpretability guidance
def print_interpretability_summary():
    print("\nHow to read the results (quick guide):")
    print("- IG signed heatmap: red/warm = supports the predicted class; blue/cool = opposes it.")
    print("- Normalize by max-abs when comparing images. Check IG 'convergence delta' — large values mean treat attributions cautiously.")
    print("- LIME panel (if used): green/highlighted superpixels indicate locally important regions; background-dominated explanations are a red flag.")
    print("- MC Dropout histogram: narrow peak → stable belief; wide/multi-modal → epistemic uncertainty.")
    print("- TTA histogram: many flips under small augmentations → fragile/aleatoric sensitivity.")
    print("- Predictive entropy: higher → more uncertainty in the full distribution.")
    print("- Variation ratio: fraction of samples not matching majority; higher → more disagreement.\n")

print_interpretability_summary()