Spaces:
Sleeping
Sleeping
File size: 31,484 Bytes
8af51e2 47ff5e5 8af51e2 47ff5e5 8af51e2 47ff5e5 8af51e2 47ff5e5 8af51e2 0ee8310 8af51e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 |
'''
streamlitapp.py โ Vision Transformer Interpretability Dashboard (Streamlit app)
This Streamlit app provides interpretability tools for vision transformer and CNN models.
Features:
- LIME explanations for image classification predictions
- Uncertainty analysis via MC Dropout and Test-Time Augmentation (TTA)
- Switch between Hugging Face (ViT, Swin, DeiT) and timm (ResNet, EfficientNet, ConvNeXt) models
- Support for custom finetuned models and class mappings
- Interactive sidebar for model selection and checkpoint upload
- Feynman-style explanations and cheat-sheet for interpretability concepts
Inspired by and reuses code from:
- vit_and_captum.py (Integrated Gradients with Captum)
- vit_lime_uncertainty.py (LIME explanations and uncertainty)
- detr_and_interp.py (Grad-CAM for DETR, logging setup)
'''
import streamlit as st
import html
import numpy as np, torch, matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor, PreTrainedModel
from lime import lime_image
import torchvision.transforms as T
import timm
from skimage.segmentation import slic, mark_boundaries
import streamlit.components.v1 as components
# Add logging
import logging, os
from logging.handlers import RotatingFileHandler
LOG_DIR = os.path.join(os.path.dirname(__file__), "logs")
os.makedirs(LOG_DIR, exist_ok=True)
logfile = os.path.join(LOG_DIR, "interp.log")
logger = logging.getLogger("interp")
if not logger.handlers:
logger.setLevel(logging.INFO)
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
fh = RotatingFileHandler(logfile, maxBytes=5_000_000, backupCount=3, encoding="utf-8")
fh.setLevel(logging.INFO)
fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")
sh.setFormatter(fmt)
fh.setFormatter(fmt)
logger.addHandler(sh)
logger.addHandler(fh)
# ---------------- Setup ----------------
MODEL_NAME = "google/vit-base-patch16-224"
device = "cuda" if torch.cuda.is_available() else "cpu"
# ---------- Sidebar model selectors ----------
# Quick lists you can edit to test other HF / timm models
HF_MODELS = [
"google/vit-base-patch16-224",
"facebook/deit-base-patch16-224",
"microsoft/swin-tiny-patch4-window7-224",
"google/vit-large-patch16-224",
]
TIMM_MODELS = [
"convnext_base",
"resnet50",
"efficientnet_b0",
]
def model_selector(slot_key: str, default_source="hf"):
source = st.sidebar.selectbox(
f"{slot_key} source",
["hf", "timm"],
index=0 if default_source == "hf" else 1,
key=f"{slot_key}_source",
)
if source == "hf":
hf_choice = st.sidebar.selectbox(
f"{slot_key} Hugging Face model",
HF_MODELS,
index=0,
key=f"{slot_key}_hf",
)
return f"hf:{hf_choice}"
else:
timm_choice = st.sidebar.selectbox(
f"{slot_key} timm model",
TIMM_MODELS,
index=0,
key=f"{slot_key}_timm",
)
return f"timm:{timm_choice}"
# ---------- Model Loader ----------
# Use Streamlit caching when available to avoid repeated downloads
try:
cache_decorator = st.cache_resource
except Exception:
from functools import lru_cache
cache_decorator = lru_cache(maxsize=8)
@cache_decorator
def load_model(choice, checkpoint=None, class_map=None, num_classes=None):
"""
Load a model from HF, timm, or a custom checkpoint
Args:
choice: Model identifier ('hf:model_name' or 'timm:model_name')
checkpoint: Optional path to custom checkpoint file
class_map: Optional dict mapping class indices to labels
num_classes: Optional number of classes for custom models
"""
logger.info("Loading model: %s", choice)
is_hf = choice.startswith("hf:")
# Parse model identifier
if is_hf:
hf_name = choice.split("hf:")[1]
if checkpoint: # Custom checkpoint
# For custom HF model, first load the architecture then apply weights
try:
if num_classes:
model = AutoModelForImageClassification.from_pretrained(
hf_name, num_labels=num_classes, ignore_mismatched_sizes=True
).to(device)
else:
model = AutoModelForImageClassification.from_pretrained(hf_name).to(device)
# Load checkpoint with error handling
state_dict = torch.load(checkpoint, map_location=device)
# If state_dict is wrapped (common in training checkpoints)
if "model" in state_dict:
state_dict = state_dict["model"]
elif "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
# Handle any prefix differences by checking and stripping if needed
if all(k.startswith('model.') for k in state_dict if k != 'config'):
state_dict = {k[6:]: v for k, v in state_dict.items() if k != 'config'}
# Load with flexible partial loading (ignore missing/unexpected)
model.load_state_dict(state_dict, strict=False)
logger.info("Custom checkpoint loaded for HF model")
# If custom class mapping provided, update config
if class_map:
model.config.id2label = class_map
model.config.label2id = {v: int(k) for k, v in class_map.items()}
except Exception as e:
logger.error(f"Error loading custom HF model: {e}")
st.error(f"Failed to load custom model: {e}")
# Fallback to base model
model = AutoModelForImageClassification.from_pretrained(hf_name).to(device)
else:
# Standard HF model
model = AutoModelForImageClassification.from_pretrained(hf_name).to(device)
processor = AutoImageProcessor.from_pretrained(hf_name)
elif choice.startswith("timm:"):
name = choice.split("timm:")[1]
if checkpoint: # Custom checkpoint
try:
# For timm, specify custom number of classes if provided
if num_classes:
model = timm.create_model(name, pretrained=False, num_classes=num_classes).to(device)
else:
model = timm.create_model(name, pretrained=True).to(device)
# Load checkpoint
state_dict = torch.load(checkpoint, map_location=device)
# Handle common checkpoint formats
if "model" in state_dict:
state_dict = state_dict["model"]
elif "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
# Handle any prefix differences
if all(k.startswith('module.') for k in state_dict):
state_dict = {k[7:]: v for k, v in state_dict}
model.load_state_dict(state_dict, strict=False)
logger.info("Custom checkpoint loaded for timm model")
except Exception as e:
logger.error(f"Error loading custom timm model: {e}")
st.error(f"Failed to load custom model: {e}")
# Fallback to pretrained
model = timm.create_model(name, pretrained=True).to(device)
else:
# Standard timm model
model = timm.create_model(name, pretrained=True).to(device)
# Use a standard processor for timm
processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224")
# Set model to eval mode
model.eval()
logger.info("Model %s loaded (eval mode)", choice)
# Return model, processor, flag for HF, and class map
return model, processor, is_hf, class_map
# Add sidebar with clear sections
st.sidebar.title("Model Selection")
# Enhanced sidebar with custom model support
with st.sidebar:
# Add tabs for standard vs custom models
tab1, tab2 = st.tabs(["Standard Models", "Custom Finetuned Models"])
with tab1:
st.markdown("### ๐ Standard Models")
st.markdown("Choose from pre-trained models:")
m1 = model_selector("Active Model", default_source="hf")
# Button to apply standard model change
if st.button("๐ Set as Active Model", help="Click to use the selected model for analysis", key="std_model_btn"):
with st.spinner(f"Loading {m1}..."):
model, processor, is_hf_model, _ = load_model(m1)
st.session_state.model = model
st.session_state.processor = processor
st.session_state.is_hf_model = is_hf_model
st.session_state.active_model = m1
st.session_state.using_custom = False
st.session_state.class_map = None
st.success(f"โ
Model activated: {m1}")
with tab2:
st.markdown("### ๐ง Custom Finetuned Model")
st.markdown("Use your own finetuned model:")
# Select base architecture
custom_source = st.selectbox(
"Base architecture source",
["hf", "timm"],
key="custom_source"
)
if custom_source == "hf":
custom_base = st.selectbox(
"Hugging Face base model",
HF_MODELS,
key="custom_hf_base"
)
base_model = f"hf:{custom_base}"
else:
custom_base = st.selectbox(
"timm base model",
TIMM_MODELS,
key="custom_timm_base"
)
base_model = f"timm:{custom_base}"
# Upload checkpoint file
uploaded_checkpoint = st.file_uploader(
"Upload model checkpoint (.pth, .bin)",
type=["pth", "bin", "pt", "ckpt"],
help="Upload your finetuned model weights"
)
# Optional class mapping
custom_classes = st.number_input(
"Number of classes (if different from base model)",
min_value=0, max_value=1000, value=0,
help="Leave at 0 to use default classes from base model"
)
uploaded_labels = st.file_uploader(
"Upload class labels (optional JSON)",
type=["json"],
help="JSON file mapping class indices to labels: {\"0\": \"cat\", \"1\": \"dog\"}"
)
# Process label mapping
class_map = None
if uploaded_labels:
try:
import json
class_map = json.loads(uploaded_labels.getvalue().decode("utf-8"))
st.success(f"โ Loaded {len(class_map)} class labels")
except Exception as e:
st.error(f"Error loading class labels: {e}")
# Store uploaded file in session state if provided
if uploaded_checkpoint:
# Save to a temporary file
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as tmp_file:
tmp_file.write(uploaded_checkpoint.getvalue())
checkpoint_path = tmp_file.name
# Store in session state
if 'checkpoint_path' not in st.session_state:
st.session_state.checkpoint_path = checkpoint_path
st.success("โ Checkpoint ready to use")
# Button to apply custom model
if st.button("๐ Load Custom Model", help="Click to use your custom model"):
with st.spinner(f"Loading custom model based on {base_model}..."):
try:
num_classes = custom_classes if custom_classes > 0 else None
model, processor, is_hf_model, class_map = load_model(
base_model, checkpoint_path, class_map, num_classes
)
st.session_state.model = model
st.session_state.processor = processor
st.session_state.is_hf_model = is_hf_model
st.session_state.active_model = f"Custom {base_model}"
st.session_state.using_custom = True
st.session_state.class_map = class_map
st.success(f"โ
Custom model activated!")
except Exception as e:
st.error(f"Failed to load custom model: {str(e)}")
# Explanation section
st.markdown("---")
st.markdown("### โน๏ธ Model Types")
st.markdown("""
- **HF (Hugging Face)**: Vision Transformer models with standard interpretability
- **timm (PyTorch Image Models)**: Classical CNN architectures like ResNet, EfficientNet
*Custom models must match the base architecture's format.*
""")
# Initialize model and processor from session state
if 'active_model' not in st.session_state:
# First time loading - use default model
m1 = "hf:google/vit-base-patch16-224"
st.session_state.active_model = m1
model, processor, is_hf_model, _ = load_model(m1)
st.session_state.model = model
st.session_state.processor = processor
st.session_state.is_hf_model = is_hf_model
st.session_state.using_custom = False
st.session_state.class_map = None
else:
# Get from session state
model = st.session_state.model
processor = st.session_state.processor
is_hf_model = st.session_state.is_hf_model
# Initialize explainer
explainer = lime_image.LimeImageExplainer()
st.title("๐ง Vision Transformer Interpretability Dashboard")
st.write("Upload an image and explore explanations with **LIME** and **Uncertainty Analysis**.")
# Add a Feynman-style "How it works" explanation as a collapsible expander
with st.expander("How it works โ Feynman-style explanations (click to expand)", expanded=False):
st.markdown("""
## ๐ง Vision Transformer Interpretability โ Feynman-Style Explanations
### Why do we care about interpretability & uncertainty?
Imagine you ask a kid to identify whether a picture is a cat. They point to the fur, ears, maybe whiskers. But what if the kid always focused on shadows, or background trees, instead of the cat itself? We want two things:
1. **Why** did the model say โcatโ? What parts of the image made it decide so?
2. **How confident** is the model in that decision? Could small changes flip it?
Interpretable methods show us #1. Uncertainty estimation shows us #2. Together, they help us see not just *what* the model does, but *whether* we should trust it.
### Key techniques, in plain analogies
- **LIME (Local Interpretable Model-agnostic Explanations)**: For a single image & prediction, LIME perturbs (changes) parts of the image, watches how the prediction changes, and fits a simple model locally to understand which parts are most influential.
- Analogy: Like shining small spotlights on different parts of a stage during a play: you dim a section, see how the actorโs reaction changes. The parts whose dimming changes the reaction most are parts the actor depends on.
- **Uncertainty in LIME (multiple LIME runs)**: Because LIME uses randomness (perturbing patches), different runs can give different โimportantโ regions. Measuring how much they differ tells you how stable/fragile the explanation is.
- Analogy: If you ask several cooks what the dominant spice in a stew is and everyone agrees, you're confident; if opinions vary, your knowledge is shakier.
- **MC Dropout (Monte Carlo Dropout)**: Leave dropout on at inference time and run the model multiple times. The spread of predictions is a proxy for epistemic uncertainty.
- Analogy: Like a jury where each juror occasionally misses a sentence; if the verdict remains the same across many "faulty hearing" runs, trust it more.
- **Test-Time Augmentation (TTA) Uncertainty**: Apply small transforms (crops, flips) at inference and watch prediction variance. High variance โ brittle model.
- Analogy: Take photos under slightly different lighting/angles; if the label flips, the model may depend on superficial cues.
### How to read the visuals
- LIME highlights: bright / colored superpixels = influential regions. If background or artifacts light up, that's a red flag.
- LIME uncertainty heatmap: high std in a region means attributions are unstable there.
- MC Dropout / TTA histograms: narrow/tall peak = confident, wide/multi-modal = uncertain.
### Limitations & caveats
- Stable explanations can still be consistently wrong if the model learned a bias.
- MC Dropout is an approximation โ it helps but doesn't fully replace calibrated probabilistic methods.
- TTA shows input sensitivity, not full distributional shift robustness.
### Quick example (walkthrough)
1. Upload image โ model predicts label with some probability.
2. LIME finds important superpixels; multiple LIME runs give mean + std maps.
3. MC Dropout produces a histogram over runs; use it to judge epistemic uncertainty.
4. TTA shows sensitivity to small input changes.
### Practical tips
- Use explanation + uncertainty to guide active learning: label cases where the model is uncertain or explanations are unstable.
- For safety-critical systems, combine these visual signals with human review and stricter failure thresholds.
### Where to read more
- Christoph Molnar โ Interpretable Machine Learning (chapter on LIME): https://christophm.github.io/interpretable-ml-book/lime.html
- Ribeiro et al., "Why Should I Trust You?" (original LIME paper): https://homes.cs.washington.edu/~marcotcr/blog/lime/
- Zhang et al., "Why Should You Trust My Explanation?" (LIME reliability): https://arxiv.org/abs/1904.12991
- MC Dropout practical guide & notes: https://medium.com/@ciaranbench/monte-carlo-dropout-a-practical-guide-4b4dc18014b5
""")
# Compact one-page cheat-sheet (quick flags & checks)
with st.expander("Cheat-sheet โ Quick flags & warnings", expanded=False):
cheat_text = """
Quick checks when an explanation looks suspicious
- Red flag: LIME highlights background or repeated dataset artifacts (logos, borders) โ model may have learned spurious cues.
- Red flag: LIME attribution std is high in key regions โ explanation unstable; try different segmentations or more samples.
- Red flag: MC Dropout or TTA histograms are multi-modal or very wide โ model uncertain; consider human review or abstain.
- Quick fixes: increase dataset diversity, add regularization, try different segmentation_fn parameters, or collect more labels for uncertain cases.
One-line definitions
- LIME: perturb + fit simple local model to explain a single prediction.
- MC Dropout: enable dropout at inference and sample to estimate epistemic uncertainty.
- TTA: apply small input transforms at inference to measure sensitivity / aleatoric uncertainty.
Pro-tip: Use explanation + uncertainty to drive active learning: pick instances with high prediction uncertainty or unstable explanations for labeling.
"""
# Show the cheat-sheet as markdown
st.markdown(cheat_text)
# Download button for the cheat-sheet as plain text
try:
st.download_button(
label="Download cheat-sheet (.txt)",
data=cheat_text,
file_name="cheat_sheet.txt",
mime="text/plain",
)
except Exception:
# Streamlit may raise if download_button isn't available in some environments; ignore gracefully
pass
# Copy-to-clipboard button using a small HTML+JS snippet
escaped = html.escape(cheat_text)
copy_html = f"""
<div>
<button id='copy-btn' style='padding:6px 10px;border-radius:4px;'>Copy cheat-sheet</button>
<script>
const btn = document.getElementById('copy-btn');
btn.addEventListener('click', async () => {{
try {{
await navigator.clipboard.writeText(`{escaped}`);
btn.innerText = 'Copied!';
setTimeout(() => btn.innerText = 'Copy cheat-sheet', 1500);
}} catch (e) {{
btn.innerText = 'Copy failed';
}}
}});
</script>
</div>
"""
components.html(copy_html, height=70)
# Display active model clearly in the main panel
is_custom = st.session_state.get('using_custom', False)
custom_badge = " ๐ง Custom" if is_custom else ""
st.markdown(f"### Active Model: `{st.session_state.active_model}{custom_badge}`")
model_type = "Hugging Face Transformer" if is_hf_model else "timm CNN Architecture"
st.caption(f"Model type: {model_type}")
# ---------------- Helpers ----------------
def classifier_fn(images_batch):
# Use current model/processor from session state
inputs = processor(images=[Image.fromarray(x.astype(np.uint8)) for x in images_batch],
return_tensors="pt").to(device)
with torch.no_grad():
if is_hf_model:
outputs = model(**inputs)
logits = outputs.logits
else:
x = inputs['pixel_values']
logits = model(x)
probs = torch.softmax(logits, dim=-1).cpu().numpy()
return probs
def predict_probs(pil_img):
# Use current model/processor from session state
inputs = processor(images=pil_img, return_tensors="pt").to(device)
with torch.no_grad():
if is_hf_model:
outputs = model(**inputs)
logits = outputs.logits
else:
x = inputs['pixel_values']
logits = model(x)
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
return probs
# ---------------- Upload ----------------
uploaded = st.file_uploader("Upload an image", type=["png","jpg","jpeg"])
if uploaded:
img = Image.open(uploaded).convert("RGB").resize((224,224))
logger.info("Uploaded image received (size=%s)", img.size)
# Streamlit 1.XX: replace deprecated `use_container_width` with `width`
# For full-width behavior use width='stretch' (or 'content' for intrinsic size)
st.image(img, caption="Uploaded image", width='stretch')
# ---------------- Prediction ----------------
probs = predict_probs(img)
pred_idx = int(np.argmax(probs))
# Get label - handle models differently based on source
if is_hf_model:
# Use model's config.id2label if available
pred_label = model.config.id2label[pred_idx]
elif st.session_state.get('class_map'):
# Use custom class map if provided (access defensively)
_class_map = st.session_state.get('class_map')
pred_label = _class_map.get(str(pred_idx), f"Class {pred_idx}") if _class_map is not None else f"Class {pred_idx}"
else:
# For timm models without labels
pred_label = f"Class {pred_idx}"
pred_prob = float(probs[pred_idx])
logger.info("Prediction: %s (%.3f)", pred_label, pred_prob)
st.subheader("๐ฎ Prediction")
st.write(f"**Top-1:** {pred_label} ({pred_prob:.3f})")
if not is_hf_model and not st.session_state.get('class_map'):
st.info("โน๏ธ Using model without class names. Upload a class mapping in the sidebar for friendly labels.")
# ---------------- LIME ----------------
st.subheader("๐ LIME Attribution")
st.markdown("""
**Local Interpretable Model-agnostic Explanations (LIME)** is a technique that approximates how a complex model (like ViT or ResNet) makes decisions for a specific input by creating a simpler, interpretable model around it.
It perturbs the image into segments and sees which ones most influence the prediction, revealing what the model "sees" as important.
This is crucial for debugging biases or understanding if the model focuses on relevant features vs. artifacts.
""")
img_np = np.array(img)
with st.spinner("Generating LIME explanation..."):
exp = explainer.explain_instance(
img_np, classifier_fn=classifier_fn, top_labels=1, num_samples=1000,
segmentation_fn=lambda x: slic(x, n_segments=60, compactness=9, start_label=0)
)
temp, mask = exp.get_image_and_mask(pred_idx, positive_only=True,
num_features=8, hide_rest=False)
lime_img = mark_boundaries(temp/255.0, mask)
st.image(lime_img, caption=f"LIME highlights regions important for '{pred_label}'")
st.info("""
**How to read:** Bright (or colored) segments show areas the model relied on most for its prediction โ these are the "superpixels" that, when altered, change the output the most.
Green/red overlays often indicate positive/negative contributions. If irrelevant background or edges light up, it might signal the model learned spurious correlations (e.g., from training data artifacts).
Furthermore, this builds trust by showing if AI decisions align with human intuition.
""")
# ---------------- LIME Uncertainty ----------------
st.subheader("๐ LIME Attribution Uncertainty")
st.markdown("""
Uncertainty in explanations arises because LIME is stochastic โ it samples perturbations randomly. By running LIME multiple times, we can measure variability in attributions,
highlighting if the model's reasoning is consistent or fragile for this image. High variability suggests the explanation (and thus model confidence) isn't robust.
""")
logger.info("Starting LIME uncertainty runs (n=5)")
maps = []
for i in range(5):
logger.debug("LIME run %d", i+1)
exp = explainer.explain_instance(
img_np, classifier_fn=classifier_fn, top_labels=1, num_samples=500,
segmentation_fn=lambda x: slic(x, n_segments=60, compactness=9, start_label=0)
)
local_exp = dict(exp.local_exp)[pred_idx]
segments = exp.segments
attr_map = np.zeros(segments.shape)
for seg_id, weight in local_exp:
attr_map[segments == seg_id] = weight
maps.append(attr_map)
maps = np.stack(maps)
mean_attr, std_attr = maps.mean(0), maps.std(0)
fig, ax = plt.subplots(1,2, figsize=(8,4))
im1 = ax[0].imshow(mean_attr, cmap="jet"); ax[0].set_title("Mean attribution"); ax[0].axis("off")
plt.colorbar(im1, ax=ax[0], fraction=0.046)
im2 = ax[1].imshow(std_attr, cmap="hot"); ax[1].set_title("Attribution std (uncertainty)"); ax[1].axis("off")
plt.colorbar(im2, ax=ax[1], fraction=0.046)
st.pyplot(fig)
st.info("""
**How to read:** The left heatmap shows average importance across runs (hotter = more influential). The right shows standard deviation โ high std (yellow/red) means unstable explanations for those regions.
If uncertainty is high in key areas, the model might overfit or need more diverse training data. This helps ML practitioners quantify explanation reliability.
""")
logger.info("Completed LIME uncertainty runs")
# ---------------- MC Dropout ----------------
st.subheader("๐ฒ MC Dropout Uncertainty")
st.markdown("""
Monte Carlo (MC) Dropout treats dropout layers (normally off during inference) as a Bayesian approximation to estimate epistemic uncertainty โ how much the model "doesn't know" due to limited training.
By enabling dropout and sampling predictions multiple times, we see if the model consistently agrees on the class or wavers, indicating potential unreliability.
""")
logger.info("Starting MC Dropout sampling")
model.train() # enable dropout
mc_preds = []
with torch.no_grad():
for _ in range(30):
probs_mc = predict_probs(img)
mc_preds.append(probs_mc)
model.eval()
mc_preds = np.stack(mc_preds)
mc_mean = mc_preds.mean(0)
mc_top = mc_mean.argmax()
if is_hf_model:
mc_label = model.config.id2label[mc_top]
elif st.session_state.get('class_map'):
_class_map = st.session_state.get('class_map')
mc_label = _class_map.get(str(mc_top), f"Class {mc_top}") if _class_map is not None else f"Class {mc_top}"
else:
mc_label = f"Class {mc_top}"
p = mc_preds[:, mc_top]
fig, ax = plt.subplots()
ax.hist(p, bins=15, color="C0")
ax.set_title(f"MC Dropout: p({mc_label}) across samples")
st.pyplot(fig)
st.info("""
**How to read:** This histogram shows probability distributions for the top class across 30 samples. A narrow, peaked distribution means stable confidence (low uncertainty).
A wide spread or multiple modes suggests the model is unsure, possibly due to out-of-distribution inputs. For devs, this flags cases needing human review; it highlights risky predictions.
""")
logger.info("Completed MC Dropout: top=%s", mc_label)
# ---------------- Test-Time Augmentation (TTA) Uncertainty ----------------
st.subheader("๐ Test-Time Augmentation (TTA) Uncertainty")
st.markdown("""
Test-Time Augmentation (TTA) applies random transformations (crops, flips) at inference to probe aleatoric uncertainty โ noise inherent in the input or model.
If predictions vary wildly under small changes, the model relies on brittle features, revealing data-related issues rather than model knowledge gaps.
""")
logger.info("Starting TTA sampling")
tta_tfms = T.Compose([T.Resize(256), T.RandomResizedCrop(224, scale=(0.9,1.0)), T.RandomHorizontalFlip(p=0.5)])
tta_preds = []
with torch.no_grad():
for _ in range(20):
aug = tta_tfms(img)
probs_tta = predict_probs(aug)
tta_preds.append(probs_tta)
tta_preds = np.stack(tta_preds)
tta_mean = tta_preds.mean(0)
tta_top = tta_mean.argmax()
if is_hf_model:
tta_label = model.config.id2label[tta_top]
elif st.session_state.get('class_map'):
_class_map = st.session_state.get('class_map')
tta_label = _class_map.get(str(tta_top), f"Class {tta_top}") if _class_map is not None else f"Class {tta_top}"
else:
tta_label = f"Class {tta_top}"
p_tta = tta_preds[:, tta_top]
fig, ax = plt.subplots()
ax.hist(p_tta, bins=15, color="C1")
ax.set_title(f"TTA: p({tta_label}) across augmentations")
st.pyplot(fig)
st.info("""
**How to read:** Similar to MC Dropout, but focused on input variations. Low variance means the prediction is robust to perturbations (good sign). High variance indicates sensitivity to details like lighting/position,
common in overfitted models. Use this to assess if your AI system handles real-world variability well.
""")
logger.info("Completed TTA: top=%s", tta_label)
# ---------------- Summary ---------------- |