File size: 31,484 Bytes
8af51e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47ff5e5
8af51e2
 
 
 
 
 
 
 
 
47ff5e5
 
8af51e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47ff5e5
8af51e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47ff5e5
 
8af51e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ee8310
 
 
8af51e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
'''
streamlitapp.py โ€” Vision Transformer Interpretability Dashboard (Streamlit app)

This Streamlit app provides interpretability tools for vision transformer and CNN models.
Features:
- LIME explanations for image classification predictions
- Uncertainty analysis via MC Dropout and Test-Time Augmentation (TTA)
- Switch between Hugging Face (ViT, Swin, DeiT) and timm (ResNet, EfficientNet, ConvNeXt) models
- Support for custom finetuned models and class mappings
- Interactive sidebar for model selection and checkpoint upload
- Feynman-style explanations and cheat-sheet for interpretability concepts

Inspired by and reuses code from:
- vit_and_captum.py (Integrated Gradients with Captum)
- vit_lime_uncertainty.py (LIME explanations and uncertainty)
- detr_and_interp.py (Grad-CAM for DETR, logging setup)
'''

import streamlit as st
import html
import numpy as np, torch, matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor, PreTrainedModel
from lime import lime_image
import torchvision.transforms as T
import timm
from skimage.segmentation import slic, mark_boundaries
import streamlit.components.v1 as components


# Add logging
import logging, os
from logging.handlers import RotatingFileHandler

LOG_DIR = os.path.join(os.path.dirname(__file__), "logs")
os.makedirs(LOG_DIR, exist_ok=True)
logfile = os.path.join(LOG_DIR, "interp.log")

logger = logging.getLogger("interp")
if not logger.handlers:
    logger.setLevel(logging.INFO)
    sh = logging.StreamHandler()
    sh.setLevel(logging.INFO)
    fh = RotatingFileHandler(logfile, maxBytes=5_000_000, backupCount=3, encoding="utf-8")
    fh.setLevel(logging.INFO)
    fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")
    sh.setFormatter(fmt)
    fh.setFormatter(fmt)
    logger.addHandler(sh)
    logger.addHandler(fh)


# ---------------- Setup ----------------
MODEL_NAME = "google/vit-base-patch16-224"
device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------- Sidebar model selectors ----------
# Quick lists you can edit to test other HF / timm models
HF_MODELS = [
    "google/vit-base-patch16-224",
    "facebook/deit-base-patch16-224",
    "microsoft/swin-tiny-patch4-window7-224",
    "google/vit-large-patch16-224",
]
TIMM_MODELS = [
    "convnext_base",
    "resnet50",
    "efficientnet_b0",
]

def model_selector(slot_key: str, default_source="hf"):
    source = st.sidebar.selectbox(
        f"{slot_key} source",
        ["hf", "timm"],
        index=0 if default_source == "hf" else 1,
        key=f"{slot_key}_source",
    )
    if source == "hf":
        hf_choice = st.sidebar.selectbox(
            f"{slot_key} Hugging Face model",
            HF_MODELS,
            index=0,
            key=f"{slot_key}_hf",
        )
        return f"hf:{hf_choice}"
    else:
        timm_choice = st.sidebar.selectbox(
            f"{slot_key} timm model",
            TIMM_MODELS,
            index=0,
            key=f"{slot_key}_timm",
        )
        return f"timm:{timm_choice}"

# ---------- Model Loader ----------
# Use Streamlit caching when available to avoid repeated downloads
try:
    cache_decorator = st.cache_resource
except Exception:
    from functools import lru_cache
    cache_decorator = lru_cache(maxsize=8)

