strelizi commited on
Commit
5d7db16
·
verified ·
1 Parent(s): d28e568

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -366
app.py CHANGED
@@ -1,379 +1,48 @@
 
1
  import torch
2
- import torch.nn as nn
3
  from torchvision import models, transforms
4
  from PIL import Image
5
- import gradio as gr
6
- from captum.attr import LayerGradCam, IntegratedGradients
7
- import numpy as np
8
  import matplotlib.pyplot as plt
 
9
  from io import BytesIO
10
- import urllib.request
11
- from torch.nn.functional import interpolate
12
- import warnings
13
- warnings.filterwarnings('ignore')
14
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
- @torch.no_grad()
17
- def load_models_and_labels():
18
- """Load multiple models for ensemble prediction"""
19
- print("🚀 Loading advanced models...")
20
 
21
- # Model 1: EfficientNet-B4 (Best accuracy/speed ratio)
22
- efficientnet = models.efficientnet_b4(weights='IMAGENET1K_V1')
23
- efficientnet.eval().to(DEVICE)
 
 
24
 
25
- # Model 2: ResNet152 (Deep architecture)
26
- resnet152 = models.resnet152(weights='IMAGENET1K_V2')
27
- resnet152.eval().to(DEVICE)
28
 
29
- # Model 3: ConvNeXt-Base (State-of-the-art)
30
- convnext = models.convnext_base(weights='IMAGENET1K_V1')
31
- convnext.eval().to(DEVICE)
 
 
32
 
33
- # Load ImageNet labels
34
- url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
35
- response = urllib.request.urlopen(url)
36
- labels = [line.decode('utf-8').strip() for line in response.readlines()]
 
37
 
