File size: 10,113 Bytes
bca9b06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd4c84f
 
 
 
 
 
 
 
 
 
 
 
 
 
bca9b06
 
bd4c84f
bca9b06
 
 
 
bd4c84f
bca9b06
 
 
bd4c84f
bca9b06
bd4c84f
bca9b06
 
bd4c84f
bca9b06
 
 
 
 
 
 
bd4c84f
bca9b06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd4c84f
bca9b06
 
bd4c84f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bca9b06
 
bd4c84f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bca9b06
bd4c84f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bca9b06
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import time, math
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoImageProcessor, AutoModelForImageClassification

# --- Model choice (CPU-friendly) ---
MODEL_ID = "google/vit-base-patch16-224"  # alternatives: "microsoft/resnet-50", "facebook/convnext-tiny-224"

# Load once at startup
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()
LABELS = model.config.id2label  # {idx: "label"}

def _softmax_with_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor:
    """Softmax(logits / T) with numerical stability."""
    if temperature <= 0:
        temperature = 1.0
    scaled = logits / float(temperature)
    scaled = scaled - torch.max(scaled)
    exp = torch.exp(scaled)
    return exp / torch.sum(exp)

def _entropy(probs: np.ndarray) -> float:
    """Shannon entropy in nats: -Σ p log p (ignore zeros)."""
    p = probs[probs > 0]
    return float(-(p * np.log(p)).sum())

def _kl(p: np.ndarray, q: np.ndarray) -> float:
    """KL divergence KL(p||q), with small epsilon for stability."""
    eps = 1e-12
    p = p + eps
    q = q + eps
    return float(np.sum(p * np.log(p / q)))

def _jsd(p: np.ndarray, q: np.ndarray) -> float:
    """Jensen–Shannon divergence (symmetric, bounded)."""
    m = 0.5 * (p + q)
    return 0.5 * _kl(p, m) + 0.5 * _kl(q, m)

def _make_bar(labels, probs, title="Top-K predicted classes"):
    """Return a matplotlib horizontal bar chart of probabilities."""
    fig, ax = plt.subplots(figsize=(6, 3.2))
    y = np.arange(len(labels))
    ax.barh(y, probs)  # default colors only (per teaching tool rules)
    ax.set_yticks(y, labels)
    ax.invert_yaxis()
    ax.set_xlim(0, 1)
    ax.set_xlabel("Probability")
    ax.set_title(title)
    fig.tight_layout()
    return fig