@cache_decorator
def load_model(choice, checkpoint=None, class_map=None, num_classes=None):
    """
    Load a model from HF, timm, or a custom checkpoint
    Args:
        choice: Model identifier ('hf:model_name' or 'timm:model_name')
        checkpoint: Optional path to custom checkpoint file
        class_map: Optional dict mapping class indices to labels
        num_classes: Optional number of classes for custom models
    """
    logger.info("Loading model: %s", choice)
    is_hf = choice.startswith("hf:")
    
    # Parse model identifier
    if is_hf:
        hf_name = choice.split("hf:")[1]
        if checkpoint:  # Custom checkpoint
            # For custom HF model, first load the architecture then apply weights
            try:
                if num_classes:
                    model = AutoModelForImageClassification.from_pretrained(
                        hf_name, num_labels=num_classes, ignore_mismatched_sizes=True
                    ).to(device)
                else:
                    model = AutoModelForImageClassification.from_pretrained(hf_name).to(device)
                
                # Load checkpoint with error handling
                state_dict = torch.load(checkpoint, map_location=device)
                # If state_dict is wrapped (common in training checkpoints)
                if "model" in state_dict:
                    state_dict = state_dict["model"]
                elif "state_dict" in state_dict:
                    state_dict = state_dict["state_dict"]
                    
                # Handle any prefix differences by checking and stripping if needed
                if all(k.startswith('model.') for k in state_dict if k != 'config'):
                    state_dict = {k[6:]: v for k, v in state_dict.items() if k != 'config'}
                
                # Load with flexible partial loading (ignore missing/unexpected)
                model.load_state_dict(state_dict, strict=False)
                logger.info("Custom checkpoint loaded for HF model")
                
                # If custom class mapping provided, update config
                if class_map:
                    model.config.id2label = class_map
                    model.config.label2id = {v: int(k) for k, v in class_map.items()}
            except Exception as e:
                logger.error(f"Error loading custom HF model: {e}")
                st.error(f"Failed to load custom model: {e}")
                # Fallback to base model
                model = AutoModelForImageClassification.from_pretrained(hf_name).to(device)
        else:
            # Standard HF model
            model = AutoModelForImageClassification.from_pretrained(hf_name).to(device)
        
        processor = AutoImageProcessor.from_pretrained(hf_name)
    
    elif choice.startswith("timm:"):
        name = choice.split("timm:")[1]
        if checkpoint:  # Custom checkpoint
            try:
                # For timm, specify custom number of classes if provided
                if num_classes:
                    model = timm.create_model(name, pretrained=False, num_classes=num_classes).to(device)
                else:
                    model = timm.create_model(name, pretrained=True).to(device)
                
                # Load checkpoint
                state_dict = torch.load(checkpoint, map_location=device)
                # Handle common checkpoint formats
                if "model" in state_dict:
                    state_dict = state_dict["model"]
                elif "state_dict" in state_dict:
                    state_dict = state_dict["state_dict"]
                
                # Handle any prefix differences
                if all(k.startswith('module.') for k in state_dict):
                    state_dict = {k[7:]: v for k, v in state_dict}
                
                model.load_state_dict(state_dict, strict=False)
                logger.info("Custom checkpoint loaded for timm model")
            except Exception as e:
                logger.error(f"Error loading custom timm model: {e}")
                st.error(f"Failed to load custom model: {e}")
                # Fallback to pretrained
                model = timm.create_model(name, pretrained=True).to(device)
        else:
            # Standard timm model
            model = timm.create_model(name, pretrained=True).to(device)
        
        # Use a standard processor for timm
        processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224")
        
    # Set model to eval mode
    model.eval()
    logger.info("Model %s loaded (eval mode)", choice)
    
    # Return model, processor, flag for HF, and class map
    return model, processor, is_hf, class_map

# Add sidebar with clear sections
st.sidebar.title("Model Selection")