38
- print(" Models loaded successfully!")
39
- return efficientnet, resnet152, convnext, labels
40
-
41
- efficientnet_model, resnet152_model, convnext_model, IMAGENET_LABELS = load_models_and_labels()
42
-
43
- # Setup Grad-CAM for each model
44
- efficientnet_target = efficientnet_model.features[-1]
45
- resnet152_target = resnet152_model.layer4[-1]
46
- convnext_target = convnext_model.features[-1][-1]
47
-
48
- gradcam_efficient = LayerGradCam(efficientnet_model, efficientnet_target)
49
- gradcam_resnet = LayerGradCam(resnet152_model, resnet152_target)
50
- gradcam_convnext = LayerGradCam(convnext_model, convnext_target)
51
-
52
- # Transforms for different models
53
- transform_standard = transforms.Compose([
54
- transforms.Resize((224, 224)),
55
- transforms.ToTensor(),
56
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
57
- ])
58
-
59
- transform_large = transforms.Compose([
60
- transforms.Resize((380, 380)),
61
- transforms.ToTensor(),
62
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
63
- ])
64
-
65
-
66
- def predict_and_explain(image, use_tta=True):
67
- if image is None:
68
- return "Please upload an image", None, None
69
-
70
- try:
71
- # Prepare inputs
72
- img_tensor_224 = transform_standard(image).unsqueeze(0).to(DEVICE)
73
- img_tensor_380 = transform_large(image).unsqueeze(0).to(DEVICE)
74
-
75
- predictions = []
76
- model_names = []
77
-
78
- with torch.no_grad():
79
- # EfficientNet-B4 prediction (380x380)
80
- output_eff = efficientnet_model(img_tensor_380)
81
- prob_eff = torch.softmax(output_eff, dim=1)
82
- predictions.append(prob_eff)
83
- model_names.append("EfficientNet-B4")
84
-
85
- # ResNet152 prediction
86
- output_res = resnet152_model(img_tensor_224)
87
- prob_res = torch.softmax(output_res, dim=1)
88
- predictions.append(prob_res)
89
- model_names.append("ResNet152")
90
-
91
- # ConvNeXt prediction
92
- output_conv = convnext_model(img_tensor_224)
93
- prob_conv = torch.softmax(output_conv, dim=1)
94
- predictions.append(prob_conv)
95
- model_names.append("ConvNeXt-Base")
96
-
97
- # Test-Time Augmentation (optional)
98
- if use_tta:
99
- # Horizontal flip
100
- img_flip = transforms.functional.hflip(image)
101
- img_flip_tensor = transform_standard(img_flip).unsqueeze(0).to(DEVICE)
102
-
103
- output_flip_res = resnet152_model(img_flip_tensor)
104
- predictions.append(torch.softmax(output_flip_res, dim=1))
105
- model_names.append("ResNet152-Flip")
106
-
107
- # Ensemble: Weighted average (EfficientNet gets highest weight)
108
- weights = [0.40, 0.30, 0.25, 0.05] if use_tta else [0.45, 0.30, 0.25]
109
- ensemble_prob = sum(w * p for w, p in zip(weights, predictions))
110
-
111
- top10_prob, top10_idx = torch.topk(ensemble_prob, 10)
112
- pred_class = top10_idx[0][0].item()
113
- confidence = top10_prob[0][0].item()
114
-
115
- # Generate Grad-CAM from best performing model (EfficientNet)
116
- attributions = gradcam_efficient.attribute(img_tensor_380, target=pred_class)
117
- attr_resized = interpolate(attributions, size=(224, 224), mode='bilinear', align_corners=False)
118
- attr_np = attr_resized.squeeze().cpu().detach().numpy()
119
- attr_np = (attr_np - attr_np.min()) / (attr_np.max() - attr_np.min() + 1e-8)
120
-
121
- # Alternative: Get Grad-CAM from all models and average
122
- attr_resnet = gradcam_resnet.attribute(img_tensor_224, target=pred_class)
123
- attr_resnet = interpolate(attr_resnet, size=(224, 224), mode='bilinear', align_corners=False)
124
- attr_resnet_np = attr_resnet.squeeze().cpu().detach().numpy()
125
- attr_resnet_np = (attr_resnet_np - attr_resnet_np.min()) / (attr_resnet_np.max() - attr_resnet_np.min() + 1e-8)
126
-
127
- # Average heatmaps for better visualization
128
- attr_avg = (attr_np * 0.6 + attr_resnet_np * 0.4)
129
-
130
- # Main visualization
131
- fig = plt.figure(figsize=(24, 14))
132
- fig.patch.set_facecolor('#0a0a0a')
133
-
134
- gs = fig.add_gridspec(2, 3, height_ratios=[2, 1], hspace=0.3, wspace=0.15)
135
-
136
- ax1 = fig.add_subplot(gs[0, 0])
137
- ax2 = fig.add_subplot(gs[0, 1])
138
- ax3 = fig.add_subplot(gs[0, 2])
139
- ax4 = fig.add_subplot(gs[1, :])
140
-
141
- ax1.imshow(image)
142
- ax1.set_title("Original Image", fontsize=18, fontweight='700', color='#e0e0e0', pad=20)
143
- ax1.axis('off')
144
-
145
- im = ax2.imshow(attr_avg, cmap='jet', interpolation='bilinear')
146
- ax2.set_title("Ensemble Grad-CAM", fontsize=18, fontweight='700', color='#e0e0e0', pad=20)
147
- ax2.axis('off')
148
- cbar = plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
149
- cbar.ax.tick_params(labelsize=12, colors='#a0a0a0')
150
- cbar.set_label('Importance', rotation=270, labelpad=25, color='#e0e0e0', fontsize=13, fontweight='600')
151
 
