Spaces:
Sleeping
Sleeping
| ''' | |
| streamlitapp.py — Vision Transformer Interpretability Dashboard (Streamlit app) | |
| This Streamlit app provides interpretability tools for vision transformer and CNN models. | |
| Features: | |
| - LIME explanations for image classification predictions | |
| - Uncertainty analysis via MC Dropout and Test-Time Augmentation (TTA) | |
| - Switch between Hugging Face (ViT, Swin, DeiT) and timm (ResNet, EfficientNet, ConvNeXt) models | |
| - Support for custom finetuned models and class mappings | |
| - Interactive sidebar for model selection and checkpoint upload | |
| - Feynman-style explanations and cheat-sheet for interpretability concepts | |
| Inspired by and reuses code from: | |
| - vit_and_captum.py (Integrated Gradients with Captum) | |
| - vit_lime_uncertainty.py (LIME explanations and uncertainty) | |
| - detr_and_interp.py (Grad-CAM for DETR, logging setup) | |
| ''' | |
| import streamlit as st | |
| import html | |
| import numpy as np, torch, matplotlib.pyplot as plt | |
| from PIL import Image | |
| from transformers import AutoModelForImageClassification, AutoImageProcessor, PreTrainedModel | |
| from lime import lime_image | |
| import torchvision.transforms as T | |
| import timm | |
| from skimage.segmentation import slic, mark_boundaries | |
| import streamlit.components.v1 as components | |
| # Add logging | |
| import logging, os | |
| from logging.handlers import RotatingFileHandler | |
| LOG_DIR = os.path.join(os.path.dirname(__file__), "logs") | |
| os.makedirs(LOG_DIR, exist_ok=True) | |
| logfile = os.path.join(LOG_DIR, "interp.log") | |
| logger = logging.getLogger("interp") | |
| if not logger.handlers: | |
| logger.setLevel(logging.INFO) | |
| sh = logging.StreamHandler() | |
| sh.setLevel(logging.INFO) | |
| fh = RotatingFileHandler(logfile, maxBytes=5_000_000, backupCount=3, encoding="utf-8") | |
| fh.setLevel(logging.INFO) | |
| fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s") | |
| sh.setFormatter(fmt) | |
| fh.setFormatter(fmt) | |
| logger.addHandler(sh) | |
| logger.addHandler(fh) | |
| # ---------------- Setup ---------------- | |
| MODEL_NAME = "google/vit-base-patch16-224" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ---------- Sidebar model selectors ---------- | |
| # Quick lists you can edit to test other HF / timm models | |
| HF_MODELS = [ | |
| "google/vit-base-patch16-224", | |
| "facebook/deit-base-patch16-224", | |
| "microsoft/swin-tiny-patch4-window7-224", | |
| "google/vit-large-patch16-224", | |
| ] | |
| TIMM_MODELS = [ | |
| "convnext_base", | |
| "resnet50", | |
| "efficientnet_b0", | |
| ] | |
| def model_selector(slot_key: str, default_source="hf"): | |
| source = st.sidebar.selectbox( | |
| f"{slot_key} source", | |
| ["hf", "timm"], | |
| index=0 if default_source == "hf" else 1, | |
| key=f"{slot_key}_source", | |
| ) | |
| if source == "hf": | |
| hf_choice = st.sidebar.selectbox( | |
| f"{slot_key} Hugging Face model", | |
| HF_MODELS, | |
| index=0, | |
| key=f"{slot_key}_hf", | |
| ) | |
| return f"hf:{hf_choice}" | |
| else: | |
| timm_choice = st.sidebar.selectbox( | |
| f"{slot_key} timm model", | |
| TIMM_MODELS, | |
| index=0, | |
| key=f"{slot_key}_timm", | |
| ) | |
| return f"timm:{timm_choice}" | |
| # ---------- Model Loader ---------- | |
| # Use Streamlit caching when available to avoid repeated downloads | |
| try: | |
| cache_decorator = st.cache_resource | |
| except Exception: | |
| from functools import lru_cache | |
| cache_decorator = lru_cache(maxsize=8) | |
| def load_model(choice, checkpoint=None, class_map=None, num_classes=None): | |
| """ | |
| Load a model from HF, timm, or a custom checkpoint | |
| Args: | |
| choice: Model identifier ('hf:model_name' or 'timm:model_name') | |
| checkpoint: Optional path to custom checkpoint file | |
| class_map: Optional dict mapping class indices to labels | |
| num_classes: Optional number of classes for custom models | |
| """ | |
| logger.info("Loading model: %s", choice) | |
| is_hf = choice.startswith("hf:") | |
| # Parse model identifier | |
| if is_hf: | |
| hf_name = choice.split("hf:")[1] | |
| if checkpoint: # Custom checkpoint | |
| # For custom HF model, first load the architecture then apply weights | |
| try: | |
| if num_classes: | |
| model = AutoModelForImageClassification.from_pretrained( | |
| hf_name, num_labels=num_classes, ignore_mismatched_sizes=True | |
| ).to(device) | |
| else: | |
| model = AutoModelForImageClassification.from_pretrained(hf_name).to(device) | |
| # Load checkpoint with error handling | |
| state_dict = torch.load(checkpoint, map_location=device) | |
| # If state_dict is wrapped (common in training checkpoints) | |
| if "model" in state_dict: | |
| state_dict = state_dict["model"] | |
| elif "state_dict" in state_dict: | |
| state_dict = state_dict["state_dict"] | |
| # Handle any prefix differences by checking and stripping if needed | |
| if all(k.startswith('model.') for k in state_dict if k != 'config'): | |
| state_dict = {k[6:]: v for k, v in state_dict.items() if k != 'config'} | |
| # Load with flexible partial loading (ignore missing/unexpected) | |
| model.load_state_dict(state_dict, strict=False) | |
| logger.info("Custom checkpoint loaded for HF model") | |
| # If custom class mapping provided, update config | |
| if class_map: | |
| model.config.id2label = class_map | |
| model.config.label2id = {v: int(k) for k, v in class_map.items()} | |
| except Exception as e: | |
| logger.error(f"Error loading custom HF model: {e}") | |
| st.error(f"Failed to load custom model: {e}") | |
| # Fallback to base model | |
| model = AutoModelForImageClassification.from_pretrained(hf_name).to(device) | |
| else: | |
| # Standard HF model | |
| model = AutoModelForImageClassification.from_pretrained(hf_name).to(device) | |
| processor = AutoImageProcessor.from_pretrained(hf_name) | |
| elif choice.startswith("timm:"): | |
| name = choice.split("timm:")[1] | |
| if checkpoint: # Custom checkpoint | |
| try: | |
| # For timm, specify custom number of classes if provided | |
| if num_classes: | |
| model = timm.create_model(name, pretrained=False, num_classes=num_classes).to(device) | |
| else: | |
| model = timm.create_model(name, pretrained=True).to(device) | |
| # Load checkpoint | |
| state_dict = torch.load(checkpoint, map_location=device) | |
| # Handle common checkpoint formats | |
| if "model" in state_dict: | |
| state_dict = state_dict["model"] | |
| elif "state_dict" in state_dict: | |
| state_dict = state_dict["state_dict"] | |
| # Handle any prefix differences | |
| if all(k.startswith('module.') for k in state_dict): | |
| state_dict = {k[7:]: v for k, v in state_dict} | |
| model.load_state_dict(state_dict, strict=False) | |
| logger.info("Custom checkpoint loaded for timm model") | |
| except Exception as e: | |
| logger.error(f"Error loading custom timm model: {e}") | |
| st.error(f"Failed to load custom model: {e}") | |
| # Fallback to pretrained | |
| model = timm.create_model(name, pretrained=True).to(device) | |
| else: | |
| # Standard timm model | |
| model = timm.create_model(name, pretrained=True).to(device) | |
| # Use a standard processor for timm | |
| processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224") | |
| # Set model to eval mode | |
| model.eval() | |
| logger.info("Model %s loaded (eval mode)", choice) | |
| # Return model, processor, flag for HF, and class map | |
| return model, processor, is_hf, class_map | |
| # Add sidebar with clear sections | |
| st.sidebar.title("Model Selection") | |
| # Enhanced sidebar with custom model support | |
| with st.sidebar: | |
| # Add tabs for standard vs custom models | |
| tab1, tab2 = st.tabs(["Standard Models", "Custom Finetuned Models"]) | |
| with tab1: | |
| st.markdown("### 📊 Standard Models") | |
| st.markdown("Choose from pre-trained models:") | |
| m1 = model_selector("Active Model", default_source="hf") | |
| # Button to apply standard model change | |
| if st.button("📋 Set as Active Model", help="Click to use the selected model for analysis", key="std_model_btn"): | |
| with st.spinner(f"Loading {m1}..."): | |
| model, processor, is_hf_model, _ = load_model(m1) | |
| st.session_state.model = model | |
| st.session_state.processor = processor | |
| st.session_state.is_hf_model = is_hf_model | |
| st.session_state.active_model = m1 | |
| st.session_state.using_custom = False | |
| st.session_state.class_map = None | |
| st.success(f"✅ Model activated: {m1}") | |
| with tab2: | |
| st.markdown("### 🔧 Custom Finetuned Model") | |
| st.markdown("Use your own finetuned model:") | |
| # Select base architecture | |
| custom_source = st.selectbox( | |
| "Base architecture source", | |
| ["hf", "timm"], | |
| key="custom_source" | |
| ) | |
| if custom_source == "hf": | |
| custom_base = st.selectbox( | |
| "Hugging Face base model", | |
| HF_MODELS, | |
| key="custom_hf_base" | |
| ) | |
| base_model = f"hf:{custom_base}" | |
| else: | |
| custom_base = st.selectbox( | |
| "timm base model", | |
| TIMM_MODELS, | |
| key="custom_timm_base" | |
| ) | |
| base_model = f"timm:{custom_base}" | |
| # Upload checkpoint file | |
| uploaded_checkpoint = st.file_uploader( | |
| "Upload model checkpoint (.pth, .bin)", | |
| type=["pth", "bin", "pt", "ckpt"], | |
| help="Upload your finetuned model weights" | |
| ) | |
| # Optional class mapping | |
| custom_classes = st.number_input( | |
| "Number of classes (if different from base model)", | |
| min_value=0, max_value=1000, value=0, | |
| help="Leave at 0 to use default classes from base model" | |
| ) | |
| uploaded_labels = st.file_uploader( | |
| "Upload class labels (optional JSON)", | |
| type=["json"], | |
| help="JSON file mapping class indices to labels: {\"0\": \"cat\", \"1\": \"dog\"}" | |
| ) | |
| # Process label mapping | |
| class_map = None | |
| if uploaded_labels: | |
| try: | |
| import json | |
| class_map = json.loads(uploaded_labels.getvalue().decode("utf-8")) | |
| st.success(f"✓ Loaded {len(class_map)} class labels") | |
| except Exception as e: | |
| st.error(f"Error loading class labels: {e}") | |
| # Store uploaded file in session state if provided | |
| if uploaded_checkpoint: | |
| # Save to a temporary file | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as tmp_file: | |
| tmp_file.write(uploaded_checkpoint.getvalue()) | |
| checkpoint_path = tmp_file.name | |
| # Store in session state | |
| if 'checkpoint_path' not in st.session_state: | |
| st.session_state.checkpoint_path = checkpoint_path | |
| st.success("✓ Checkpoint ready to use") | |
| # Button to apply custom model | |
| if st.button("🚀 Load Custom Model", help="Click to use your custom model"): | |
| with st.spinner(f"Loading custom model based on {base_model}..."): | |
| try: | |
| num_classes = custom_classes if custom_classes > 0 else None | |
| model, processor, is_hf_model, class_map = load_model( | |
| base_model, checkpoint_path, class_map, num_classes | |
| ) | |
| st.session_state.model = model | |
| st.session_state.processor = processor | |
| st.session_state.is_hf_model = is_hf_model | |
| st.session_state.active_model = f"Custom {base_model}" | |
| st.session_state.using_custom = True | |
| st.session_state.class_map = class_map | |
| st.success(f"✅ Custom model activated!") | |
| except Exception as e: | |
| st.error(f"Failed to load custom model: {str(e)}") | |
| # Explanation section | |
| st.markdown("---") | |
| st.markdown("### ℹ️ Model Types") | |
| st.markdown(""" | |
| - **HF (Hugging Face)**: Vision Transformer models with standard interpretability | |
| - **timm (PyTorch Image Models)**: Classical CNN architectures like ResNet, EfficientNet | |
| *Custom models must match the base architecture's format.* | |
| """) | |
| # Initialize model and processor from session state | |
| if 'active_model' not in st.session_state: | |
| # First time loading - use default model | |
| m1 = "hf:google/vit-base-patch16-224" | |
| st.session_state.active_model = m1 | |
| model, processor, is_hf_model, _ = load_model(m1) | |
| st.session_state.model = model | |
| st.session_state.processor = processor | |
| st.session_state.is_hf_model = is_hf_model | |
| st.session_state.using_custom = False | |
| st.session_state.class_map = None | |
| else: | |
| # Get from session state | |
| model = st.session_state.model | |
| processor = st.session_state.processor | |
| is_hf_model = st.session_state.is_hf_model | |
| # Initialize explainer | |
| explainer = lime_image.LimeImageExplainer() | |
| st.title("🧠 Vision Transformer Interpretability Dashboard") | |
| st.write("Upload an image and explore explanations with **LIME** and **Uncertainty Analysis**.") | |
| # Add a Feynman-style "How it works" explanation as a collapsible expander | |
| with st.expander("How it works — Feynman-style explanations (click to expand)", expanded=False): | |
| st.markdown(""" | |
| ## 🧠 Vision Transformer Interpretability — Feynman-Style Explanations | |
| ### Why do we care about interpretability & uncertainty? | |
| Imagine you ask a kid to identify whether a picture is a cat. They point to the fur, ears, maybe whiskers. But what if the kid always focused on shadows, or background trees, instead of the cat itself? We want two things: | |
| 1. **Why** did the model say “cat”? What parts of the image made it decide so? | |
| 2. **How confident** is the model in that decision? Could small changes flip it? | |
| Interpretable methods show us #1. Uncertainty estimation shows us #2. Together, they help us see not just *what* the model does, but *whether* we should trust it. | |
| ### Key techniques, in plain analogies | |
| - **LIME (Local Interpretable Model-agnostic Explanations)**: For a single image & prediction, LIME perturbs (changes) parts of the image, watches how the prediction changes, and fits a simple model locally to understand which parts are most influential. | |
| - Analogy: Like shining small spotlights on different parts of a stage during a play: you dim a section, see how the actor’s reaction changes. The parts whose dimming changes the reaction most are parts the actor depends on. | |
| - **Uncertainty in LIME (multiple LIME runs)**: Because LIME uses randomness (perturbing patches), different runs can give different “important” regions. Measuring how much they differ tells you how stable/fragile the explanation is. | |
| - Analogy: If you ask several cooks what the dominant spice in a stew is and everyone agrees, you're confident; if opinions vary, your knowledge is shakier. | |
| - **MC Dropout (Monte Carlo Dropout)**: Leave dropout on at inference time and run the model multiple times. The spread of predictions is a proxy for epistemic uncertainty. | |
| - Analogy: Like a jury where each juror occasionally misses a sentence; if the verdict remains the same across many "faulty hearing" runs, trust it more. | |
| - **Test-Time Augmentation (TTA) Uncertainty**: Apply small transforms (crops, flips) at inference and watch prediction variance. High variance → brittle model. | |
| - Analogy: Take photos under slightly different lighting/angles; if the label flips, the model may depend on superficial cues. | |
| ### How to read the visuals | |
| - LIME highlights: bright / colored superpixels = influential regions. If background or artifacts light up, that's a red flag. | |
| - LIME uncertainty heatmap: high std in a region means attributions are unstable there. | |
| - MC Dropout / TTA histograms: narrow/tall peak = confident, wide/multi-modal = uncertain. | |
| ### Limitations & caveats | |
| - Stable explanations can still be consistently wrong if the model learned a bias. | |
| - MC Dropout is an approximation — it helps but doesn't fully replace calibrated probabilistic methods. | |
| - TTA shows input sensitivity, not full distributional shift robustness. | |
| ### Quick example (walkthrough) | |
| 1. Upload image → model predicts label with some probability. | |
| 2. LIME finds important superpixels; multiple LIME runs give mean + std maps. | |
| 3. MC Dropout produces a histogram over runs; use it to judge epistemic uncertainty. | |
| 4. TTA shows sensitivity to small input changes. | |
| ### Practical tips | |
| - Use explanation + uncertainty to guide active learning: label cases where the model is uncertain or explanations are unstable. | |
| - For safety-critical systems, combine these visual signals with human review and stricter failure thresholds. | |
| ### Where to read more | |
| - Christoph Molnar — Interpretable Machine Learning (chapter on LIME): https://christophm.github.io/interpretable-ml-book/lime.html | |
| - Ribeiro et al., "Why Should I Trust You?" (original LIME paper): https://homes.cs.washington.edu/~marcotcr/blog/lime/ | |
| - Zhang et al., "Why Should You Trust My Explanation?" (LIME reliability): https://arxiv.org/abs/1904.12991 | |
| - MC Dropout practical guide & notes: https://medium.com/@ciaranbench/monte-carlo-dropout-a-practical-guide-4b4dc18014b5 | |
| """) | |
| # Compact one-page cheat-sheet (quick flags & checks) | |
| with st.expander("Cheat-sheet — Quick flags & warnings", expanded=False): | |
| cheat_text = """ | |
| Quick checks when an explanation looks suspicious | |
| - Red flag: LIME highlights background or repeated dataset artifacts (logos, borders) — model may have learned spurious cues. | |
| - Red flag: LIME attribution std is high in key regions — explanation unstable; try different segmentations or more samples. | |
| - Red flag: MC Dropout or TTA histograms are multi-modal or very wide — model uncertain; consider human review or abstain. | |
| - Quick fixes: increase dataset diversity, add regularization, try different segmentation_fn parameters, or collect more labels for uncertain cases. | |
| One-line definitions | |
| - LIME: perturb + fit simple local model to explain a single prediction. | |
| - MC Dropout: enable dropout at inference and sample to estimate epistemic uncertainty. | |
| - TTA: apply small input transforms at inference to measure sensitivity / aleatoric uncertainty. | |
| Pro-tip: Use explanation + uncertainty to drive active learning: pick instances with high prediction uncertainty or unstable explanations for labeling. | |
| """ | |
| # Show the cheat-sheet as markdown | |
| st.markdown(cheat_text) | |
| # Download button for the cheat-sheet as plain text | |
| try: | |
| st.download_button( | |
| label="Download cheat-sheet (.txt)", | |
| data=cheat_text, | |
| file_name="cheat_sheet.txt", | |
| mime="text/plain", | |
| ) | |
| except Exception: | |
| # Streamlit may raise if download_button isn't available in some environments; ignore gracefully | |
| pass | |
| # Copy-to-clipboard button using a small HTML+JS snippet | |
| escaped = html.escape(cheat_text) | |
| copy_html = f""" | |
| <div> | |
| <button id='copy-btn' style='padding:6px 10px;border-radius:4px;'>Copy cheat-sheet</button> | |
| <script> | |
| const btn = document.getElementById('copy-btn'); | |
| btn.addEventListener('click', async () => {{ | |
| try {{ | |
| await navigator.clipboard.writeText(`{escaped}`); | |
| btn.innerText = 'Copied!'; | |
| setTimeout(() => btn.innerText = 'Copy cheat-sheet', 1500); | |
| }} catch (e) {{ | |
| btn.innerText = 'Copy failed'; | |
| }} | |
| }}); | |
| </script> | |
| </div> | |
| """ | |
| components.html(copy_html, height=70) | |
| # Display active model clearly in the main panel | |
| is_custom = st.session_state.get('using_custom', False) | |
| custom_badge = " 🔧 Custom" if is_custom else "" | |
| st.markdown(f"### Active Model: `{st.session_state.active_model}{custom_badge}`") | |
| model_type = "Hugging Face Transformer" if is_hf_model else "timm CNN Architecture" | |
| st.caption(f"Model type: {model_type}") | |
| # ---------------- Helpers ---------------- | |
| def classifier_fn(images_batch): | |
| # Use current model/processor from session state | |
| inputs = processor(images=[Image.fromarray(x.astype(np.uint8)) for x in images_batch], | |
| return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| if is_hf_model: | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| else: | |
| x = inputs['pixel_values'] | |
| logits = model(x) | |
| probs = torch.softmax(logits, dim=-1).cpu().numpy() | |
| return probs | |
| def predict_probs(pil_img): | |
| # Use current model/processor from session state | |
| inputs = processor(images=pil_img, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| if is_hf_model: | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| else: | |
| x = inputs['pixel_values'] | |
| logits = model(x) | |
| probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] | |
| return probs | |
| # ---------------- Upload ---------------- | |
| uploaded = st.file_uploader("Upload an image", type=["png","jpg","jpeg"]) | |
| if uploaded: | |
| img = Image.open(uploaded).convert("RGB").resize((224,224)) | |
| logger.info("Uploaded image received (size=%s)", img.size) | |
| # Streamlit 1.XX: replace deprecated `use_container_width` with `width` | |
| # For full-width behavior use width='stretch' (or 'content' for intrinsic size) | |
| st.image(img, caption="Uploaded image", width='stretch') | |
| # ---------------- Prediction ---------------- | |
| probs = predict_probs(img) | |
| pred_idx = int(np.argmax(probs)) | |
| # Get label - handle models differently based on source | |
| if is_hf_model: | |
| # Use model's config.id2label if available | |
| pred_label = model.config.id2label[pred_idx] | |
| elif st.session_state.get('class_map'): | |
| # Use custom class map if provided (access defensively) | |
| _class_map = st.session_state.get('class_map') | |
| pred_label = _class_map.get(str(pred_idx), f"Class {pred_idx}") if _class_map is not None else f"Class {pred_idx}" | |
| else: | |
| # For timm models without labels | |
| pred_label = f"Class {pred_idx}" | |
| pred_prob = float(probs[pred_idx]) | |
| logger.info("Prediction: %s (%.3f)", pred_label, pred_prob) | |
| st.subheader("🔮 Prediction") | |
| st.write(f"**Top-1:** {pred_label} ({pred_prob:.3f})") | |
| if not is_hf_model and not st.session_state.get('class_map'): | |
| st.info("ℹ️ Using model without class names. Upload a class mapping in the sidebar for friendly labels.") | |
| # ---------------- LIME ---------------- | |
| st.subheader("📍 LIME Attribution") | |
| st.markdown(""" | |
| **Local Interpretable Model-agnostic Explanations (LIME)** is a technique that approximates how a complex model (like ViT or ResNet) makes decisions for a specific input by creating a simpler, interpretable model around it. | |
| It perturbs the image into segments and sees which ones most influence the prediction, revealing what the model "sees" as important. | |
| This is crucial for debugging biases or understanding if the model focuses on relevant features vs. artifacts. | |
| """) | |
| img_np = np.array(img) | |
| with st.spinner("Generating LIME explanation..."): | |
| exp = explainer.explain_instance( | |
| img_np, classifier_fn=classifier_fn, top_labels=1, num_samples=1000, | |
| segmentation_fn=lambda x: slic(x, n_segments=60, compactness=9, start_label=0) | |
| ) | |
| temp, mask = exp.get_image_and_mask(pred_idx, positive_only=True, | |
| num_features=8, hide_rest=False) | |
| lime_img = mark_boundaries(temp/255.0, mask) | |
| st.image(lime_img, caption=f"LIME highlights regions important for '{pred_label}'") | |
| st.info(""" | |
| **How to read:** Bright (or colored) segments show areas the model relied on most for its prediction – these are the "superpixels" that, when altered, change the output the most. | |
| Green/red overlays often indicate positive/negative contributions. If irrelevant background or edges light up, it might signal the model learned spurious correlations (e.g., from training data artifacts). | |
| Furthermore, this builds trust by showing if AI decisions align with human intuition. | |
| """) | |
| # ---------------- LIME Uncertainty ---------------- | |
| st.subheader("📊 LIME Attribution Uncertainty") | |
| st.markdown(""" | |
| Uncertainty in explanations arises because LIME is stochastic – it samples perturbations randomly. By running LIME multiple times, we can measure variability in attributions, | |
| highlighting if the model's reasoning is consistent or fragile for this image. High variability suggests the explanation (and thus model confidence) isn't robust. | |
| """) | |
| logger.info("Starting LIME uncertainty runs (n=5)") | |
| maps = [] | |
| for i in range(5): | |
| logger.debug("LIME run %d", i+1) | |
| exp = explainer.explain_instance( | |
| img_np, classifier_fn=classifier_fn, top_labels=1, num_samples=500, | |
| segmentation_fn=lambda x: slic(x, n_segments=60, compactness=9, start_label=0) | |
| ) | |
| local_exp = dict(exp.local_exp)[pred_idx] | |
| segments = exp.segments | |
| attr_map = np.zeros(segments.shape) | |
| for seg_id, weight in local_exp: | |
| attr_map[segments == seg_id] = weight | |
| maps.append(attr_map) | |
| maps = np.stack(maps) | |
| mean_attr, std_attr = maps.mean(0), maps.std(0) | |
| fig, ax = plt.subplots(1,2, figsize=(8,4)) | |
| im1 = ax[0].imshow(mean_attr, cmap="jet"); ax[0].set_title("Mean attribution"); ax[0].axis("off") | |
| plt.colorbar(im1, ax=ax[0], fraction=0.046) | |
| im2 = ax[1].imshow(std_attr, cmap="hot"); ax[1].set_title("Attribution std (uncertainty)"); ax[1].axis("off") | |
| plt.colorbar(im2, ax=ax[1], fraction=0.046) | |
| st.pyplot(fig) | |
| st.info(""" | |
| **How to read:** The left heatmap shows average importance across runs (hotter = more influential). The right shows standard deviation – high std (yellow/red) means unstable explanations for those regions. | |
| If uncertainty is high in key areas, the model might overfit or need more diverse training data. This helps ML practitioners quantify explanation reliability. | |
| """) | |
| logger.info("Completed LIME uncertainty runs") | |
| # ---------------- MC Dropout ---------------- | |
| st.subheader("🎲 MC Dropout Uncertainty") | |
| st.markdown(""" | |
| Monte Carlo (MC) Dropout treats dropout layers (normally off during inference) as a Bayesian approximation to estimate epistemic uncertainty – how much the model "doesn't know" due to limited training. | |
| By enabling dropout and sampling predictions multiple times, we see if the model consistently agrees on the class or wavers, indicating potential unreliability. | |
| """) | |
| logger.info("Starting MC Dropout sampling") | |
| model.train() # enable dropout | |
| mc_preds = [] | |
| with torch.no_grad(): | |
| for _ in range(30): | |
| probs_mc = predict_probs(img) | |
| mc_preds.append(probs_mc) | |
| model.eval() | |
| mc_preds = np.stack(mc_preds) | |
| mc_mean = mc_preds.mean(0) | |
| mc_top = mc_mean.argmax() | |
| if is_hf_model: | |
| mc_label = model.config.id2label[mc_top] | |
| elif st.session_state.get('class_map'): | |
| _class_map = st.session_state.get('class_map') | |
| mc_label = _class_map.get(str(mc_top), f"Class {mc_top}") if _class_map is not None else f"Class {mc_top}" | |
| else: | |
| mc_label = f"Class {mc_top}" | |
| p = mc_preds[:, mc_top] | |
| fig, ax = plt.subplots() | |
| ax.hist(p, bins=15, color="C0") | |
| ax.set_title(f"MC Dropout: p({mc_label}) across samples") | |
| st.pyplot(fig) | |
| st.info(""" | |
| **How to read:** This histogram shows probability distributions for the top class across 30 samples. A narrow, peaked distribution means stable confidence (low uncertainty). | |
| A wide spread or multiple modes suggests the model is unsure, possibly due to out-of-distribution inputs. For devs, this flags cases needing human review; it highlights risky predictions. | |
| """) | |
| logger.info("Completed MC Dropout: top=%s", mc_label) | |
| # ---------------- Test-Time Augmentation (TTA) Uncertainty ---------------- | |
| st.subheader("🔄 Test-Time Augmentation (TTA) Uncertainty") | |
| st.markdown(""" | |
| Test-Time Augmentation (TTA) applies random transformations (crops, flips) at inference to probe aleatoric uncertainty – noise inherent in the input or model. | |
| If predictions vary wildly under small changes, the model relies on brittle features, revealing data-related issues rather than model knowledge gaps. | |
| """) | |
| logger.info("Starting TTA sampling") | |
| tta_tfms = T.Compose([T.Resize(256), T.RandomResizedCrop(224, scale=(0.9,1.0)), T.RandomHorizontalFlip(p=0.5)]) | |
| tta_preds = [] | |
| with torch.no_grad(): | |
| for _ in range(20): | |
| aug = tta_tfms(img) | |
| probs_tta = predict_probs(aug) | |
| tta_preds.append(probs_tta) | |
| tta_preds = np.stack(tta_preds) | |
| tta_mean = tta_preds.mean(0) | |
| tta_top = tta_mean.argmax() | |
| if is_hf_model: | |
| tta_label = model.config.id2label[tta_top] | |
| elif st.session_state.get('class_map'): | |
| _class_map = st.session_state.get('class_map') | |
| tta_label = _class_map.get(str(tta_top), f"Class {tta_top}") if _class_map is not None else f"Class {tta_top}" | |
| else: | |
| tta_label = f"Class {tta_top}" | |
| p_tta = tta_preds[:, tta_top] | |
| fig, ax = plt.subplots() | |
| ax.hist(p_tta, bins=15, color="C1") | |
| ax.set_title(f"TTA: p({tta_label}) across augmentations") | |
| st.pyplot(fig) | |
| st.info(""" | |
| **How to read:** Similar to MC Dropout, but focused on input variations. Low variance means the prediction is robust to perturbations (good sign). High variance indicates sensitivity to details like lighting/position, | |
| common in overfitted models. Use this to assess if your AI system handles real-world variability well. | |
| """) | |
| logger.info("Completed TTA: top=%s", tta_label) | |
| # ---------------- Summary ---------------- |