# Enhanced sidebar with custom model support
with st.sidebar:
    # Add tabs for standard vs custom models
    tab1, tab2 = st.tabs(["Standard Models", "Custom Finetuned Models"])
    
    with tab1:
        st.markdown("### ๐Ÿ“Š Standard Models")
        st.markdown("Choose from pre-trained models:")
        m1 = model_selector("Active Model", default_source="hf")
        
        # Button to apply standard model change
        if st.button("๐Ÿ“‹ Set as Active Model", help="Click to use the selected model for analysis", key="std_model_btn"):
            with st.spinner(f"Loading {m1}..."):
                model, processor, is_hf_model, _ = load_model(m1)
                st.session_state.model = model
                st.session_state.processor = processor
                st.session_state.is_hf_model = is_hf_model
                st.session_state.active_model = m1
                st.session_state.using_custom = False
                st.session_state.class_map = None
                st.success(f"โœ… Model activated: {m1}")
    
    with tab2:
        st.markdown("### ๐Ÿ”ง Custom Finetuned Model")
        st.markdown("Use your own finetuned model:")
        
        # Select base architecture
        custom_source = st.selectbox(
            "Base architecture source",
            ["hf", "timm"],
            key="custom_source"
        )
        
        if custom_source == "hf":
            custom_base = st.selectbox(
                "Hugging Face base model",
                HF_MODELS,
                key="custom_hf_base"
            )
            base_model = f"hf:{custom_base}"
        else:
            custom_base = st.selectbox(
                "timm base model",
                TIMM_MODELS,
                key="custom_timm_base"
            )
            base_model = f"timm:{custom_base}"
        
        # Upload checkpoint file
        uploaded_checkpoint = st.file_uploader(
            "Upload model checkpoint (.pth, .bin)",
            type=["pth", "bin", "pt", "ckpt"],
            help="Upload your finetuned model weights"
        )
        
        # Optional class mapping
        custom_classes = st.number_input(
            "Number of classes (if different from base model)",
            min_value=0, max_value=1000, value=0,
            help="Leave at 0 to use default classes from base model"
        )
        
        uploaded_labels = st.file_uploader(
            "Upload class labels (optional JSON)",
            type=["json"],
            help="JSON file mapping class indices to labels: {\"0\": \"cat\", \"1\": \"dog\"}"
        )
        
        # Process label mapping
        class_map = None
        if uploaded_labels:
            try:
                import json
                class_map = json.loads(uploaded_labels.getvalue().decode("utf-8"))
                st.success(f"โœ“ Loaded {len(class_map)} class labels")
            except Exception as e:
                st.error(f"Error loading class labels: {e}")
        
        # Store uploaded file in session state if provided
        if uploaded_checkpoint:
            # Save to a temporary file
            import tempfile
            with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as tmp_file:
                tmp_file.write(uploaded_checkpoint.getvalue())
                checkpoint_path = tmp_file.name
                
                # Store in session state
                if 'checkpoint_path' not in st.session_state:
                    st.session_state.checkpoint_path = checkpoint_path
            
            st.success("โœ“ Checkpoint ready to use")
            
            # Button to apply custom model
            if st.button("๐Ÿš€ Load Custom Model", help="Click to use your custom model"):
                with st.spinner(f"Loading custom model based on {base_model}..."):
                    try:
                        num_classes = custom_classes if custom_classes > 0 else None
                        model, processor, is_hf_model, class_map = load_model(
                            base_model, checkpoint_path, class_map, num_classes
                        )
                        st.session_state.model = model
                        st.session_state.processor = processor
                        st.session_state.is_hf_model = is_hf_model
                        st.session_state.active_model = f"Custom {base_model}"
                        st.session_state.using_custom = True
                        st.session_state.class_map = class_map
                        st.success(f"โœ… Custom model activated!")
                    except Exception as e:
                        st.error(f"Failed to load custom model: {str(e)}")
    
    # Explanation section
    st.markdown("---")
    st.markdown("### โ„น๏ธ Model Types")
    st.markdown("""
    - **HF (Hugging Face)**: Vision Transformer models with standard interpretability
    - **timm (PyTorch Image Models)**: Classical CNN architectures like ResNet, EfficientNet
    
    *Custom models must match the base architecture's format.*
    """)

# Initialize model and processor from session state
if 'active_model' not in st.session_state:
    # First time loading - use default model
    m1 = "hf:google/vit-base-patch16-224" 
    st.session_state.active_model = m1
    model, processor, is_hf_model, _ = load_model(m1)
    st.session_state.model = model
    st.session_state.processor = processor
    st.session_state.is_hf_model = is_hf_model
    st.session_state.using_custom = False
    st.session_state.class_map = None
else:
    # Get from session state
    model = st.session_state.model
    processor = st.session_state.processor
    is_hf_model = st.session_state.is_hf_model

# Initialize explainer
explainer = lime_image.LimeImageExplainer()