152
- ax3.imshow(image)
153
- ax3.imshow(attr_avg, cmap='jet', alpha=0.5, interpolation='bilinear')
154
- ax3.set_title(f"AI Focus: {IMAGENET_LABELS[pred_class]}", fontsize=18, fontweight='700', color='#e0e0e0', pad=20)
155
- ax3.axis('off')
156
-
157
- top10_labels = [IMAGENET_LABELS[idx.item()] for idx in top10_idx[0]]
158
- top10_probs = [prob.item() * 100 for prob in top10_prob[0]]
159
-
160
- colors = ['#10b981' if i == 9 else '#3b82f6' if i >= 7 else '#8b5cf6' for i in range(10)]
161
- bars = ax4.barh(range(10), top10_probs[::-1], color=colors[::-1], edgecolor='#1a1a1a', linewidth=2)
162
-
163
- ax4.set_yticks(range(10))
164
- ax4.set_yticklabels(top10_labels[::-1], fontsize=14, color='#e0e0e0', fontweight='600')
165
- ax4.set_xlabel('Confidence (%)', fontsize=15, color='#e0e0e0', fontweight='700')
166
- ax4.set_title('Top 10 Predictions (Ensemble)', fontsize=19, fontweight='800', color='#e0e0e0', pad=20)
167
- ax4.set_xlim([0, 100])
168
- ax4.grid(axis='x', alpha=0.2, color='#404040', linestyle='--')
169
- ax4.set_facecolor('#0a0a0a')
170
- ax4.spines['top'].set_visible(False)
171
- ax4.spines['right'].set_visible(False)
172
- ax4.spines['left'].set_color('#404040')
173
- ax4.spines['bottom'].set_color('#404040')
174
- ax4.tick_params(colors='#a0a0a0', labelsize=13)
175
-
176
- for bar, prob in zip(bars, top10_probs[::-1]):
177
- ax4.text(prob + 1.5, bar.get_y() + bar.get_height()/2,
178
- f'{prob:.1f}%', va='center', fontsize=13, color='#e0e0e0', fontweight='700')
179
-
180
- plt.tight_layout()
181
-
182
- buf = BytesIO()
183
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='#0a0a0a')
184
- buf.seek(0)
185
- result_image = Image.open(buf)
186
- plt.close(fig)
187
-
188
- # Detailed comparison: Individual model heatmaps
189
- fig2, axes = plt.subplots(2, 2, figsize=(20, 18))
190
- fig2.patch.set_facecolor('#0a0a0a')
191
-
192
- axes[0, 0].imshow(image)
193
- axes[0, 0].set_title("Original Image", fontsize=17, fontweight='700', color='#e0e0e0', pad=15)
194
- axes[0, 0].axis('off')
195
-
196
- axes[0, 1].imshow(image)
197
- axes[0, 1].imshow(attr_np, cmap='jet', alpha=0.6, interpolation='bilinear')
198
- axes[0, 1].set_title("EfficientNet-B4 Focus", fontsize=17, fontweight='700', color='#e0e0e0', pad=15)
199
- axes[0, 1].axis('off')
200
-
201
- axes[1, 0].imshow(image)
202
- axes[1, 0].imshow(attr_resnet_np, cmap='hot', alpha=0.6, interpolation='bilinear')
203
- axes[1, 0].set_title("ResNet152 Focus", fontsize=17, fontweight='700', color='#e0e0e0', pad=15)
204
- axes[1, 0].axis('off')
205
-
206
- axes[1, 1].imshow(image)
207
- axes[1, 1].imshow(attr_avg, cmap='viridis', alpha=0.6, interpolation='gaussian')
208
- axes[1, 1].contour(attr_avg, levels=6, colors='white', linewidths=2, alpha=0.9)
209
- axes[1, 1].set_title("Ensemble Average + Contours", fontsize=17, fontweight='700', color='#e0e0e0', pad=15)
210
- axes[1, 1].axis('off')
211
-
212
- plt.tight_layout()
213
-
214
- buf2 = BytesIO()
215
- plt.savefig(buf2, format='png', dpi=140, bbox_inches='tight', facecolor='#0a0a0a')
216
- buf2.seek(0)
217
- detailed_heatmap = Image.open(buf2)
218
- plt.close(fig2)
219
-
220
- # Enhanced prediction card with model breakdown
221
- badge = "high" if confidence > 0.8 else "medium" if confidence > 0.5 else "low"
222
- badge_text = "High Confidence" if confidence > 0.8 else "Medium Confidence" if confidence > 0.5 else "Low Confidence"
223
- badge_icon = "🎯" if confidence > 0.8 else "⚡" if confidence > 0.5 else "⚠️"
224
-
225
- # Individual model predictions for transparency
226
- individual_preds = []
227
- for i, (pred, name) in enumerate(zip(predictions[:3], model_names[:3])):
228
- top1_idx = pred[0].argmax().item()
229
- top1_prob = pred[0, top1_idx].item() * 100
230
- individual_preds.append(f"""
231
- <div class='model-pred'>
232
- <span class='model-name'>{name}</span>
233
- <span class='model-class'>{IMAGENET_LABELS[top1_idx]}</span>
234
- <span class='model-conf'>{top1_prob:.1f}%</span>
235
- </div>
236
- """)
237
-
238
- top5_html = "<div class='top5-grid'>"
239
- icons = ["🥇", "🥈", "🥉", "4️⃣", "5️⃣"]
240
- for i, (prob, idx) in enumerate(zip(top10_prob[0][:5], top10_idx[0][:5])):
241
- pct = prob.item() * 100
242
- top5_html += f"""
243
- <div class='top5-row'>
244
- <span class='rank'>{icons[i]}</span>
245
- <span class='label'>{IMAGENET_LABELS[idx.item()]}</span>
246
- <div class='bar-wrap'><div class='bar' style='width:{pct}%'></div></div>
247
- <span class='pct'>{pct:.2f}%</span>
248
- </div>"""
249
- top5_html += "</div>"
250
-
251
- prediction_text = f"""
252
- <div class="result-card">
253
- <div class="pred-header">
254
- <h2 class="pred-label">{IMAGENET_LABELS[pred_class]}</h2>
255
- <div class="badge badge-{badge}">{badge_icon} {badge_text}</div>
256
- </div>
257
- <div class="conf-score">{confidence*100:.2f}%</div>
258
- <div class="ensemble-tag">🔬 Multi-Model Ensemble (3 Networks + TTA)</div>
259
- <div class="divider"></div>
260
- <div class="model-breakdown">
261
- <h3>Individual Model Predictions:</h3>
262
- {''.join(individual_preds)}
263
- </div>
264
- <div class="divider"></div>
265
- {top5_html}
266
- </div>"""
267
-
268
- return prediction_text, result_image, detailed_heatmap
269
-
270
- except Exception as e:
271
- return f"<div class='error-msg'>⚠️ Error: {str(e)}</div>", None, None
272
-
273
-
274
- custom_css = """
275
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800;900&display=swap');
276
- * { box-sizing: border-box; margin: 0; padding: 0; }
277
- body, .gradio-container { margin: 0 !important; padding: 0 !important; width: 100vw !important; min-height: 100vh !important; max-width: 100vw !important; background: linear-gradient(135deg, #0a0a0a 0%, #1a1a1a 50%, #0f0f0f 100%) !important; font-family: 'Inter', sans-serif !important; color: #e0e0e0 !important; overflow-x: hidden !important; }
278
- .gradio-container { padding: 0 !important; }
279
- .main-wrapper { padding: 1.5rem; max-width: 1920px; margin: 0 auto; position: relative; z-index: 2; }
280
- .hero-header { text-align: center; padding: 2rem 1rem 1.5rem; margin-bottom: 1.5rem; position: relative; }
281
- .hero-header::before { content: ''; position: absolute; top: 0; left: 50%; transform: translateX(-50%); width: 300px; height: 300px; background: radial-gradient(circle, rgba(59, 130, 246, 0.15), transparent); filter: blur(80px); z-index: -1; }
282
- .hero-header h1 { font-size: clamp(2rem, 5vw, 3.5rem); font-weight: 900; background-color: #d8b4fe; -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin: 0 0 0.5rem; letter-spacing: -1px; }
283
- .hero-header .subtitle { font-size: clamp(0.95rem, 2vw, 1.2rem); color: #808080; font-weight: 400; margin: 0 0 0.5rem; }
284
- .hero-header .model-tag { display: inline-block; background: #93c5fd; border: 1px solid rgba(59, 130, 246, 0.3); color: #3b82f6; padding: 0.5rem 1.5rem; border-radius: 50px; font-size: 0.85rem; font-weight: 700; letter-spacing: 0.5px; margin-top: 0.5rem; }
285
- .top-section { display: grid; grid-template-columns: 400px 1fr; gap: 1.25rem; margin-bottom: 1.25rem; }
286
- .upload-panel, .results-panel, .viz-section { background: rgba(20, 20, 20, 0.8); border: 1px solid rgba(255, 255, 255, 0.1); border-radius: 24px; padding: 1.5rem; backdrop-filter: blur(20px); box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4); }
287
- .section-label { font-size: 1.1rem; font-weight: 700; background: #93c5fd; -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin: 0 0 1rem; text-align: center; letter-spacing: 0.5px; }
288
- #input-image { border: 2px dashed rgba(59, 130, 246, 0.4) !important; border-radius: 20px !important; background: rgba(10, 10, 10, 0.6) !important; height: 320px !important; transition: all 0.3s ease; }
289
- #input-image:hover { border-color: #3b82f6 !important; background: rgba(20, 20, 30, 0.8) !important; transform: scale(1.02); box-shadow: 0 0 30px rgba(59, 130, 246, 0.2); }
290
- .btn-row { display: flex; gap: 0.75rem; margin-top: 1rem; }
291
- .gr-button { border-radius: 14px !important; font-weight: 700 !important; height: 50px !important; font-size: 0.95rem !important; transition: all 0.3s ease !important; border: none !important; letter-spacing: 0.5px; text-transform: uppercase; }
292
- .gr-button-primary { background: linear-gradient(135deg, #3b82f6, #8b5cf6) !important; color: white !important; box-shadow: 0 4px 20px rgba(59, 130, 246, 0.4) !important; }
293
- .gr-button-primary:hover { transform: translateY(-3px) !important; box-shadow: 0 8px 30px rgba(59, 130, 246, 0.6) !important; }
294
- .gr-button-secondary { background: rgba(40, 40, 40, 0.8) !important; color: #a0a0a0 !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; }
295
- .pred-header { display: flex; align-items: center; justify-content: space-between; flex-wrap: wrap; gap: 1rem; margin-bottom: 0.75rem; }
296
- .pred-label { font-size: clamp(1.5rem, 3vw, 2rem); font-weight: 900; color: #ffffff; margin: 0; letter-spacing: -0.5px; }
297
- .badge { padding: 0.5rem 1.25rem; border-radius: 50px; font-size: 0.875rem; font-weight: 700; text-transform: uppercase; letter-spacing: 0.5px; box-shadow: 0 4px 15px rgba(0, 0, 0, 0.3); }
298
- .badge-high { background: linear-gradient(135deg, #10b981, #059669); color: white; }
299
- .badge-medium { background: linear-gradient(135deg, #f59e0b, #d97706); color: white; }
300
- .badge-low { background: linear-gradient(135deg, #ef4444, #dc2626); color: white; }
301
- .conf-score { font-size: clamp(2rem, 5vw, 3rem); font-weight: 900; background: linear-gradient(135deg, #3b82f6, #8b5cf6); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin-bottom: 1rem; letter-spacing: -1px; }
302
- .ensemble-tag { background: rgba(16, 185, 129, 0.15); border: 1px solid rgba(16, 185, 129, 0.3); color: #10b981; padding: 0.5rem 1rem; border-radius: 12px; font-size: 0.8rem; font-weight: 700; text-align: center; margin-bottom: 1rem; }
303
- .model-breakdown { background: rgba(30, 30, 30, 0.6); padding: 1rem; border-radius: 12px; margin-bottom: 1rem; }
304
- .model-breakdown h3 { font-size: 0.95rem; color: #a0a0a0; margin-bottom: 0.75rem; font-weight: 600; }
305
- .model-pred { display: grid; grid-template-columns: 140px 1fr auto; gap: 0.75rem; align-items: center; padding: 0.5rem; border-radius: 8px; background: rgba(40, 40, 40, 0.4); margin-bottom: 0.5rem; }
306
- .model-name { color: #3b82f6; font-weight: 700; font-size: 0.85rem; }
307
- .model-class { color: #e0e0e0; font-size: 0.875rem; font-weight: 500; }
308
- .model-conf { color: #10b981; font-weight: 700; font-size: 0.875rem; }
309
- .divider { height: 2px; background: linear-gradient(90deg, transparent, rgba(59, 130, 246, 0.3), transparent); margin: 1.5rem 0; }
310
- .top5-grid { display: flex; flex-direction: column; gap: 0.875rem; }
311
- .top5-row { display: grid; grid-template-columns: 40px 1fr auto 80px; align-items: center; gap: 0.875rem; font-size: 0.95rem; padding: 0.5rem; border-radius: 12px; background: rgba(30, 30, 30, 0.5); transition: all 0.3s ease; }
312
- .top5-row:hover { background: rgba(40, 40, 40, 0.7); transform: translateX(5px); }
313
- .rank { font-size: 1.5rem; text-align: center; }
314
- .label { color: #e0e0e0; font-weight: 600; white-space: nowrap; overflow: hidden; text-overflow: ellipsis; }
315
- .bar-wrap { background: rgba(40, 40, 40, 0.8); height: 10px; border-radius: 5px; overflow: hidden; min-width: 100px; box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.3); }
316
- .bar { background: linear-gradient(90deg, #3b82f6, #8b5cf6); height: 100%; transition: width 1s ease; border-radius: 5px; box-shadow: 0 0 10px rgba(59, 130, 246, 0.5); }
317
- .pct { color: #3b82f6; font-weight: 700; font-size: 0.9rem; text-align: right; }
318
- #result-image, #detailed-heatmap { border-radius: 16px !important; overflow: hidden; width: 100% !important; height: auto !important; min-height: 500px !important; box-shadow: 0 8px 32px rgba(0, 0, 0, 0.5); object-fit: contain !important; }
319
- .placeholder { text-align: center; padding: 4rem 1.5rem; color: #606060; font-size: 1.1rem; line-height: 1.6; }
320
- .placeholder strong { color: #3b82f6; }
321
- .error-msg { color: #ef4444; background: rgba(239, 68, 68, 0.1); padding: 1.5rem; border-radius: 16px; text-align: center; border: 1px solid rgba(239, 68, 68, 0.3); }
322
- footer, .footer { display: none !important; }
323
- ::-webkit-scrollbar { width: 10px; }
324
- ::-webkit-scrollbar-track { background: rgba(20, 20, 20, 0.5); }
325
- ::-webkit-scrollbar-thumb { background: rgba(59, 130, 246, 0.5); border-radius: 5px; }
326
- @media (max-width: 768px) {
327
- .top-section { grid-template-columns: 1fr; }
328
- #input-image { height: 240px !important; }
329
- .top5-row { grid-template-columns: 35px 1fr 70px; }
330
- .bar-wrap { grid-column: 1 / -1; margin-top: 0.375rem; }
331
- #result-image { min-height: 600px !important; max-height: none !important; }
332
- #detailed-heatmap { min-height: 450px !important; max-height: none !important; }
333
- .viz-section { padding: 1rem; }
334
- .section-label { font-size: 1rem; }
335
- .model-pred { grid-template-columns: 1fr; gap: 0.25rem; }
336
- }
337
- @media (max-width: 480px) {
338
- .main-wrapper { padding: 1rem; }
339
- #result-image { min-height: 550px !important; }
340
- #detailed-heatmap { min-height: 400px !important; }
341
- }
342
- """
343
-
344
-
345
- with gr.Blocks(css=custom_css, theme=gr.themes.Base(), title="Advanced XAI Classifier") as demo:
346
- gr.HTML('<link rel="icon" href="https://res.cloudinary.com/ddn0xuwut/image/upload/v1761284764/encryption_hc0fxo.png" type="image/png">')
347
-
348
- with gr.Column(elem_classes="main-wrapper"):
349
- gr.HTML('''
350
- <div class="hero-header">
351
- <h1>Advanced XAI Classifier</h1>
352
- <p class="subtitle">Multi-Model Ensemble with Test-Time Augmentation</p>
353
- <div class="model-tag">⚡ EfficientNet-B4 + ResNet152 + ConvNeXt</div>
354
- </div>
355
- ''')
356
-
357
- with gr.Row(elem_classes="top-section"):
358
- with gr.Column(scale=0, min_width=400, elem_classes="upload-panel"):
359
- gr.HTML("<div class='section-label'>📤 Upload Image</div>")
360
- input_image = gr.Image(type="pil", label=None, elem_id="input-image", show_label=False, container=False)
361
- with gr.Row(elem_classes="btn-row"):
362
- predict_btn = gr.Button("🚀 Analyze", variant="primary", size="lg", scale=2)
363
- clear_btn = gr.ClearButton([input_image], value="🗑️ Clear", size="lg", scale=1)
364
-
365
- with gr.Column(scale=1, elem_classes="results-panel"):
366
- output_text = gr.HTML('<div class="placeholder"><strong>👋 Welcome to Advanced XAI!</strong><br><br>This classifier uses 3 state-of-the-art models:<br>• EfficientNet-B4 (40%)<br>• ResNet152 (30%)<br>• ConvNeXt-Base (25%)<br>• + Test-Time Augmentation (5%)<br><br>Upload an image to see the magic! ✨</div>')
367
-
368
- with gr.Column(elem_classes="viz-section"):
369
- gr.HTML("<div class='section-label'>🎯 Ensemble Visual Explainability</div>")
370
- output_image = gr.Image(label=None, type="pil", show_label=False, elem_id="result-image", container=False)
371
-
372
- with gr.Column(elem_classes="viz-section"):
373
- gr.HTML("<div class='section-label'>🔬 Model Comparison Analysis</div>")
374
- detailed_heatmap = gr.Image(label=None, type="pil", show_label=False, elem_id="detailed-heatmap", container=False)
375
-
376
- predict_btn.click(fn=predict_and_explain, inputs=[input_image], outputs=[output_text, output_image, detailed_heatmap])
377
 
