ViT_timm_interp / src /streamlit_app.py
Skier8402's picture
Update src/streamlit_app.py
0ee8310 verified
'''
streamlitapp.py — Vision Transformer Interpretability Dashboard (Streamlit app)
This Streamlit app provides interpretability tools for vision transformer and CNN models.
Features:
- LIME explanations for image classification predictions
- Uncertainty analysis via MC Dropout and Test-Time Augmentation (TTA)
- Switch between Hugging Face (ViT, Swin, DeiT) and timm (ResNet, EfficientNet, ConvNeXt) models
- Support for custom finetuned models and class mappings
- Interactive sidebar for model selection and checkpoint upload
- Feynman-style explanations and cheat-sheet for interpretability concepts
Inspired by and reuses code from:
- vit_and_captum.py (Integrated Gradients with Captum)
- vit_lime_uncertainty.py (LIME explanations and uncertainty)
- detr_and_interp.py (Grad-CAM for DETR, logging setup)
'''
import streamlit as st
import html
import numpy as np, torch, matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor, PreTrainedModel
from lime import lime_image
import torchvision.transforms as T
import timm
from skimage.segmentation import slic, mark_boundaries
import streamlit.components.v1 as components
# Add logging
import logging, os
from logging.handlers import RotatingFileHandler
LOG_DIR = os.path.join(os.path.dirname(__file__), "logs")
os.makedirs(LOG_DIR, exist_ok=True)
logfile = os.path.join(LOG_DIR, "interp.log")
logger = logging.getLogger("interp")
if not logger.handlers:
logger.setLevel(logging.INFO)
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
fh = RotatingFileHandler(logfile, maxBytes=5_000_000, backupCount=3, encoding="utf-8")
fh.setLevel(logging.INFO)
fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")
sh.setFormatter(fmt)
fh.setFormatter(fmt)
logger.addHandler(sh)
logger.addHandler(fh)
# ---------------- Setup ----------------
MODEL_NAME = "google/vit-base-patch16-224"
device = "cuda" if torch.cuda.is_available() else "cpu"
# ---------- Sidebar model selectors ----------
# Quick lists you can edit to test other HF / timm models
HF_MODELS = [
"google/vit-base-patch16-224",
"facebook/deit-base-patch16-224",
"microsoft/swin-tiny-patch4-window7-224",
"google/vit-large-patch16-224",
]
TIMM_MODELS = [
"convnext_base",
"resnet50",
"efficientnet_b0",
]
def model_selector(slot_key: str, default_source="hf"):
source = st.sidebar.selectbox(
f"{slot_key} source",
["hf", "timm"],
index=0 if default_source == "hf" else 1,
key=f"{slot_key}_source",
)
if source == "hf":
hf_choice = st.sidebar.selectbox(
f"{slot_key} Hugging Face model",
HF_MODELS,
index=0,
key=f"{slot_key}_hf",
)
return f"hf:{hf_choice}"
else:
timm_choice = st.sidebar.selectbox(
f"{slot_key} timm model",
TIMM_MODELS,
index=0,
key=f"{slot_key}_timm",
)
return f"timm:{timm_choice}"
# ---------- Model Loader ----------
# Use Streamlit caching when available to avoid repeated downloads
try:
cache_decorator = st.cache_resource
except Exception:
from functools import lru_cache
cache_decorator = lru_cache(maxsize=8)
@cache_decorator
def load_model(choice, checkpoint=None, class_map=None, num_classes=None):
"""
Load a model from HF, timm, or a custom checkpoint
Args:
choice: Model identifier ('hf:model_name' or 'timm:model_name')
checkpoint: Optional path to custom checkpoint file
class_map: Optional dict mapping class indices to labels
num_classes: Optional number of classes for custom models
"""
logger.info("Loading model: %s", choice)
is_hf = choice.startswith("hf:")
# Parse model identifier
if is_hf:
hf_name = choice.split("hf:")[1]
if checkpoint: # Custom checkpoint
# For custom HF model, first load the architecture then apply weights
try:
if num_classes:
model = AutoModelForImageClassification.from_pretrained(
hf_name, num_labels=num_classes, ignore_mismatched_sizes=True
).to(device)
else:
model = AutoModelForImageClassification.from_pretrained(hf_name).to(device)
# Load checkpoint with error handling
state_dict = torch.load(checkpoint, map_location=device)
# If state_dict is wrapped (common in training checkpoints)
if "model" in state_dict:
state_dict = state_dict["model"]
elif "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
# Handle any prefix differences by checking and stripping if needed
if all(k.startswith('model.') for k in state_dict if k != 'config'):
state_dict = {k[6:]: v for k, v in state_dict.items() if k != 'config'}
# Load with flexible partial loading (ignore missing/unexpected)
model.load_state_dict(state_dict, strict=False)
logger.info("Custom checkpoint loaded for HF model")
# If custom class mapping provided, update config
if class_map:
model.config.id2label = class_map
model.config.label2id = {v: int(k) for k, v in class_map.items()}
except Exception as e:
logger.error(f"Error loading custom HF model: {e}")
st.error(f"Failed to load custom model: {e}")
# Fallback to base model
model = AutoModelForImageClassification.from_pretrained(hf_name).to(device)
else:
# Standard HF model
model = AutoModelForImageClassification.from_pretrained(hf_name).to(device)
processor = AutoImageProcessor.from_pretrained(hf_name)
elif choice.startswith("timm:"):
name = choice.split("timm:")[1]
if checkpoint: # Custom checkpoint
try:
# For timm, specify custom number of classes if provided
if num_classes:
model = timm.create_model(name, pretrained=False, num_classes=num_classes).to(device)
else:
model = timm.create_model(name, pretrained=True).to(device)
# Load checkpoint
state_dict = torch.load(checkpoint, map_location=device)
# Handle common checkpoint formats
if "model" in state_dict:
state_dict = state_dict["model"]
elif "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
# Handle any prefix differences
if all(k.startswith('module.') for k in state_dict):
state_dict = {k[7:]: v for k, v in state_dict}
model.load_state_dict(state_dict, strict=False)
logger.info("Custom checkpoint loaded for timm model")
except Exception as e:
logger.error(f"Error loading custom timm model: {e}")
st.error(f"Failed to load custom model: {e}")
# Fallback to pretrained
model = timm.create_model(name, pretrained=True).to(device)
else:
# Standard timm model
model = timm.create_model(name, pretrained=True).to(device)
# Use a standard processor for timm
processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224")
# Set model to eval mode
model.eval()
logger.info("Model %s loaded (eval mode)", choice)
# Return model, processor, flag for HF, and class map
return model, processor, is_hf, class_map
# Add sidebar with clear sections
st.sidebar.title("Model Selection")
# Enhanced sidebar with custom model support
with st.sidebar:
# Add tabs for standard vs custom models
tab1, tab2 = st.tabs(["Standard Models", "Custom Finetuned Models"])
with tab1:
st.markdown("### 📊 Standard Models")
st.markdown("Choose from pre-trained models:")
m1 = model_selector("Active Model", default_source="hf")
# Button to apply standard model change
if st.button("📋 Set as Active Model", help="Click to use the selected model for analysis", key="std_model_btn"):
with st.spinner(f"Loading {m1}..."):
model, processor, is_hf_model, _ = load_model(m1)
st.session_state.model = model
st.session_state.processor = processor
st.session_state.is_hf_model = is_hf_model
st.session_state.active_model = m1
st.session_state.using_custom = False
st.session_state.class_map = None
st.success(f"✅ Model activated: {m1}")
with tab2:
st.markdown("### 🔧 Custom Finetuned Model")
st.markdown("Use your own finetuned model:")
# Select base architecture
custom_source = st.selectbox(
"Base architecture source",
["hf", "timm"],
key="custom_source"
)
if custom_source == "hf":
custom_base = st.selectbox(
"Hugging Face base model",
HF_MODELS,
key="custom_hf_base"
)
base_model = f"hf:{custom_base}"
else:
custom_base = st.selectbox(
"timm base model",
TIMM_MODELS,
key="custom_timm_base"
)
base_model = f"timm:{custom_base}"
# Upload checkpoint file
uploaded_checkpoint = st.file_uploader(
"Upload model checkpoint (.pth, .bin)",
type=["pth", "bin", "pt", "ckpt"],
help="Upload your finetuned model weights"
)
# Optional class mapping
custom_classes = st.number_input(
"Number of classes (if different from base model)",
min_value=0, max_value=1000, value=0,
help="Leave at 0 to use default classes from base model"
)
uploaded_labels = st.file_uploader(
"Upload class labels (optional JSON)",
type=["json"],
help="JSON file mapping class indices to labels: {\"0\": \"cat\", \"1\": \"dog\"}"
)
# Process label mapping
class_map = None
if uploaded_labels:
try:
import json
class_map = json.loads(uploaded_labels.getvalue().decode("utf-8"))
st.success(f"✓ Loaded {len(class_map)} class labels")
except Exception as e:
st.error(f"Error loading class labels: {e}")
# Store uploaded file in session state if provided
if uploaded_checkpoint:
# Save to a temporary file
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as tmp_file:
tmp_file.write(uploaded_checkpoint.getvalue())
checkpoint_path = tmp_file.name
# Store in session state
if 'checkpoint_path' not in st.session_state:
st.session_state.checkpoint_path = checkpoint_path
st.success("✓ Checkpoint ready to use")
# Button to apply custom model
if st.button("🚀 Load Custom Model", help="Click to use your custom model"):
with st.spinner(f"Loading custom model based on {base_model}..."):
try:
num_classes = custom_classes if custom_classes > 0 else None
model, processor, is_hf_model, class_map = load_model(
base_model, checkpoint_path, class_map, num_classes
)
st.session_state.model = model
st.session_state.processor = processor
st.session_state.is_hf_model = is_hf_model
st.session_state.active_model = f"Custom {base_model}"
st.session_state.using_custom = True
st.session_state.class_map = class_map
st.success(f"✅ Custom model activated!")
except Exception as e:
st.error(f"Failed to load custom model: {str(e)}")
# Explanation section
st.markdown("---")
st.markdown("### ℹ️ Model Types")
st.markdown("""
- **HF (Hugging Face)**: Vision Transformer models with standard interpretability
- **timm (PyTorch Image Models)**: Classical CNN architectures like ResNet, EfficientNet
*Custom models must match the base architecture's format.*
""")
# Initialize model and processor from session state
if 'active_model' not in st.session_state:
# First time loading - use default model
m1 = "hf:google/vit-base-patch16-224"
st.session_state.active_model = m1
model, processor, is_hf_model, _ = load_model(m1)
st.session_state.model = model
st.session_state.processor = processor
st.session_state.is_hf_model = is_hf_model
st.session_state.using_custom = False
st.session_state.class_map = None
else:
# Get from session state
model = st.session_state.model
processor = st.session_state.processor
is_hf_model = st.session_state.is_hf_model
# Initialize explainer
explainer = lime_image.LimeImageExplainer()
st.title("🧠 Vision Transformer Interpretability Dashboard")
st.write("Upload an image and explore explanations with **LIME** and **Uncertainty Analysis**.")
# Add a Feynman-style "How it works" explanation as a collapsible expander
with st.expander("How it works — Feynman-style explanations (click to expand)", expanded=False):
st.markdown("""
## 🧠 Vision Transformer Interpretability — Feynman-Style Explanations
### Why do we care about interpretability & uncertainty?
Imagine you ask a kid to identify whether a picture is a cat. They point to the fur, ears, maybe whiskers. But what if the kid always focused on shadows, or background trees, instead of the cat itself? We want two things:
1. **Why** did the model say “cat”? What parts of the image made it decide so?
2. **How confident** is the model in that decision? Could small changes flip it?
Interpretable methods show us #1. Uncertainty estimation shows us #2. Together, they help us see not just *what* the model does, but *whether* we should trust it.
### Key techniques, in plain analogies
- **LIME (Local Interpretable Model-agnostic Explanations)**: For a single image & prediction, LIME perturbs (changes) parts of the image, watches how the prediction changes, and fits a simple model locally to understand which parts are most influential.
- Analogy: Like shining small spotlights on different parts of a stage during a play: you dim a section, see how the actor’s reaction changes. The parts whose dimming changes the reaction most are parts the actor depends on.
- **Uncertainty in LIME (multiple LIME runs)**: Because LIME uses randomness (perturbing patches), different runs can give different “important” regions. Measuring how much they differ tells you how stable/fragile the explanation is.
- Analogy: If you ask several cooks what the dominant spice in a stew is and everyone agrees, you're confident; if opinions vary, your knowledge is shakier.
- **MC Dropout (Monte Carlo Dropout)**: Leave dropout on at inference time and run the model multiple times. The spread of predictions is a proxy for epistemic uncertainty.
- Analogy: Like a jury where each juror occasionally misses a sentence; if the verdict remains the same across many "faulty hearing" runs, trust it more.
- **Test-Time Augmentation (TTA) Uncertainty**: Apply small transforms (crops, flips) at inference and watch prediction variance. High variance → brittle model.
- Analogy: Take photos under slightly different lighting/angles; if the label flips, the model may depend on superficial cues.
### How to read the visuals
- LIME highlights: bright / colored superpixels = influential regions. If background or artifacts light up, that's a red flag.
- LIME uncertainty heatmap: high std in a region means attributions are unstable there.
- MC Dropout / TTA histograms: narrow/tall peak = confident, wide/multi-modal = uncertain.
### Limitations & caveats
- Stable explanations can still be consistently wrong if the model learned a bias.
- MC Dropout is an approximation — it helps but doesn't fully replace calibrated probabilistic methods.
- TTA shows input sensitivity, not full distributional shift robustness.
### Quick example (walkthrough)
1. Upload image → model predicts label with some probability.
2. LIME finds important superpixels; multiple LIME runs give mean + std maps.
3. MC Dropout produces a histogram over runs; use it to judge epistemic uncertainty.
4. TTA shows sensitivity to small input changes.
### Practical tips
- Use explanation + uncertainty to guide active learning: label cases where the model is uncertain or explanations are unstable.
- For safety-critical systems, combine these visual signals with human review and stricter failure thresholds.
### Where to read more
- Christoph Molnar — Interpretable Machine Learning (chapter on LIME): https://christophm.github.io/interpretable-ml-book/lime.html
- Ribeiro et al., "Why Should I Trust You?" (original LIME paper): https://homes.cs.washington.edu/~marcotcr/blog/lime/
- Zhang et al., "Why Should You Trust My Explanation?" (LIME reliability): https://arxiv.org/abs/1904.12991
- MC Dropout practical guide & notes: https://medium.com/@ciaranbench/monte-carlo-dropout-a-practical-guide-4b4dc18014b5
""")
# Compact one-page cheat-sheet (quick flags & checks)
with st.expander("Cheat-sheet — Quick flags & warnings", expanded=False):
cheat_text = """
Quick checks when an explanation looks suspicious
- Red flag: LIME highlights background or repeated dataset artifacts (logos, borders) — model may have learned spurious cues.
- Red flag: LIME attribution std is high in key regions — explanation unstable; try different segmentations or more samples.
- Red flag: MC Dropout or TTA histograms are multi-modal or very wide — model uncertain; consider human review or abstain.
- Quick fixes: increase dataset diversity, add regularization, try different segmentation_fn parameters, or collect more labels for uncertain cases.
One-line definitions
- LIME: perturb + fit simple local model to explain a single prediction.
- MC Dropout: enable dropout at inference and sample to estimate epistemic uncertainty.
- TTA: apply small input transforms at inference to measure sensitivity / aleatoric uncertainty.
Pro-tip: Use explanation + uncertainty to drive active learning: pick instances with high prediction uncertainty or unstable explanations for labeling.
"""
# Show the cheat-sheet as markdown
st.markdown(cheat_text)
# Download button for the cheat-sheet as plain text
try:
st.download_button(
label="Download cheat-sheet (.txt)",
data=cheat_text,
file_name="cheat_sheet.txt",
mime="text/plain",
)
except Exception:
# Streamlit may raise if download_button isn't available in some environments; ignore gracefully
pass
# Copy-to-clipboard button using a small HTML+JS snippet
escaped = html.escape(cheat_text)
copy_html = f"""
<div>
<button id='copy-btn' style='padding:6px 10px;border-radius:4px;'>Copy cheat-sheet</button>
<script>
const btn = document.getElementById('copy-btn');
btn.addEventListener('click', async () => {{
try {{
await navigator.clipboard.writeText(`{escaped}`);
btn.innerText = 'Copied!';
setTimeout(() => btn.innerText = 'Copy cheat-sheet', 1500);
}} catch (e) {{
btn.innerText = 'Copy failed';
}}
}});
</script>
</div>
"""
components.html(copy_html, height=70)
# Display active model clearly in the main panel
is_custom = st.session_state.get('using_custom', False)
custom_badge = " 🔧 Custom" if is_custom else ""
st.markdown(f"### Active Model: `{st.session_state.active_model}{custom_badge}`")
model_type = "Hugging Face Transformer" if is_hf_model else "timm CNN Architecture"
st.caption(f"Model type: {model_type}")
# ---------------- Helpers ----------------
def classifier_fn(images_batch):
# Use current model/processor from session state
inputs = processor(images=[Image.fromarray(x.astype(np.uint8)) for x in images_batch],
return_tensors="pt").to(device)
with torch.no_grad():
if is_hf_model:
outputs = model(**inputs)
logits = outputs.logits
else:
x = inputs['pixel_values']
logits = model(x)
probs = torch.softmax(logits, dim=-1).cpu().numpy()
return probs
def predict_probs(pil_img):
# Use current model/processor from session state
inputs = processor(images=pil_img, return_tensors="pt").to(device)
with torch.no_grad():
if is_hf_model:
outputs = model(**inputs)
logits = outputs.logits
else:
x = inputs['pixel_values']
logits = model(x)
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
return probs
# ---------------- Upload ----------------
uploaded = st.file_uploader("Upload an image", type=["png","jpg","jpeg"])
if uploaded:
img = Image.open(uploaded).convert("RGB").resize((224,224))
logger.info("Uploaded image received (size=%s)", img.size)
# Streamlit 1.XX: replace deprecated `use_container_width` with `width`
# For full-width behavior use width='stretch' (or 'content' for intrinsic size)
st.image(img, caption="Uploaded image", width='stretch')
# ---------------- Prediction ----------------
probs = predict_probs(img)
pred_idx = int(np.argmax(probs))
# Get label - handle models differently based on source
if is_hf_model:
# Use model's config.id2label if available
pred_label = model.config.id2label[pred_idx]
elif st.session_state.get('class_map'):
# Use custom class map if provided (access defensively)
_class_map = st.session_state.get('class_map')
pred_label = _class_map.get(str(pred_idx), f"Class {pred_idx}") if _class_map is not None else f"Class {pred_idx}"
else:
# For timm models without labels
pred_label = f"Class {pred_idx}"
pred_prob = float(probs[pred_idx])
logger.info("Prediction: %s (%.3f)", pred_label, pred_prob)
st.subheader("🔮 Prediction")
st.write(f"**Top-1:** {pred_label} ({pred_prob:.3f})")
if not is_hf_model and not st.session_state.get('class_map'):
st.info("ℹ️ Using model without class names. Upload a class mapping in the sidebar for friendly labels.")
# ---------------- LIME ----------------
st.subheader("📍 LIME Attribution")
st.markdown("""
**Local Interpretable Model-agnostic Explanations (LIME)** is a technique that approximates how a complex model (like ViT or ResNet) makes decisions for a specific input by creating a simpler, interpretable model around it.
It perturbs the image into segments and sees which ones most influence the prediction, revealing what the model "sees" as important.
This is crucial for debugging biases or understanding if the model focuses on relevant features vs. artifacts.
""")
img_np = np.array(img)
with st.spinner("Generating LIME explanation..."):
exp = explainer.explain_instance(
img_np, classifier_fn=classifier_fn, top_labels=1, num_samples=1000,
segmentation_fn=lambda x: slic(x, n_segments=60, compactness=9, start_label=0)
)
temp, mask = exp.get_image_and_mask(pred_idx, positive_only=True,
num_features=8, hide_rest=False)
lime_img = mark_boundaries(temp/255.0, mask)
st.image(lime_img, caption=f"LIME highlights regions important for '{pred_label}'")
st.info("""
**How to read:** Bright (or colored) segments show areas the model relied on most for its prediction – these are the "superpixels" that, when altered, change the output the most.
Green/red overlays often indicate positive/negative contributions. If irrelevant background or edges light up, it might signal the model learned spurious correlations (e.g., from training data artifacts).
Furthermore, this builds trust by showing if AI decisions align with human intuition.
""")
# ---------------- LIME Uncertainty ----------------
st.subheader("📊 LIME Attribution Uncertainty")
st.markdown("""
Uncertainty in explanations arises because LIME is stochastic – it samples perturbations randomly. By running LIME multiple times, we can measure variability in attributions,
highlighting if the model's reasoning is consistent or fragile for this image. High variability suggests the explanation (and thus model confidence) isn't robust.
""")
logger.info("Starting LIME uncertainty runs (n=5)")
maps = []
for i in range(5):
logger.debug("LIME run %d", i+1)
exp = explainer.explain_instance(
img_np, classifier_fn=classifier_fn, top_labels=1, num_samples=500,
segmentation_fn=lambda x: slic(x, n_segments=60, compactness=9, start_label=0)
)
local_exp = dict(exp.local_exp)[pred_idx]
segments = exp.segments
attr_map = np.zeros(segments.shape)
for seg_id, weight in local_exp:
attr_map[segments == seg_id] = weight
maps.append(attr_map)
maps = np.stack(maps)
mean_attr, std_attr = maps.mean(0), maps.std(0)
fig, ax = plt.subplots(1,2, figsize=(8,4))
im1 = ax[0].imshow(mean_attr, cmap="jet"); ax[0].set_title("Mean attribution"); ax[0].axis("off")
plt.colorbar(im1, ax=ax[0], fraction=0.046)
im2 = ax[1].imshow(std_attr, cmap="hot"); ax[1].set_title("Attribution std (uncertainty)"); ax[1].axis("off")
plt.colorbar(im2, ax=ax[1], fraction=0.046)
st.pyplot(fig)
st.info("""
**How to read:** The left heatmap shows average importance across runs (hotter = more influential). The right shows standard deviation – high std (yellow/red) means unstable explanations for those regions.
If uncertainty is high in key areas, the model might overfit or need more diverse training data. This helps ML practitioners quantify explanation reliability.
""")
logger.info("Completed LIME uncertainty runs")
# ---------------- MC Dropout ----------------
st.subheader("🎲 MC Dropout Uncertainty")
st.markdown("""
Monte Carlo (MC) Dropout treats dropout layers (normally off during inference) as a Bayesian approximation to estimate epistemic uncertainty – how much the model "doesn't know" due to limited training.
By enabling dropout and sampling predictions multiple times, we see if the model consistently agrees on the class or wavers, indicating potential unreliability.
""")
logger.info("Starting MC Dropout sampling")
model.train() # enable dropout
mc_preds = []
with torch.no_grad():
for _ in range(30):
probs_mc = predict_probs(img)
mc_preds.append(probs_mc)
model.eval()
mc_preds = np.stack(mc_preds)
mc_mean = mc_preds.mean(0)
mc_top = mc_mean.argmax()
if is_hf_model:
mc_label = model.config.id2label[mc_top]
elif st.session_state.get('class_map'):
_class_map = st.session_state.get('class_map')
mc_label = _class_map.get(str(mc_top), f"Class {mc_top}") if _class_map is not None else f"Class {mc_top}"
else:
mc_label = f"Class {mc_top}"
p = mc_preds[:, mc_top]
fig, ax = plt.subplots()
ax.hist(p, bins=15, color="C0")
ax.set_title(f"MC Dropout: p({mc_label}) across samples")
st.pyplot(fig)
st.info("""
**How to read:** This histogram shows probability distributions for the top class across 30 samples. A narrow, peaked distribution means stable confidence (low uncertainty).
A wide spread or multiple modes suggests the model is unsure, possibly due to out-of-distribution inputs. For devs, this flags cases needing human review; it highlights risky predictions.
""")
logger.info("Completed MC Dropout: top=%s", mc_label)
# ---------------- Test-Time Augmentation (TTA) Uncertainty ----------------
st.subheader("🔄 Test-Time Augmentation (TTA) Uncertainty")
st.markdown("""
Test-Time Augmentation (TTA) applies random transformations (crops, flips) at inference to probe aleatoric uncertainty – noise inherent in the input or model.
If predictions vary wildly under small changes, the model relies on brittle features, revealing data-related issues rather than model knowledge gaps.
""")
logger.info("Starting TTA sampling")
tta_tfms = T.Compose([T.Resize(256), T.RandomResizedCrop(224, scale=(0.9,1.0)), T.RandomHorizontalFlip(p=0.5)])
tta_preds = []
with torch.no_grad():
for _ in range(20):
aug = tta_tfms(img)
probs_tta = predict_probs(aug)
tta_preds.append(probs_tta)
tta_preds = np.stack(tta_preds)
tta_mean = tta_preds.mean(0)
tta_top = tta_mean.argmax()
if is_hf_model:
tta_label = model.config.id2label[tta_top]
elif st.session_state.get('class_map'):
_class_map = st.session_state.get('class_map')
tta_label = _class_map.get(str(tta_top), f"Class {tta_top}") if _class_map is not None else f"Class {tta_top}"
else:
tta_label = f"Class {tta_top}"
p_tta = tta_preds[:, tta_top]
fig, ax = plt.subplots()
ax.hist(p_tta, bins=15, color="C1")
ax.set_title(f"TTA: p({tta_label}) across augmentations")
st.pyplot(fig)
st.info("""
**How to read:** Similar to MC Dropout, but focused on input variations. Low variance means the prediction is robust to perturbations (good sign). High variance indicates sensitivity to details like lighting/position,
common in overfitted models. Use this to assess if your AI system handles real-world variability well.
""")
logger.info("Completed TTA: top=%s", tta_label)
# ---------------- Summary ----------------