st.title("๐Ÿง  Vision Transformer Interpretability Dashboard")
st.write("Upload an image and explore explanations with **LIME** and **Uncertainty Analysis**.")

# Add a Feynman-style "How it works" explanation as a collapsible expander
with st.expander("How it works โ€” Feynman-style explanations (click to expand)", expanded=False):
    st.markdown("""
        ## ๐Ÿง  Vision Transformer Interpretability โ€” Feynman-Style Explanations

        ### Why do we care about interpretability & uncertainty?

        Imagine you ask a kid to identify whether a picture is a cat. They point to the fur, ears, maybe whiskers. But what if the kid always focused on shadows, or background trees, instead of the cat itself? We want two things:

        1. **Why** did the model say โ€œcatโ€? What parts of the image made it decide so?
        2. **How confident** is the model in that decision? Could small changes flip it?

        Interpretable methods show us #1. Uncertainty estimation shows us #2. Together, they help us see not just *what* the model does, but *whether* we should trust it.

        ### Key techniques, in plain analogies

        - **LIME (Local Interpretable Model-agnostic Explanations)**: For a single image & prediction, LIME perturbs (changes) parts of the image, watches how the prediction changes, and fits a simple model locally to understand which parts are most influential.
            - Analogy: Like shining small spotlights on different parts of a stage during a play: you dim a section, see how the actorโ€™s reaction changes. The parts whose dimming changes the reaction most are parts the actor depends on.

        - **Uncertainty in LIME (multiple LIME runs)**: Because LIME uses randomness (perturbing patches), different runs can give different โ€œimportantโ€ regions. Measuring how much they differ tells you how stable/fragile the explanation is.
            - Analogy: If you ask several cooks what the dominant spice in a stew is and everyone agrees, you're confident; if opinions vary, your knowledge is shakier.

        - **MC Dropout (Monte Carlo Dropout)**: Leave dropout on at inference time and run the model multiple times. The spread of predictions is a proxy for epistemic uncertainty.
            - Analogy: Like a jury where each juror occasionally misses a sentence; if the verdict remains the same across many "faulty hearing" runs, trust it more.

        - **Test-Time Augmentation (TTA) Uncertainty**: Apply small transforms (crops, flips) at inference and watch prediction variance. High variance โ†’ brittle model.
            - Analogy: Take photos under slightly different lighting/angles; if the label flips, the model may depend on superficial cues.

        ### How to read the visuals

        - LIME highlights: bright / colored superpixels = influential regions. If background or artifacts light up, that's a red flag.
        - LIME uncertainty heatmap: high std in a region means attributions are unstable there.
        - MC Dropout / TTA histograms: narrow/tall peak = confident, wide/multi-modal = uncertain.

        ### Limitations & caveats

        - Stable explanations can still be consistently wrong if the model learned a bias.
        - MC Dropout is an approximation โ€” it helps but doesn't fully replace calibrated probabilistic methods.
        - TTA shows input sensitivity, not full distributional shift robustness.

        ### Quick example (walkthrough)

        1. Upload image โ†’ model predicts label with some probability.
        2. LIME finds important superpixels; multiple LIME runs give mean + std maps.
        3. MC Dropout produces a histogram over runs; use it to judge epistemic uncertainty.
        4. TTA shows sensitivity to small input changes.

        ### Practical tips

        - Use explanation + uncertainty to guide active learning: label cases where the model is uncertain or explanations are unstable.
        - For safety-critical systems, combine these visual signals with human review and stricter failure thresholds.

        ### Where to read more

        - Christoph Molnar โ€” Interpretable Machine Learning (chapter on LIME): https://christophm.github.io/interpretable-ml-book/lime.html
        - Ribeiro et al., "Why Should I Trust You?" (original LIME paper): https://homes.cs.washington.edu/~marcotcr/blog/lime/
        - Zhang et al., "Why Should You Trust My Explanation?" (LIME reliability): https://arxiv.org/abs/1904.12991
        - MC Dropout practical guide & notes: https://medium.com/@ciaranbench/monte-carlo-dropout-a-practical-guide-4b4dc18014b5
        """)