378
- if __name__ == "__main__":
379
- demo.launch(share=False, show_error=True)
 
1
+ import gradio as gr
2
  import torch
 
3
  from torchvision import models, transforms
4
  from PIL import Image
 
 
 
5
  import matplotlib.pyplot as plt
6
+ import numpy as np
7
  from io import BytesIO
 
 
 
 
 
8
 
9
+ # Minimal working version
10
+ def minimal_predict(image):
11
+ if image is None:
12
+ return "Please upload an image", None
13
 
14
+ # Simple transform
15
+ transform = transforms.Compose([
16
+ transforms.Resize((224, 224)),
17
+ transforms.ToTensor(),
18
+ ])
19
 
20
+ img_tensor = transform(image).unsqueeze(0)
 
 
21
 
22
+ # Create a simple visualization
23
+ fig, ax = plt.subplots(1, 1, figsize=(8, 6))
24
+ ax.imshow(image)
25
+ ax.set_title("Processed Image")
26
+ ax.axis('off')
27
 
28
+ buf = BytesIO()
29
+ plt.savefig(buf, format='png')
30
+ buf.seek(0)
31
+ result_img = Image.open(buf)
32
+ plt.close(fig)
33
 
34
+ return "Analysis complete", result_img
35
+
36
+ with gr.Blocks() as demo:
37
+ gr.Markdown("# Simple Image Classifier")
38
+ with gr.Row():
39
+ with gr.Column():
40
+ image_input = gr.Image(type="pil")
41
+ analyze_btn = gr.Button("Analyze")
42
+ with gr.Column():
43
+ text_output = gr.Textbox()
44
+ image_output = gr.Image()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ analyze_btn.click(fn=minimal_predict, inputs=[image_input], outputs=[text_output, image_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ demo.launch(share=False)