''' streamlitapp.py โ Vision Transformer Interpretability Dashboard (Streamlit app) This Streamlit app provides interpretability tools for vision transformer and CNN models. Features: - LIME explanations for image classification predictions - Uncertainty analysis via MC Dropout and Test-Time Augmentation (TTA) - Switch between Hugging Face (ViT, Swin, DeiT) and timm (ResNet, EfficientNet, ConvNeXt) models - Support for custom finetuned models and class mappings - Interactive sidebar for model selection and checkpoint upload - Feynman-style explanations and cheat-sheet for interpretability concepts Inspired by and reuses code from: - vit_and_captum.py (Integrated Gradients with Captum) - vit_lime_uncertainty.py (LIME explanations and uncertainty) - detr_and_interp.py (Grad-CAM for DETR, logging setup) ''' import streamlit as st import html import numpy as np, torch, matplotlib.pyplot as plt from PIL import Image from transformers import AutoModelForImageClassification, AutoImageProcessor, PreTrainedModel from lime import lime_image import torchvision.transforms as T import timm from skimage.segmentation import slic, mark_boundaries import streamlit.components.v1 as components # Add logging import logging, os from logging.handlers import RotatingFileHandler LOG_DIR = os.path.join(os.path.dirname(__file__), "logs") os.makedirs(LOG_DIR, exist_ok=True) logfile = os.path.join(LOG_DIR, "interp.log") logger = logging.getLogger("interp") if not logger.handlers: logger.setLevel(logging.INFO) sh = logging.StreamHandler() sh.setLevel(logging.INFO) fh = RotatingFileHandler(logfile, maxBytes=5_000_000, backupCount=3, encoding="utf-8") fh.setLevel(logging.INFO) fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s") sh.setFormatter(fmt) fh.setFormatter(fmt) logger.addHandler(sh) logger.addHandler(fh) # ---------------- Setup ---------------- MODEL_NAME = "google/vit-base-patch16-224" device = "cuda" if torch.cuda.is_available() else "cpu" # ---------- Sidebar model selectors ---------- # Quick lists you can edit to test other HF / timm models HF_MODELS = [ "google/vit-base-patch16-224", "facebook/deit-base-patch16-224", "microsoft/swin-tiny-patch4-window7-224", "google/vit-large-patch16-224", ] TIMM_MODELS = [ "convnext_base", "resnet50", "efficientnet_b0", ] def model_selector(slot_key: str, default_source="hf"): source = st.sidebar.selectbox( f"{slot_key} source", ["hf", "timm"], index=0 if default_source == "hf" else 1, key=f"{slot_key}_source", ) if source == "hf": hf_choice = st.sidebar.selectbox( f"{slot_key} Hugging Face model", HF_MODELS, index=0, key=f"{slot_key}_hf", ) return f"hf:{hf_choice}" else: timm_choice = st.sidebar.selectbox( f"{slot_key} timm model", TIMM_MODELS, index=0, key=f"{slot_key}_timm", ) return f"timm:{timm_choice}" # ---------- Model Loader ---------- # Use Streamlit caching when available to avoid repeated downloads try: cache_decorator = st.cache_resource except Exception: from functools import lru_cache cache_decorator = lru_cache(maxsize=8) @cache_decorator def load_model(choice, checkpoint=None, class_map=None, num_classes=None): """ Load a model from HF, timm, or a custom checkpoint Args: choice: Model identifier ('hf:model_name' or 'timm:model_name') checkpoint: Optional path to custom checkpoint file class_map: Optional dict mapping class indices to labels num_classes: Optional number of classes for custom models """ logger.info("Loading model: %s", choice) is_hf = choice.startswith("hf:") # Parse model identifier if is_hf: hf_name = choice.split("hf:")[1] if checkpoint: # Custom checkpoint # For custom HF model, first load the architecture then apply weights try: if num_classes: model = AutoModelForImageClassification.from_pretrained( hf_name, num_labels=num_classes, ignore_mismatched_sizes=True ).to(device) else: model = AutoModelForImageClassification.from_pretrained(hf_name).to(device) # Load checkpoint with error handling state_dict = torch.load(checkpoint, map_location=device) # If state_dict is wrapped (common in training checkpoints) if "model" in state_dict: state_dict = state_dict["model"] elif "state_dict" in state_dict: state_dict = state_dict["state_dict"] # Handle any prefix differences by checking and stripping if needed if all(k.startswith('model.') for k in state_dict if k != 'config'): state_dict = {k[6:]: v for k, v in state_dict.items() if k != 'config'} # Load with flexible partial loading (ignore missing/unexpected) model.load_state_dict(state_dict, strict=False) logger.info("Custom checkpoint loaded for HF model") # If custom class mapping provided, update config if class_map: model.config.id2label = class_map model.config.label2id = {v: int(k) for k, v in class_map.items()} except Exception as e: logger.error(f"Error loading custom HF model: {e}") st.error(f"Failed to load custom model: {e}") # Fallback to base model model = AutoModelForImageClassification.from_pretrained(hf_name).to(device) else: # Standard HF model model = AutoModelForImageClassification.from_pretrained(hf_name).to(device) processor = AutoImageProcessor.from_pretrained(hf_name) elif choice.startswith("timm:"): name = choice.split("timm:")[1] if checkpoint: # Custom checkpoint try: # For timm, specify custom number of classes if provided if num_classes: model = timm.create_model(name, pretrained=False, num_classes=num_classes).to(device) else: model = timm.create_model(name, pretrained=True).to(device) # Load checkpoint state_dict = torch.load(checkpoint, map_location=device) # Handle common checkpoint formats if "model" in state_dict: state_dict = state_dict["model"] elif "state_dict" in state_dict: state_dict = state_dict["state_dict"] # Handle any prefix differences if all(k.startswith('module.') for k in state_dict): state_dict = {k[7:]: v for k, v in state_dict} model.load_state_dict(state_dict, strict=False) logger.info("Custom checkpoint loaded for timm model") except Exception as e: logger.error(f"Error loading custom timm model: {e}") st.error(f"Failed to load custom model: {e}") # Fallback to pretrained model = timm.create_model(name, pretrained=True).to(device) else: # Standard timm model model = timm.create_model(name, pretrained=True).to(device) # Use a standard processor for timm processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224") # Set model to eval mode model.eval() logger.info("Model %s loaded (eval mode)", choice) # Return model, processor, flag for HF, and class map return model, processor, is_hf, class_map # Add sidebar with clear sections st.sidebar.title("Model Selection") # Enhanced sidebar with custom model support with st.sidebar: # Add tabs for standard vs custom models tab1, tab2 = st.tabs(["Standard Models", "Custom Finetuned Models"]) with tab1: st.markdown("### ๐ Standard Models") st.markdown("Choose from pre-trained models:") m1 = model_selector("Active Model", default_source="hf") # Button to apply standard model change if st.button("๐ Set as Active Model", help="Click to use the selected model for analysis", key="std_model_btn"): with st.spinner(f"Loading {m1}..."): model, processor, is_hf_model, _ = load_model(m1) st.session_state.model = model st.session_state.processor = processor st.session_state.is_hf_model = is_hf_model st.session_state.active_model = m1 st.session_state.using_custom = False st.session_state.class_map = None st.success(f"โ Model activated: {m1}") with tab2: st.markdown("### ๐ง Custom Finetuned Model") st.markdown("Use your own finetuned model:") # Select base architecture custom_source = st.selectbox( "Base architecture source", ["hf", "timm"], key="custom_source" ) if custom_source == "hf": custom_base = st.selectbox( "Hugging Face base model", HF_MODELS, key="custom_hf_base" ) base_model = f"hf:{custom_base}" else: custom_base = st.selectbox( "timm base model", TIMM_MODELS, key="custom_timm_base" ) base_model = f"timm:{custom_base}" # Upload checkpoint file uploaded_checkpoint = st.file_uploader( "Upload model checkpoint (.pth, .bin)", type=["pth", "bin", "pt", "ckpt"], help="Upload your finetuned model weights" ) # Optional class mapping custom_classes = st.number_input( "Number of classes (if different from base model)", min_value=0, max_value=1000, value=0, help="Leave at 0 to use default classes from base model" ) uploaded_labels = st.file_uploader( "Upload class labels (optional JSON)", type=["json"], help="JSON file mapping class indices to labels: {\"0\": \"cat\", \"1\": \"dog\"}" ) # Process label mapping class_map = None if uploaded_labels: try: import json class_map = json.loads(uploaded_labels.getvalue().decode("utf-8")) st.success(f"โ Loaded {len(class_map)} class labels") except Exception as e: st.error(f"Error loading class labels: {e}") # Store uploaded file in session state if provided if uploaded_checkpoint: # Save to a temporary file import tempfile with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as tmp_file: tmp_file.write(uploaded_checkpoint.getvalue()) checkpoint_path = tmp_file.name # Store in session state if 'checkpoint_path' not in st.session_state: st.session_state.checkpoint_path = checkpoint_path st.success("โ Checkpoint ready to use") # Button to apply custom model if st.button("๐ Load Custom Model", help="Click to use your custom model"): with st.spinner(f"Loading custom model based on {base_model}..."): try: num_classes = custom_classes if custom_classes > 0 else None model, processor, is_hf_model, class_map = load_model( base_model, checkpoint_path, class_map, num_classes ) st.session_state.model = model st.session_state.processor = processor st.session_state.is_hf_model = is_hf_model st.session_state.active_model = f"Custom {base_model}" st.session_state.using_custom = True st.session_state.class_map = class_map st.success(f"โ Custom model activated!") except Exception as e: st.error(f"Failed to load custom model: {str(e)}") # Explanation section st.markdown("---") st.markdown("### โน๏ธ Model Types") st.markdown(""" - **HF (Hugging Face)**: Vision Transformer models with standard interpretability - **timm (PyTorch Image Models)**: Classical CNN architectures like ResNet, EfficientNet *Custom models must match the base architecture's format.* """) # Initialize model and processor from session state if 'active_model' not in st.session_state: # First time loading - use default model m1 = "hf:google/vit-base-patch16-224" st.session_state.active_model = m1 model, processor, is_hf_model, _ = load_model(m1) st.session_state.model = model st.session_state.processor = processor st.session_state.is_hf_model = is_hf_model st.session_state.using_custom = False st.session_state.class_map = None else: # Get from session state model = st.session_state.model processor = st.session_state.processor is_hf_model = st.session_state.is_hf_model # Initialize explainer explainer = lime_image.LimeImageExplainer() st.title("๐ง Vision Transformer Interpretability Dashboard") st.write("Upload an image and explore explanations with **LIME** and **Uncertainty Analysis**.") # Add a Feynman-style "How it works" explanation as a collapsible expander with st.expander("How it works โ Feynman-style explanations (click to expand)", expanded=False): st.markdown(""" ## ๐ง Vision Transformer Interpretability โ Feynman-Style Explanations ### Why do we care about interpretability & uncertainty? Imagine you ask a kid to identify whether a picture is a cat. They point to the fur, ears, maybe whiskers. But what if the kid always focused on shadows, or background trees, instead of the cat itself? We want two things: 1. **Why** did the model say โcatโ? What parts of the image made it decide so? 2. **How confident** is the model in that decision? Could small changes flip it? Interpretable methods show us #1. Uncertainty estimation shows us #2. Together, they help us see not just *what* the model does, but *whether* we should trust it. ### Key techniques, in plain analogies - **LIME (Local Interpretable Model-agnostic Explanations)**: For a single image & prediction, LIME perturbs (changes) parts of the image, watches how the prediction changes, and fits a simple model locally to understand which parts are most influential. - Analogy: Like shining small spotlights on different parts of a stage during a play: you dim a section, see how the actorโs reaction changes. The parts whose dimming changes the reaction most are parts the actor depends on. - **Uncertainty in LIME (multiple LIME runs)**: Because LIME uses randomness (perturbing patches), different runs can give different โimportantโ regions. Measuring how much they differ tells you how stable/fragile the explanation is. - Analogy: If you ask several cooks what the dominant spice in a stew is and everyone agrees, you're confident; if opinions vary, your knowledge is shakier. - **MC Dropout (Monte Carlo Dropout)**: Leave dropout on at inference time and run the model multiple times. The spread of predictions is a proxy for epistemic uncertainty. - Analogy: Like a jury where each juror occasionally misses a sentence; if the verdict remains the same across many "faulty hearing" runs, trust it more. - **Test-Time Augmentation (TTA) Uncertainty**: Apply small transforms (crops, flips) at inference and watch prediction variance. High variance โ brittle model. - Analogy: Take photos under slightly different lighting/angles; if the label flips, the model may depend on superficial cues. ### How to read the visuals - LIME highlights: bright / colored superpixels = influential regions. If background or artifacts light up, that's a red flag. - LIME uncertainty heatmap: high std in a region means attributions are unstable there. - MC Dropout / TTA histograms: narrow/tall peak = confident, wide/multi-modal = uncertain. ### Limitations & caveats - Stable explanations can still be consistently wrong if the model learned a bias. - MC Dropout is an approximation โ it helps but doesn't fully replace calibrated probabilistic methods. - TTA shows input sensitivity, not full distributional shift robustness. ### Quick example (walkthrough) 1. Upload image โ model predicts label with some probability. 2. LIME finds important superpixels; multiple LIME runs give mean + std maps. 3. MC Dropout produces a histogram over runs; use it to judge epistemic uncertainty. 4. TTA shows sensitivity to small input changes. ### Practical tips - Use explanation + uncertainty to guide active learning: label cases where the model is uncertain or explanations are unstable. - For safety-critical systems, combine these visual signals with human review and stricter failure thresholds. ### Where to read more - Christoph Molnar โ Interpretable Machine Learning (chapter on LIME): https://christophm.github.io/interpretable-ml-book/lime.html - Ribeiro et al., "Why Should I Trust You?" (original LIME paper): https://homes.cs.washington.edu/~marcotcr/blog/lime/ - Zhang et al., "Why Should You Trust My Explanation?" (LIME reliability): https://arxiv.org/abs/1904.12991 - MC Dropout practical guide & notes: https://medium.com/@ciaranbench/monte-carlo-dropout-a-practical-guide-4b4dc18014b5 """) # Compact one-page cheat-sheet (quick flags & checks) with st.expander("Cheat-sheet โ Quick flags & warnings", expanded=False): cheat_text = """ Quick checks when an explanation looks suspicious - Red flag: LIME highlights background or repeated dataset artifacts (logos, borders) โ model may have learned spurious cues. - Red flag: LIME attribution std is high in key regions โ explanation unstable; try different segmentations or more samples. - Red flag: MC Dropout or TTA histograms are multi-modal or very wide โ model uncertain; consider human review or abstain. - Quick fixes: increase dataset diversity, add regularization, try different segmentation_fn parameters, or collect more labels for uncertain cases. One-line definitions - LIME: perturb + fit simple local model to explain a single prediction. - MC Dropout: enable dropout at inference and sample to estimate epistemic uncertainty. - TTA: apply small input transforms at inference to measure sensitivity / aleatoric uncertainty. Pro-tip: Use explanation + uncertainty to drive active learning: pick instances with high prediction uncertainty or unstable explanations for labeling. """ # Show the cheat-sheet as markdown st.markdown(cheat_text) # Download button for the cheat-sheet as plain text try: st.download_button( label="Download cheat-sheet (.txt)", data=cheat_text, file_name="cheat_sheet.txt", mime="text/plain", ) except Exception: # Streamlit may raise if download_button isn't available in some environments; ignore gracefully pass # Copy-to-clipboard button using a small HTML+JS snippet escaped = html.escape(cheat_text) copy_html = f"""