# Compact one-page cheat-sheet (quick flags & checks)
with st.expander("Cheat-sheet โ€” Quick flags & warnings", expanded=False):
        cheat_text = """
Quick checks when an explanation looks suspicious

- Red flag: LIME highlights background or repeated dataset artifacts (logos, borders) โ€” model may have learned spurious cues.
- Red flag: LIME attribution std is high in key regions โ€” explanation unstable; try different segmentations or more samples.
- Red flag: MC Dropout or TTA histograms are multi-modal or very wide โ€” model uncertain; consider human review or abstain.
- Quick fixes: increase dataset diversity, add regularization, try different segmentation_fn parameters, or collect more labels for uncertain cases.

One-line definitions
- LIME: perturb + fit simple local model to explain a single prediction.
- MC Dropout: enable dropout at inference and sample to estimate epistemic uncertainty.
- TTA: apply small input transforms at inference to measure sensitivity / aleatoric uncertainty.

Pro-tip: Use explanation + uncertainty to drive active learning: pick instances with high prediction uncertainty or unstable explanations for labeling.
"""

        # Show the cheat-sheet as markdown
        st.markdown(cheat_text)

        # Download button for the cheat-sheet as plain text
        try:
                st.download_button(
                        label="Download cheat-sheet (.txt)",
                        data=cheat_text,
                        file_name="cheat_sheet.txt",
                        mime="text/plain",
                )
        except Exception:
                # Streamlit may raise if download_button isn't available in some environments; ignore gracefully
                pass

        # Copy-to-clipboard button using a small HTML+JS snippet
        escaped = html.escape(cheat_text)
        copy_html = f"""
        <div>
            <button id='copy-btn' style='padding:6px 10px;border-radius:4px;'>Copy cheat-sheet</button>
            <script>
                const btn = document.getElementById('copy-btn');
                btn.addEventListener('click', async () => {{
                    try {{
                        await navigator.clipboard.writeText(`{escaped}`);
                        btn.innerText = 'Copied!';
                        setTimeout(() => btn.innerText = 'Copy cheat-sheet', 1500);
                    }} catch (e) {{
                        btn.innerText = 'Copy failed';
                    }}
                }});
            </script>
        </div>
        """
        components.html(copy_html, height=70)

# Display active model clearly in the main panel
is_custom = st.session_state.get('using_custom', False)
custom_badge = " ๐Ÿ”ง Custom" if is_custom else ""
st.markdown(f"### Active Model: `{st.session_state.active_model}{custom_badge}`")
model_type = "Hugging Face Transformer" if is_hf_model else "timm CNN Architecture"
st.caption(f"Model type: {model_type}")

# ---------------- Helpers ----------------
def classifier_fn(images_batch):
    # Use current model/processor from session state
    inputs = processor(images=[Image.fromarray(x.astype(np.uint8)) for x in images_batch],
                      return_tensors="pt").to(device)
    with torch.no_grad():
        if is_hf_model:
            outputs = model(**inputs)
            logits = outputs.logits
        else:
            x = inputs['pixel_values']
            logits = model(x)
        probs = torch.softmax(logits, dim=-1).cpu().numpy()
    return probs

def predict_probs(pil_img):
    # Use current model/processor from session state
    inputs = processor(images=pil_img, return_tensors="pt").to(device)
    with torch.no_grad():
        if is_hf_model:
            outputs = model(**inputs)
            logits = outputs.logits
        else:
            x = inputs['pixel_values']
            logits = model(x)
        probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
    return probs