def _analyze_single(img, top_k=5, temperature=1.0):
    """
    Return: (quick_label_dict, bar_plot, table_rows, notes_markdown, full_probs_numpy)
    """
    if img is None:
        return ({"<no image>": 1.0}, None, [], "Please upload an image.", None)

    t0 = time.perf_counter()
    inputs = processor(images=img, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits[0]  # shape [num_labels]
    probs = _softmax_with_temperature(logits, temperature)

    k = max(1, int(top_k))
    k = min(k, probs.shape[0])
    top_vals, top_idx = torch.topk(probs, k=k, dim=-1)
    top_idx = top_idx.tolist()
    top_vals = top_vals.tolist()
    labels = [LABELS[i] for i in top_idx]
    logits_top = [float(logits[i]) for i in top_idx]

    quick = {lab: float(p) for lab, p in zip(labels, top_vals)}
    fig = _make_bar(labels, top_vals)

    rows = []
    for rank, (lab, p, lg) in enumerate(zip(labels, top_vals, logits_top), start=1):
        rows.append([rank, lab, round(float(p), 6), round(float(lg), 6)])

    probs_np = probs.detach().cpu().numpy()
    H = _entropy(probs_np)
    top1 = float(top_vals[0])
    top2 = float(top_vals[1]) if len(top_vals) > 1 else 0.0
    margin = top1 - top2
    cum_topk = float(sum(top_vals))
    infer_ms = (time.perf_counter() - t0) * 1000.0

    size = processor.size if hasattr(processor, "size") else {}
    target_h = size.get("height", None)
    target_w = size.get("width", None)
    size_str = f"{target_h}×{target_w}" if (target_h and target_w) else "model default"

    md = (
        f"**Uncertainty**  \n"
        f"- Entropy (lower→more confident): **{H:.3f} nats**  \n"
        f"- Top-1 margin (Top-1 − Top-2): **{margin:.3f}**  \n"
        f"- Cumulative Top-K probability: **{cum_topk:.3f}**  \n\n"
        f"**Preprocessing & Runtime**  \n"
        f"- Processor target size: **{size_str}**  \n"
        f"- Inference time: **{infer_ms:.1f} ms**  \n"
    )

    return quick, fig, rows, md, probs_np

def _align_topk(labelsA, probsA, labelsB, probsB, K=5):
    """Make a unified label set of size up to K using union-of-top labels then rank by max(prob)."""
    dA = dict(zip(labelsA, probsA))
    dB = dict(zip(labelsB, probsB))
    union = set(labelsA) | set(labelsB)
    # rank by max(prob from A, prob from B)
    ranked = sorted(list(union), key=lambda x: max(dA.get(x, 0.0), dB.get(x, 0.0)), reverse=True)
    chosen = ranked[:K]
    a = [float(dA.get(l, 0.0)) for l in chosen]
    b = [float(dB.get(l, 0.0)) for l in chosen]
    return chosen, a, b

def analyze_single(img, top_k=5, temperature=1.0):
    quick, fig, rows, md, _ = _analyze_single(img, top_k, temperature)
    return quick, fig, rows, md

def analyze_pair(imgA, imgB, top_k=5, temperature=1.0):
    """
    A/B analysis:
      - show per-image quick dict, bar chart, table, notes
      - show aligned Top-K delta bar and divergence metrics
    """
    # Analyze each side
    qa, figa, rowsa, mda, pa = _analyze_single(imgA, top_k, temperature)
    qb, figb, rowsb, mdb, pb = _analyze_single(imgB, top_k, temperature)

    # If either missing, return as-is
    if pa is None or pb is None:
        return qa, figa, rowsa, mda, qb, figb, rowsb, mdb, None, "Upload both images for delta metrics."

    # Build aligned top-K over labels
    # We need label sets and probs for both to compute aligned bars
    # Recover top-K labels directly from rows (rank, label, prob, logit)
    labelsA = [r[1] for r in rowsa]
    probsA  = [r[2] for r in rowsa]
    labelsB = [r[1] for r in rowsb]
    probsB  = [r[2] for r in rowsb]
    chosen, a, b = _align_topk(labelsA, probsA, labelsB, probsB, K=max(int(top_k), 1))

    # Delta bar (A−B)
    deltas = [float(x - y) for x, y in zip(a, b)]
    fig_delta = _make_bar([f"{lbl} (Δ)" for lbl in chosen], deltas, title="Aligned Top-K Δ Probabilities (A − B)")

    # Distribution-level differences (full softmax vectors)
    # Ensure same length and normalize to prob distributions
    pa = pa / (pa.sum() + 1e-12)
    pb = pb / (pb.sum() + 1e-12)
    H_a = _entropy(pa)
    H_b = _entropy(pb)
    jsd = _jsd(pa, pb)

    # Top-1 labels for each side
    top1_a_idx = int(np.argmax(pa))
    top1_b_idx = int(np.argmax(pb))
    top1_a = LABELS[top1_a_idx]
    top1_b = LABELS[top1_b_idx]

    diff_md = (
        f"**A/B Divergence**  \n"
        f"- Jensen–Shannon divergence: **{jsd:.4f}** (0=same, higher=more different)  \n"
        f"- Entropy A / B: **{H_a:.3f} / {H_b:.3f}** nats  \n"
        f"- Top-1 A / B: **{top1_a} / {top1_b}**  \n"
        f"- Aligned Top-K shown above is ranked by max(prob_A, prob_B).  \n"
        f"_Tip:_ Try different crops/lighting or adjust **Temperature** to watch distributions change."
    )

    return qa, figa, rowsa, mda, qb, figb, rowsb, mdb, fig_delta, diff_md

with gr.Blocks(fill_height=True, analytics_enabled=False) as demo:
    gr.Markdown("# 🖼️ Image Classification — ‘Watch the Decision’\nVisualize probabilities, logits, entropy, and A/B deltas.\n\n"
                "_Notes:_ Predictions reflect the ImageNet‑1k label space; unusual objects or logos may be misclassified. "
                "Do not use for identity or sensitive inferences.")

    with gr.Tab("Single Image"):
        with gr.Row():
            with gr.Column(scale=1):
                img = gr.Image(type="pil", label="Upload image (JPEG/PNG)", height=340)
                topk = gr.Slider(1, 10, value=5, step=1, label="Top‑K predictions")
                temp = gr.Slider(0.25, 2.0, value=1.0, step=0.05, label="Softmax Temperature")
                run = gr.Button("Analyze", variant="primary")
            with gr.Column(scale=1):
                gr.Markdown("### Quick glance")
                glance = gr.Label(num_top_classes=10)
                gr.Markdown("### Probabilities (Top‑K)")
                plot = gr.Plot()
                gr.Markdown("### Details (Top‑K)")
                table = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5)
                gr.Markdown("### Metrics & Notes")
                notes = gr.Markdown()
        run.click(analyze_single, [img, topk, temp], [glance, plot, table, notes])

    with gr.Tab("A/B Compare"):
        with gr.Row():
            with gr.Column(scale=1):
                imgA = gr.Image(type="pil", label="Image A", height=300)
                imgB = gr.Image(type="pil", label="Image B", height=300)
                topkAB = gr.Slider(1, 10, value=5, step=1, label="Aligned Top‑K")
                tempAB = gr.Slider(0.25, 2.0, value=1.0, step=0.05, label="Softmax Temperature")
                runAB = gr.Button("Analyze A/B", variant="primary")
            with gr.Column(scale=1):
                gr.Markdown("### A — Quick glance")
                glanceA = gr.Label(num_top_classes=10)
                gr.Markdown("### A — Probabilities (Top‑K)")
                plotA = gr.Plot()
                gr.Markdown("### A — Details (Top‑K)")
                tableA = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5)
                gr.Markdown("### A — Notes")
                notesA = gr.Markdown()
            with gr.Column(scale=1):
                gr.Markdown("### B — Quick glance")
                glanceB = gr.Label(num_top_classes=10)
                gr.Markdown("### B — Probabilities (Top‑K)")
                plotB = gr.Plot()
                gr.Markdown("### B — Details (Top‑K)")
                tableB = gr.Dataframe(headers=["Rank", "Label", "Probability", "Logit"], datatype=["number", "str", "number", "number"], row_count=5)
                gr.Markdown("### B — Notes")
                notesB = gr.Markdown()
        with gr.Row():
            gr.Markdown("### Aligned Top‑K Δ (A − B) & Divergence")
        with gr.Row():
            deltaPlot = gr.Plot()
            deltaNotes = gr.Markdown()
        runAB.click(analyze_pair, [imgA, imgB, topkAB, tempAB], [glanceA, plotA, tableA, notesA, glanceB, plotB, tableB, notesB, deltaPlot, deltaNotes])

if __name__ == "__main__":
    demo.launch()