# ---------------- Upload ----------------
uploaded = st.file_uploader("Upload an image", type=["png","jpg","jpeg"])
if uploaded:
    img = Image.open(uploaded).convert("RGB").resize((224,224))
    logger.info("Uploaded image received (size=%s)", img.size)
    # Streamlit 1.XX: replace deprecated `use_container_width` with `width`
    # For full-width behavior use width='stretch' (or 'content' for intrinsic size)
    st.image(img, caption="Uploaded image", width='stretch')

    # ---------------- Prediction ----------------
    probs = predict_probs(img)
    pred_idx = int(np.argmax(probs))
    
    # Get label - handle models differently based on source
    if is_hf_model:
        # Use model's config.id2label if available
        pred_label = model.config.id2label[pred_idx]
    elif st.session_state.get('class_map'):
        # Use custom class map if provided (access defensively)
        _class_map = st.session_state.get('class_map')
        pred_label = _class_map.get(str(pred_idx), f"Class {pred_idx}") if _class_map is not None else f"Class {pred_idx}"
    else:
        # For timm models without labels
        pred_label = f"Class {pred_idx}"
    
    pred_prob = float(probs[pred_idx])
    logger.info("Prediction: %s (%.3f)", pred_label, pred_prob)

    st.subheader("๐Ÿ”ฎ Prediction")
    st.write(f"**Top-1:** {pred_label} ({pred_prob:.3f})")
    
    if not is_hf_model and not st.session_state.get('class_map'):
        st.info("โ„น๏ธ Using model without class names. Upload a class mapping in the sidebar for friendly labels.")

    # ---------------- LIME ----------------
    st.subheader("๐Ÿ“ LIME Attribution")
    st.markdown("""
    **Local Interpretable Model-agnostic Explanations (LIME)** is a technique that approximates how a complex model (like ViT or ResNet) makes decisions for a specific input by creating a simpler, interpretable model around it. 
    It perturbs the image into segments and sees which ones most influence the prediction, revealing what the model "sees" as important. 
    This is crucial for debugging biases or understanding if the model focuses on relevant features vs. artifacts.
    """)
    img_np = np.array(img)
    
    with st.spinner("Generating LIME explanation..."):
        exp = explainer.explain_instance(
            img_np, classifier_fn=classifier_fn, top_labels=1, num_samples=1000,
            segmentation_fn=lambda x: slic(x, n_segments=60, compactness=9, start_label=0)
        )
        temp, mask = exp.get_image_and_mask(pred_idx, positive_only=True,
                                            num_features=8, hide_rest=False)
        lime_img = mark_boundaries(temp/255.0, mask)

    st.image(lime_img, caption=f"LIME highlights regions important for '{pred_label}'")
    st.info("""
    **How to read:** Bright (or colored) segments show areas the model relied on most for its prediction โ€“ these are the "superpixels" that, when altered, change the output the most. 
    Green/red overlays often indicate positive/negative contributions. If irrelevant background or edges light up, it might signal the model learned spurious correlations (e.g., from training data artifacts). 
    Furthermore, this builds trust by showing if AI decisions align with human intuition.
    """)

    # ---------------- LIME Uncertainty ----------------
    st.subheader("๐Ÿ“Š LIME Attribution Uncertainty")
    st.markdown("""
    Uncertainty in explanations arises because LIME is stochastic โ€“ it samples perturbations randomly. By running LIME multiple times, we can measure variability in attributions, 
    highlighting if the model's reasoning is consistent or fragile for this image. High variability suggests the explanation (and thus model confidence) isn't robust.
    """)
    logger.info("Starting LIME uncertainty runs (n=5)")
    maps = []
    for i in range(5):
        logger.debug("LIME run %d", i+1)
        exp = explainer.explain_instance(
            img_np, classifier_fn=classifier_fn, top_labels=1, num_samples=500,
            segmentation_fn=lambda x: slic(x, n_segments=60, compactness=9, start_label=0)
        )
        local_exp = dict(exp.local_exp)[pred_idx]
        segments = exp.segments
        attr_map = np.zeros(segments.shape)
        for seg_id, weight in local_exp:
            attr_map[segments == seg_id] = weight
        maps.append(attr_map)
    maps = np.stack(maps)
    mean_attr, std_attr = maps.mean(0), maps.std(0)

    fig, ax = plt.subplots(1,2, figsize=(8,4))
    im1 = ax[0].imshow(mean_attr, cmap="jet"); ax[0].set_title("Mean attribution"); ax[0].axis("off")
    plt.colorbar(im1, ax=ax[0], fraction=0.046)
    im2 = ax[1].imshow(std_attr, cmap="hot"); ax[1].set_title("Attribution std (uncertainty)"); ax[1].axis("off")
    plt.colorbar(im2, ax=ax[1], fraction=0.046)
    st.pyplot(fig)
    st.info("""
    **How to read:** The left heatmap shows average importance across runs (hotter = more influential). The right shows standard deviation โ€“ high std (yellow/red) means unstable explanations for those regions. 
    If uncertainty is high in key areas, the model might overfit or need more diverse training data. This helps ML practitioners quantify explanation reliability.
    """)
    logger.info("Completed LIME uncertainty runs")

    # ---------------- MC Dropout ----------------
    st.subheader("๐ŸŽฒ MC Dropout Uncertainty")
    st.markdown("""
    Monte Carlo (MC) Dropout treats dropout layers (normally off during inference) as a Bayesian approximation to estimate epistemic uncertainty โ€“ how much the model "doesn't know" due to limited training. 
    By enabling dropout and sampling predictions multiple times, we see if the model consistently agrees on the class or wavers, indicating potential unreliability.
    """)
    logger.info("Starting MC Dropout sampling")
    model.train()  # enable dropout
    mc_preds = []
    with torch.no_grad():
        for _ in range(30):
            probs_mc = predict_probs(img)
            mc_preds.append(probs_mc)
    model.eval()
    mc_preds = np.stack(mc_preds)
    mc_mean = mc_preds.mean(0)
    mc_top = mc_mean.argmax()
    if is_hf_model:
        mc_label = model.config.id2label[mc_top] 
    elif st.session_state.get('class_map'):
        _class_map = st.session_state.get('class_map')
        mc_label = _class_map.get(str(mc_top), f"Class {mc_top}") if _class_map is not None else f"Class {mc_top}"
    else:
        mc_label = f"Class {mc_top}"
    p = mc_preds[:, mc_top]

    fig, ax = plt.subplots()
    ax.hist(p, bins=15, color="C0")
    ax.set_title(f"MC Dropout: p({mc_label}) across samples")
    st.pyplot(fig)
    st.info("""
    **How to read:** This histogram shows probability distributions for the top class across 30 samples. A narrow, peaked distribution means stable confidence (low uncertainty). 
    A wide spread or multiple modes suggests the model is unsure, possibly due to out-of-distribution inputs. For devs, this flags cases needing human review; it highlights risky predictions.
    """)
    logger.info("Completed MC Dropout: top=%s", mc_label)

    # ---------------- Test-Time Augmentation (TTA) Uncertainty ----------------
    st.subheader("๐Ÿ”„ Test-Time Augmentation (TTA) Uncertainty")
    st.markdown("""
    Test-Time Augmentation (TTA) applies random transformations (crops, flips) at inference to probe aleatoric uncertainty โ€“ noise inherent in the input or model. 
    If predictions vary wildly under small changes, the model relies on brittle features, revealing data-related issues rather than model knowledge gaps.
    """)
    logger.info("Starting TTA sampling")
    tta_tfms = T.Compose([T.Resize(256), T.RandomResizedCrop(224, scale=(0.9,1.0)), T.RandomHorizontalFlip(p=0.5)])
    tta_preds = []
    with torch.no_grad():
        for _ in range(20):
            aug = tta_tfms(img)
            probs_tta = predict_probs(aug)
            tta_preds.append(probs_tta)
    tta_preds = np.stack(tta_preds)
    tta_mean = tta_preds.mean(0)
    tta_top = tta_mean.argmax()
    if is_hf_model:
        tta_label = model.config.id2label[tta_top]
    elif st.session_state.get('class_map'):
        _class_map = st.session_state.get('class_map')
        tta_label = _class_map.get(str(tta_top), f"Class {tta_top}") if _class_map is not None else f"Class {tta_top}"
    else:
        tta_label = f"Class {tta_top}"
    p_tta = tta_preds[:, tta_top]

    fig, ax = plt.subplots()
    ax.hist(p_tta, bins=15, color="C1")
    ax.set_title(f"TTA: p({tta_label}) across augmentations")
    st.pyplot(fig)
    st.info("""
    **How to read:** Similar to MC Dropout, but focused on input variations. Low variance means the prediction is robust to perturbations (good sign). High variance indicates sensitivity to details like lighting/position, 
    common in overfitted models. Use this to assess if your AI system handles real-world variability well.
    """)
    logger.info("Completed TTA: top=%s", tta_label)
# ---------------- Summary ----------------