tanh1c commited on
Commit
d13c106
·
1 Parent(s): 94071ae

Add Gradio image demo

Browse files
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pth filter=lfs diff=lfs merge=lfs -text
2
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
3
+ *.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import runpy
2
+
3
+
4
+ ns = runpy.run_path(
5
+ "assignments/assignment-1/app/main.py",
6
+ run_name="hf_space_app",
7
+ )
8
+ demo = ns["create_app"]()
9
+
10
+
11
+ if __name__ == "__main__":
12
+ demo.launch()
assignments/assignment-1/app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # DL Assignment 1 - Application Demo
assignments/assignment-1/app/assets/style.css ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Deep Learning Assignment - Custom Style System
3
+ =============================================
4
+ Designed for a premium, dark-themed experience.
5
+ */
6
+
7
+ /* Main container stabilization */
8
+ .gradio-container {
9
+ max-width: 1200px !important;
10
+ margin: 0 auto !important;
11
+ font-family: 'Inter', system-ui, -apple-system, sans-serif !important;
12
+ }
13
+
14
+ /* Header & Title Styling */
15
+ .app-header {
16
+ text-align: center;
17
+ padding: 30px 0;
18
+ border-bottom: 1px solid #30363d;
19
+ margin-bottom: 30px;
20
+ background: linear-gradient(to bottom, #161b22, #0d1117);
21
+ }
22
+
23
+ .app-header h1 {
24
+ font-weight: 800 !important;
25
+ letter-spacing: -0.02em;
26
+ background: linear-gradient(135deg, #58a6ff 0%, #bc8cff 100%);
27
+ -webkit-background-clip: text;
28
+ background-clip: text;
29
+ -webkit-text-fill-color: transparent;
30
+ }
31
+
32
+ /* Model Info Display */
33
+ .model-info-box {
34
+ background: #161b22;
35
+ border: 1px solid #30363d;
36
+ border-radius: 12px;
37
+ padding: 24px;
38
+ margin: 15px 0;
39
+ box-shadow: 0 4px 20px rgba(0,0,0,0.3);
40
+ }
41
+
42
+ /* Prediction Result Premium Card */
43
+ .prediction-label {
44
+ font-size: 26px !important;
45
+ font-weight: 700 !important;
46
+ text-align: center;
47
+ padding: 20px;
48
+ background: linear-gradient(135deg, #238636 0%, #2ea043 100%);
49
+ border-radius: 12px;
50
+ color: white !important;
51
+ margin: 15px 0;
52
+ box-shadow: 0 8px 32px rgba(35, 134, 54, 0.2);
53
+ border: 1px solid rgba(255,255,255,0.1);
54
+ }
55
+
56
+ /* Confidence Bars & Progress */
57
+ .confidence-bar {
58
+ height: 32px;
59
+ border-radius: 8px;
60
+ background-color: #21262d;
61
+ overflow: hidden;
62
+ margin: 8px 0;
63
+ border: 1px solid #30363d;
64
+ }
65
+
66
+ /* Modern Tabs Navigation */
67
+ .tab-nav {
68
+ border-bottom: 1px solid #30363d !important;
69
+ margin-bottom: 20px !important;
70
+ }
71
+
72
+ .tab-nav button {
73
+ font-size: 15px !important;
74
+ font-weight: 600 !important;
75
+ padding: 14px 28px !important;
76
+ color: #8b949e !important;
77
+ transition: all 0.2s ease !important;
78
+ }
79
+
80
+ .tab-nav button:hover {
81
+ color: #f0f6fc !important;
82
+ background-color: rgba(139, 148, 158, 0.1) !important;
83
+ }
84
+
85
+ .tab-nav button.selected {
86
+ color: #58a6ff !important;
87
+ border-bottom: 2px solid #58a6ff !important;
88
+ background: transparent !important;
89
+ }
90
+
91
+ /* Calibration Metric Cards */
92
+ .metric-card {
93
+ background: #161b22;
94
+ border: 1px solid #30363d;
95
+ border-radius: 12px;
96
+ padding: 25px;
97
+ text-align: center;
98
+ transition: transform 0.2s ease;
99
+ }
100
+
101
+ .metric-card:hover {
102
+ transform: translateY(-2px);
103
+ border-color: #58a6ff;
104
+ }
105
+
106
+ /* Custom Buttons Styling */
107
+ .gr-button-primary {
108
+ background: linear-gradient(135deg, #1f6feb 0%, #58a6ff 100%) !important;
109
+ border: none !important;
110
+ font-weight: 600 !important;
111
+ box-shadow: 0 4px 12px rgba(31, 111, 235, 0.3) !important;
112
+ }
113
+
114
+ .gr-button-primary:hover {
115
+ filter: brightness(1.1);
116
+ transform: translateY(-1px);
117
+ }
118
+
119
+ /* Footer Section */
120
+ .app-footer {
121
+ text-align: center;
122
+ padding: 40px 20px;
123
+ color: #8b949e;
124
+ font-size: 14px;
125
+ border-top: 1px solid #30363d;
126
+ margin-top: 40px;
127
+ opacity: 0.8;
128
+ }
129
+
130
+ /* Glassmorphism utility */
131
+ .glass {
132
+ background: rgba(22, 27, 34, 0.7) !important;
133
+ backdrop-filter: blur(10px) !important;
134
+ border: 1px solid rgba(48, 54, 61, 0.5) !important;
135
+ }
assignments/assignment-1/app/image/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Image handlers for Assignment 1."""
assignments/assignment-1/app/image/data.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for loading the CIFAR-10 test split from local project assets.
3
+
4
+ The workspace keeps the archive at ``image/data/cifar-10-python.tar.gz``.
5
+ Reading the test batch directly from that archive avoids permission issues
6
+ with extracted files while keeping calibration fully offline.
7
+ """
8
+
9
+ import os
10
+ import pickle
11
+ import tarfile
12
+ from functools import lru_cache
13
+ from typing import Tuple
14
+
15
+ import numpy as np
16
+ from PIL import Image
17
+ from torch.utils.data import Dataset
18
+
19
+
20
+ ASSIGNMENT_ROOT = os.path.dirname(
21
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
22
+ )
23
+ DEFAULT_DATA_DIR = os.path.join(ASSIGNMENT_ROOT, "image", "data")
24
+ DEFAULT_ARCHIVE_PATH = os.path.join(DEFAULT_DATA_DIR, "cifar-10-python.tar.gz")
25
+
26
+
27
+ @lru_cache(maxsize=1)
28
+ def load_cifar10_test_arrays(
29
+ archive_path: str = DEFAULT_ARCHIVE_PATH,
30
+ ) -> Tuple[np.ndarray, np.ndarray]:
31
+ """Load CIFAR-10 test images and labels from the local archive."""
32
+ if not os.path.exists(archive_path):
33
+ raise FileNotFoundError(
34
+ f"CIFAR-10 archive not found at {archive_path}. "
35
+ "Expected image/data/cifar-10-python.tar.gz to exist."
36
+ )
37
+
38
+ with tarfile.open(archive_path, "r:gz") as tar:
39
+ member = tar.extractfile("cifar-10-batches-py/test_batch")
40
+ if member is None:
41
+ raise FileNotFoundError(
42
+ "Could not find cifar-10-batches-py/test_batch inside the archive."
43
+ )
44
+
45
+ batch = pickle.load(member, encoding="bytes")
46
+
47
+ images = batch[b"data"].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
48
+ labels = np.asarray(batch[b"labels"], dtype=np.int64)
49
+ return images, labels
50
+
51
+
52
+ class LocalCIFAR10TestDataset(Dataset):
53
+ """Dataset wrapper that serves the CIFAR-10 test split from local files."""
54
+
55
+ def __init__(self, transform=None, archive_path: str = DEFAULT_ARCHIVE_PATH):
56
+ self.transform = transform
57
+ self.images, self.labels = load_cifar10_test_arrays(archive_path)
58
+
59
+ def __len__(self) -> int:
60
+ return len(self.labels)
61
+
62
+ def __getitem__(self, idx: int):
63
+ image = Image.fromarray(self.images[idx])
64
+ label = int(self.labels[idx])
65
+
66
+ if self.transform is not None:
67
+ image = self.transform(image)
68
+
69
+ return image, label
70
+
71
+
72
+ def create_cifar10_test_dataset(transform=None) -> LocalCIFAR10TestDataset:
73
+ """Create the CIFAR-10 test dataset used by the calibration tab."""
74
+ return LocalCIFAR10TestDataset(transform=transform)
assignments/assignment-1/app/image/resnet18.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CIFAR-10 ResNet-18 Model Handler
3
+
4
+ Handles prediction, Grad-CAM visualization, and calibration
5
+ for the ResNet-18 model trained on CIFAR-10.
6
+ """
7
+
8
+ import os
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ from PIL import Image
13
+ from typing import Dict, List, Optional, Any
14
+ import torchvision.transforms as transforms
15
+ from torchvision.models import resnet18
16
+
17
+ from app.shared.model_registry import (
18
+ BaseModelHandler,
19
+ PredictionResult,
20
+ CalibrationResult,
21
+ )
22
+ from app.shared.artifact_utils import (
23
+ get_best_accuracy_from_history,
24
+ load_precomputed_calibration_result,
25
+ )
26
+ from app.image.data import create_cifar10_test_dataset
27
+
28
+ # CIFAR-10 class labels
29
+ CIFAR10_LABELS = [
30
+ 'Airplane', 'Automobile', 'Bird', 'Cat', 'Deer',
31
+ 'Dog', 'Frog', 'Horse', 'Ship', 'Truck'
32
+ ]
33
+
34
+ # CIFAR-10 normalization values
35
+ CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
36
+ CIFAR10_STD = (0.2470, 0.2435, 0.2616)
37
+
38
+ # Image size ResNet expects
39
+ IMAGE_SIZE = 224
40
+
41
+
42
+ def create_resnet18_cifar10(num_classes=10):
43
+ """Create ResNet-18 with modified classifier for CIFAR-10."""
44
+ model = resnet18(weights=None)
45
+ num_features = model.fc.in_features
46
+ model.fc = nn.Linear(num_features, num_classes)
47
+ return model
48
+
49
+
50
+ class GradCAM:
51
+ """
52
+ Grad-CAM implementation for visual explanation.
53
+ Generates heatmap showing which regions the model focuses on.
54
+ """
55
+
56
+ def __init__(self, model, target_layer):
57
+ self.model = model
58
+ self.target_layer = target_layer
59
+ self.gradients = None
60
+ self.activations = None
61
+ self._register_hooks()
62
+
63
+ def _register_hooks(self):
64
+ def forward_hook(module, input, output):
65
+ self.activations = output.detach()
66
+
67
+ def backward_hook(module, grad_input, grad_output):
68
+ self.gradients = grad_output[0].detach()
69
+
70
+ self.target_layer.register_forward_hook(forward_hook)
71
+ self.target_layer.register_full_backward_hook(backward_hook)
72
+
73
+ def generate(self, input_tensor, target_class=None):
74
+ """Generate Grad-CAM heatmap."""
75
+ self.model.eval()
76
+ output = self.model(input_tensor)
77
+
78
+ if target_class is None:
79
+ target_class = output.argmax(dim=1).item()
80
+
81
+ self.model.zero_grad()
82
+ one_hot = torch.zeros_like(output)
83
+ one_hot[0, target_class] = 1.0
84
+ output.backward(gradient=one_hot, retain_graph=True)
85
+
86
+ # Pool gradients across spatial dimensions
87
+ weights = self.gradients.mean(dim=[2, 3], keepdim=True)
88
+ cam = (weights * self.activations).sum(dim=1, keepdim=True)
89
+ cam = torch.relu(cam)
90
+
91
+ # Normalize
92
+ cam = cam - cam.min()
93
+ if cam.max() > 0:
94
+ cam = cam / cam.max()
95
+
96
+ # Resize to input size
97
+ cam = torch.nn.functional.interpolate(
98
+ cam, size=(IMAGE_SIZE, IMAGE_SIZE), mode='bilinear', align_corners=False
99
+ )
100
+ return cam.squeeze().cpu().numpy()
101
+
102
+
103
+ def create_gradcam_overlay(image_np, heatmap, alpha=0.5):
104
+ """Create overlay of Grad-CAM heatmap on original image."""
105
+ import matplotlib
106
+ matplotlib.use('Agg')
107
+ import matplotlib.pyplot as plt
108
+ import matplotlib.cm as cm
109
+
110
+ # Apply colormap to heatmap
111
+ colormap = cm.jet(heatmap)[:, :, :3] # Remove alpha channel
112
+ colormap = (colormap * 255).astype(np.uint8)
113
+
114
+ # Resize image to match heatmap
115
+ if image_np.shape[:2] != (IMAGE_SIZE, IMAGE_SIZE):
116
+ img_pil = Image.fromarray(image_np).resize((IMAGE_SIZE, IMAGE_SIZE))
117
+ image_np = np.array(img_pil)
118
+
119
+ # Create overlay
120
+ overlay = (alpha * colormap + (1 - alpha) * image_np).astype(np.uint8)
121
+
122
+ # Create figure with original + heatmap + overlay
123
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
124
+ fig.patch.set_facecolor('#0d1117')
125
+
126
+ titles = ['Original Image', 'Grad-CAM Heatmap', 'Overlay']
127
+ images = [image_np, colormap, overlay]
128
+
129
+ for ax, img, title in zip(axes, images, titles):
130
+ ax.imshow(img)
131
+ ax.set_title(title, color='white', fontsize=14, fontweight='bold', pad=10)
132
+ ax.axis('off')
133
+ ax.set_facecolor('#0d1117')
134
+
135
+ plt.tight_layout(pad=2)
136
+
137
+ # Convert figure to numpy array
138
+ fig.canvas.draw()
139
+ # Use buffer_rgba() which is more robust in newer matplotlib versions
140
+ rgba_buffer = fig.canvas.buffer_rgba()
141
+ result = np.array(rgba_buffer)[:, :, :3] # Strip alpha channel
142
+ plt.close(fig)
143
+
144
+ return result
145
+
146
+
147
+ class Cifar10ResNet18Handler(BaseModelHandler):
148
+ """Model handler for CIFAR-10 ResNet-18."""
149
+
150
+ def __init__(self, model_path: str):
151
+ self.model_path = model_path
152
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
153
+ self.model = None
154
+ self.grad_cam = None
155
+ self.history = {}
156
+ self.config = {}
157
+ self.best_accuracy = None
158
+ self._calibration_cache = {}
159
+ self.transform = transforms.Compose([
160
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
161
+ transforms.ToTensor(),
162
+ transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
163
+ ])
164
+ self._load_model()
165
+
166
+ def _load_model(self):
167
+ """Load the trained model."""
168
+ self.model = create_resnet18_cifar10(num_classes=10)
169
+
170
+ if os.path.exists(self.model_path):
171
+ checkpoint = torch.load(self.model_path, map_location=self.device,
172
+ weights_only=True)
173
+ if isinstance(checkpoint, dict):
174
+ self.history = checkpoint.get('history', {}) or {}
175
+ self.config = checkpoint.get('config', {}) or {}
176
+ self.best_accuracy = get_best_accuracy_from_history(self.history)
177
+ # Handle both state_dict and full model saves
178
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
179
+ self.model.load_state_dict(checkpoint['model_state_dict'])
180
+ elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
181
+ self.model.load_state_dict(checkpoint['state_dict'])
182
+ else:
183
+ self.model.load_state_dict(checkpoint)
184
+
185
+ self.model = self.model.to(self.device)
186
+ self.model.eval()
187
+
188
+ # Initialize Grad-CAM with the last conv layer
189
+ self.grad_cam = GradCAM(self.model, self.model.layer4[-1])
190
+
191
+ precomputed_full = load_precomputed_calibration_result("resnet18")
192
+ if precomputed_full is not None:
193
+ self._calibration_cache["full"] = precomputed_full
194
+
195
+ def get_model_name(self) -> str:
196
+ return "ResNet-18"
197
+
198
+ def get_dataset_name(self) -> str:
199
+ return "CIFAR-10"
200
+
201
+ def get_data_type(self) -> str:
202
+ return "image"
203
+
204
+ def get_class_labels(self) -> List[str]:
205
+ return CIFAR10_LABELS
206
+
207
+ def get_model_info(self) -> Dict[str, str]:
208
+ total_params = sum(p.numel() for p in self.model.parameters())
209
+ best_accuracy = (
210
+ f"{self.best_accuracy:.2f}%"
211
+ if self.best_accuracy is not None
212
+ else "N/A"
213
+ )
214
+ info = {
215
+ "Architecture": "ResNet-18 (Transfer Learning from ImageNet)",
216
+ "Dataset": "CIFAR-10 (10 classes, 60,000 images)",
217
+ "Parameters": f"{total_params:,}",
218
+ "Input Size": f"{IMAGE_SIZE}×{IMAGE_SIZE}×3",
219
+ "Training": "Full fine-tune, AdamW, Cosine Annealing LR",
220
+ "Best Accuracy": best_accuracy,
221
+ "Device": str(self.device),
222
+ }
223
+ if "epochs" in self.config:
224
+ info["Epochs"] = str(self.config["epochs"])
225
+ full_result = self._calibration_cache.get("full")
226
+ if full_result is not None:
227
+ info["Full-Test ECE"] = f"{full_result.ece:.6f}"
228
+ return info
229
+
230
+ def predict(self, input_data) -> PredictionResult:
231
+ """Run prediction with Grad-CAM visualization."""
232
+ if input_data is None:
233
+ raise ValueError("No input image provided")
234
+
235
+ # Convert to PIL Image if numpy array
236
+ if isinstance(input_data, np.ndarray):
237
+ original_image = input_data.copy()
238
+ pil_image = Image.fromarray(input_data).convert('RGB')
239
+ else:
240
+ pil_image = input_data.convert('RGB')
241
+ original_image = np.array(pil_image)
242
+
243
+ # Preprocess
244
+ input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
245
+
246
+ # Forward pass
247
+ with torch.no_grad():
248
+ output = self.model(input_tensor)
249
+ probabilities = torch.softmax(output, dim=1)[0]
250
+
251
+ probs = probabilities.cpu().numpy()
252
+ pred_idx = probs.argmax()
253
+ pred_label = CIFAR10_LABELS[pred_idx]
254
+ pred_conf = float(probs[pred_idx])
255
+
256
+ # Generate Grad-CAM
257
+ # Need to re-run with gradients enabled
258
+ input_tensor_grad = self.transform(pil_image).unsqueeze(0).to(self.device)
259
+ input_tensor_grad.requires_grad_(True)
260
+
261
+ heatmap = self.grad_cam.generate(input_tensor_grad, target_class=pred_idx)
262
+ explanation_image = create_gradcam_overlay(original_image, heatmap)
263
+
264
+ return PredictionResult(
265
+ label=pred_label,
266
+ confidence=pred_conf,
267
+ all_labels=CIFAR10_LABELS,
268
+ all_confidences=probs.tolist(),
269
+ explanation_image=explanation_image,
270
+ )
271
+
272
+ def get_example_inputs(self) -> List[Any]:
273
+ """Return example images from CIFAR-10 test set if available."""
274
+ return []
275
+
276
+ def get_calibration_data(
277
+ self, max_samples: Optional[int] = None
278
+ ) -> Optional[CalibrationResult]:
279
+ """
280
+ Compute calibration metrics on test set.
281
+ This runs evaluation on the full test set - can be slow on CPU.
282
+ """
283
+ cache_key = "full" if max_samples is None else f"subset:{max_samples}"
284
+ if cache_key in self._calibration_cache:
285
+ return self._calibration_cache[cache_key]
286
+
287
+ try:
288
+ import matplotlib
289
+ matplotlib.use('Agg')
290
+ import matplotlib.pyplot as plt
291
+
292
+ test_dataset = create_cifar10_test_dataset(transform=self.transform)
293
+ if max_samples is not None and 0 < max_samples < len(test_dataset):
294
+ indices = np.linspace(
295
+ 0, len(test_dataset) - 1, num=max_samples, dtype=int
296
+ ).tolist()
297
+ test_dataset = torch.utils.data.Subset(test_dataset, indices)
298
+
299
+ test_loader = torch.utils.data.DataLoader(
300
+ test_dataset, batch_size=128, shuffle=False, num_workers=0
301
+ )
302
+
303
+ all_probs = []
304
+ all_preds = []
305
+ all_targets = []
306
+
307
+ self.model.eval()
308
+ with torch.inference_mode():
309
+ for inputs, targets in test_loader:
310
+ inputs = inputs.to(self.device)
311
+ outputs = self.model(inputs)
312
+ probs = torch.softmax(outputs, dim=1)
313
+ preds = outputs.argmax(1)
314
+
315
+ all_probs.extend(probs.cpu().numpy())
316
+ all_preds.extend(preds.cpu().numpy())
317
+ all_targets.extend(targets.numpy())
318
+
319
+ all_probs = np.array(all_probs)
320
+ all_preds = np.array(all_preds)
321
+ all_targets = np.array(all_targets)
322
+
323
+ # Compute ECE (Expected Calibration Error)
324
+ n_bins = 10
325
+ max_probs = np.max(all_probs, axis=1)
326
+ correctness = (all_preds == all_targets).astype(float)
327
+
328
+ bin_boundaries = np.linspace(0, 1, n_bins + 1)
329
+ bin_accuracies = []
330
+ bin_confidences = []
331
+ bin_counts = []
332
+
333
+ for i in range(n_bins):
334
+ lower = bin_boundaries[i]
335
+ upper = bin_boundaries[i + 1]
336
+ mask = (max_probs > lower) & (max_probs <= upper)
337
+ count = mask.sum()
338
+ bin_counts.append(int(count))
339
+
340
+ if count > 0:
341
+ bin_acc = correctness[mask].mean()
342
+ bin_conf = max_probs[mask].mean()
343
+ else:
344
+ bin_acc = 0.0
345
+ bin_conf = 0.0
346
+
347
+ bin_accuracies.append(float(bin_acc))
348
+ bin_confidences.append(float(bin_conf))
349
+
350
+ # Compute ECE
351
+ total = len(all_preds)
352
+ ece = sum(
353
+ (count / total) * abs(acc - conf)
354
+ for count, acc, conf in zip(bin_counts, bin_accuracies, bin_confidences)
355
+ )
356
+
357
+ # Create reliability diagram
358
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
359
+ fig.patch.set_facecolor('#0d1117')
360
+
361
+ # Reliability Diagram
362
+ ax1.set_facecolor('#161b22')
363
+ bin_centers = [(bin_boundaries[i] + bin_boundaries[i + 1]) / 2 for i in range(n_bins)]
364
+ width = 0.08
365
+
366
+ bars1 = ax1.bar(
367
+ [c - width/2 for c in bin_centers], bin_accuracies, width,
368
+ label='Accuracy', color='#58a6ff', alpha=0.9, edgecolor='#58a6ff'
369
+ )
370
+ bars2 = ax1.bar(
371
+ [c + width/2 for c in bin_centers], bin_confidences, width,
372
+ label='Avg Confidence', color='#f97583', alpha=0.9, edgecolor='#f97583'
373
+ )
374
+
375
+ ax1.plot([0, 1], [0, 1], '--', color='#8b949e', linewidth=2,
376
+ label='Perfect Calibration')
377
+ ax1.set_xlim(0, 1)
378
+ ax1.set_ylim(0, 1)
379
+ ax1.set_xlabel('Confidence', color='white', fontsize=12)
380
+ ax1.set_ylabel('Accuracy / Confidence', color='white', fontsize=12)
381
+ ax1.set_title(
382
+ f'Reliability Diagram (ECE: {ece:.4f})',
383
+ color='white', fontsize=14, fontweight='bold', pad=15
384
+ )
385
+ ax1.legend(facecolor='#161b22', edgecolor='#30363d',
386
+ labelcolor='white', fontsize=10)
387
+ ax1.tick_params(colors='white')
388
+ for spine in ax1.spines.values():
389
+ spine.set_edgecolor('#30363d')
390
+ ax1.grid(True, alpha=0.1, color='white')
391
+
392
+ # Confidence histogram
393
+ ax2.set_facecolor('#161b22')
394
+ ax2.bar(
395
+ bin_centers, [c / total for c in bin_counts], 0.08,
396
+ color='#56d364', alpha=0.9, edgecolor='#56d364'
397
+ )
398
+ ax2.set_xlim(0, 1)
399
+ ax2.set_xlabel('Confidence', color='white', fontsize=12)
400
+ ax2.set_ylabel('Fraction of Samples', color='white', fontsize=12)
401
+ ax2.set_title(
402
+ 'Confidence Distribution',
403
+ color='white', fontsize=14, fontweight='bold', pad=15
404
+ )
405
+ ax2.tick_params(colors='white')
406
+ for spine in ax2.spines.values():
407
+ spine.set_edgecolor('#30363d')
408
+ ax2.grid(True, alpha=0.1, color='white')
409
+
410
+ plt.tight_layout(pad=3)
411
+
412
+ # Convert to numpy
413
+ fig.canvas.draw()
414
+ rgba_buffer = fig.canvas.buffer_rgba()
415
+ diagram = np.array(rgba_buffer)[:, :, :3] # Strip alpha channel
416
+ plt.close(fig)
417
+
418
+ self._calibration_cache[cache_key] = CalibrationResult(
419
+ ece=ece,
420
+ bin_accuracies=bin_accuracies,
421
+ bin_confidences=bin_confidences,
422
+ bin_counts=bin_counts,
423
+ reliability_diagram=diagram,
424
+ source="Live computation",
425
+ )
426
+ return self._calibration_cache[cache_key]
427
+
428
+ except Exception as e:
429
+ print(f"Error computing calibration: {e}")
430
+ return None
assignments/assignment-1/app/image/vit_b16.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CIFAR-10 ViT-B/16 Model Handler
3
+
4
+ Handles prediction, Grad-CAM visualization, and calibration
5
+ for the ViT-B/16 model trained on CIFAR-10.
6
+ """
7
+
8
+ import os
9
+ import types
10
+ import torch
11
+ import torch.nn as nn
12
+ import numpy as np
13
+ from PIL import Image
14
+ from typing import Dict, List, Optional, Any
15
+ import torchvision.transforms as transforms
16
+ from torchvision.models import vit_b_16
17
+
18
+ from app.shared.model_registry import (
19
+ BaseModelHandler,
20
+ PredictionResult,
21
+ CalibrationResult,
22
+ )
23
+ from app.shared.artifact_utils import (
24
+ get_best_accuracy_from_history,
25
+ load_precomputed_calibration_result,
26
+ )
27
+ from app.image.data import create_cifar10_test_dataset
28
+
29
+ # CIFAR-10 class labels
30
+ CIFAR10_LABELS = [
31
+ 'airplane', 'automobile', 'bird', 'cat', 'deer',
32
+ 'dog', 'frog', 'horse', 'ship', 'truck'
33
+ ]
34
+
35
+ # CIFAR-10 normalization values
36
+ CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
37
+ CIFAR10_STD = (0.2470, 0.2435, 0.2616)
38
+
39
+ # Image size ViT expects
40
+ IMAGE_SIZE = 224
41
+
42
+
43
+ def create_vit_model(num_classes=10):
44
+ """Create ViT-B/16 with modified classifier for CIFAR-10."""
45
+ model = vit_b_16(weights=None)
46
+ # Replace classifier head
47
+ model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
48
+ return model
49
+
50
+
51
+ class ViTAttentionVisualizer:
52
+ """
53
+ Attention visualization for ViT.
54
+ Shows which patches the model attends to.
55
+ """
56
+
57
+ def __init__(self, model):
58
+ self.model = model
59
+ self.attentions = None
60
+ self._patch_last_encoder_block()
61
+
62
+ def _patch_last_encoder_block(self):
63
+ """
64
+ Torchvision's ViT encoder block calls MultiheadAttention with
65
+ need_weights=False, so a normal forward hook never receives attention
66
+ maps. We patch only the last block to request weights during inference.
67
+ """
68
+ last_block = self.model.encoder.layers[-1]
69
+ visualizer = self
70
+
71
+ def forward_with_attention(block, input_tensor):
72
+ torch._assert(
73
+ input_tensor.dim() == 3,
74
+ f"Expected (batch_size, seq_length, hidden_dim) got {input_tensor.shape}",
75
+ )
76
+
77
+ x = block.ln_1(input_tensor)
78
+ attn_output, attn_weights = block.self_attention(
79
+ x,
80
+ x,
81
+ x,
82
+ need_weights=True,
83
+ average_attn_weights=False,
84
+ )
85
+ visualizer.attentions = attn_weights.detach()
86
+
87
+ x = block.dropout(attn_output)
88
+ x = x + input_tensor
89
+
90
+ y = block.ln_2(x)
91
+ y = block.mlp(y)
92
+ return x + y
93
+
94
+ last_block.forward = types.MethodType(forward_with_attention, last_block)
95
+
96
+ def generate_attention_map(self, input_tensor):
97
+ """Generate attention map from input tensor."""
98
+ self.model.eval()
99
+
100
+ # Forward pass
101
+ with torch.no_grad():
102
+ _ = self.model(input_tensor)
103
+
104
+ if self.attentions is None:
105
+ return None
106
+
107
+ # Get the [CLS] token attention across all heads
108
+ # Shape: (batch, heads, seq_len, seq_len) -> take cls token row
109
+ cls_attention = self.attentions[0, :, 0, 1:].mean(dim=0) # Average over heads
110
+
111
+ # Reshape to patch grid (assuming 16x16 patches for 224x224 image)
112
+ num_patches = int(cls_attention.shape[0] ** 0.5)
113
+
114
+ if num_patches * num_patches != cls_attention.shape[0]:
115
+ # Fallback: just return raw attention
116
+ return cls_attention.cpu().numpy()
117
+
118
+ # Reshape to 2D grid
119
+ attention_map = cls_attention.reshape(num_patches, num_patches).cpu().numpy()
120
+
121
+ # Normalize
122
+ attention_map = attention_map - attention_map.min()
123
+ if attention_map.max() > 0:
124
+ attention_map = attention_map / attention_map.max()
125
+
126
+ return attention_map
127
+
128
+
129
+ def create_attention_overlay(image_np, attention_map, alpha=0.5):
130
+ """Create overlay of attention map on original image."""
131
+ import matplotlib
132
+ matplotlib.use('Agg')
133
+ import matplotlib.pyplot as plt
134
+ import matplotlib.cm as cm
135
+
136
+ if attention_map is None:
137
+ return image_np
138
+
139
+ # Resize attention map to image size
140
+ from PIL import Image as PILImage
141
+ attention_uint8 = (attention_map * 255).astype(np.uint8)
142
+ attention_resized = PILImage.fromarray(attention_uint8).resize(
143
+ (IMAGE_SIZE, IMAGE_SIZE), PILImage.BILINEAR
144
+ )
145
+ attention_resized = np.array(attention_resized).astype(np.float32) / 255.0
146
+
147
+ if image_np.shape[:2] != (IMAGE_SIZE, IMAGE_SIZE):
148
+ image_np = np.array(
149
+ PILImage.fromarray(image_np).resize((IMAGE_SIZE, IMAGE_SIZE), PILImage.BILINEAR)
150
+ )
151
+
152
+ # Apply colormap
153
+ colormap = cm.jet(attention_resized)[:, :, :3]
154
+ colormap = (colormap * 255).astype(np.uint8)
155
+
156
+ # Create overlay
157
+ overlay = (alpha * colormap + (1 - alpha) * image_np).astype(np.uint8)
158
+
159
+ # Create figure
160
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
161
+ fig.patch.set_facecolor('#0d1117')
162
+
163
+ titles = ['Original Image', 'Attention Map', 'Overlay']
164
+ images = [image_np, colormap, overlay]
165
+
166
+ for ax, img, title in zip(axes, images, titles):
167
+ ax.imshow(img)
168
+ ax.set_title(title, color='white', fontsize=14, fontweight='bold', pad=10)
169
+ ax.axis('off')
170
+ ax.set_facecolor('#0d1117')
171
+
172
+ plt.tight_layout(pad=2)
173
+
174
+ fig.canvas.draw()
175
+ rgba_buffer = fig.canvas.buffer_rgba()
176
+ result = np.array(rgba_buffer)[:, :, :3]
177
+ plt.close(fig)
178
+
179
+ return result
180
+
181
+
182
+ class Cifar10ViTHandler(BaseModelHandler):
183
+ """Model handler for CIFAR-10 ViT-B/16."""
184
+
185
+ def __init__(self, model_path: str):
186
+ self.model_path = model_path
187
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
188
+ self.model = None
189
+ self.attention_viz = None
190
+ self.history = {}
191
+ self.best_accuracy = None
192
+ self._calibration_cache = {}
193
+ self.transform = transforms.Compose([
194
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
195
+ transforms.ToTensor(),
196
+ transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
197
+ ])
198
+ self._load_model()
199
+
200
+ def _load_model(self):
201
+ """Load the trained model."""
202
+ self.model = create_vit_model(num_classes=10)
203
+
204
+ if os.path.exists(self.model_path):
205
+ checkpoint = torch.load(self.model_path, map_location=self.device,
206
+ weights_only=True)
207
+ if isinstance(checkpoint, dict):
208
+ self.history = checkpoint.get('history', {}) or {}
209
+ self.best_accuracy = get_best_accuracy_from_history(self.history)
210
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
211
+ self.model.load_state_dict(checkpoint['model_state_dict'])
212
+ else:
213
+ self.model.load_state_dict(checkpoint)
214
+
215
+ self.model = self.model.to(self.device)
216
+ self.model.eval()
217
+
218
+ # Initialize attention visualizer
219
+ self.attention_viz = ViTAttentionVisualizer(self.model)
220
+
221
+ precomputed_full = load_precomputed_calibration_result("vit_b16")
222
+ if precomputed_full is not None:
223
+ self._calibration_cache["full"] = precomputed_full
224
+
225
+ def get_model_name(self) -> str:
226
+ return "ViT-B/16"
227
+
228
+ def get_dataset_name(self) -> str:
229
+ return "CIFAR-10"
230
+
231
+ def get_data_type(self) -> str:
232
+ return "image"
233
+
234
+ def get_class_labels(self) -> List[str]:
235
+ return CIFAR10_LABELS
236
+
237
+ def get_model_info(self) -> Dict[str, str]:
238
+ total_params = sum(p.numel() for p in self.model.parameters())
239
+ best_accuracy = (
240
+ f"{self.best_accuracy:.2f}%"
241
+ if self.best_accuracy is not None
242
+ else "N/A"
243
+ )
244
+ info = {
245
+ "Architecture": "ViT-B/16 (Transfer Learning from ImageNet)",
246
+ "Dataset": "CIFAR-10 (10 classes, 60,000 images)",
247
+ "Parameters": f"{total_params:,}",
248
+ "Input Size": f"{IMAGE_SIZE}×{IMAGE_SIZE}×3",
249
+ "Training": "Full fine-tune, AdamW, Cosine Annealing LR",
250
+ "Best Accuracy": best_accuracy,
251
+ "Device": str(self.device),
252
+ }
253
+ if self.history:
254
+ info["Epochs"] = str(len(self.history.get("val_acc", [])))
255
+ full_result = self._calibration_cache.get("full")
256
+ if full_result is not None:
257
+ info["Full-Test ECE"] = f"{full_result.ece:.6f}"
258
+ return info
259
+
260
+ def predict(self, input_data) -> PredictionResult:
261
+ """Run prediction with attention visualization."""
262
+ if input_data is None:
263
+ raise ValueError("No input image provided")
264
+
265
+ # Convert to PIL Image if numpy array
266
+ if isinstance(input_data, np.ndarray):
267
+ original_image = input_data.copy()
268
+ pil_image = Image.fromarray(input_data).convert('RGB')
269
+ else:
270
+ pil_image = input_data.convert('RGB')
271
+ original_image = np.array(pil_image)
272
+
273
+ # Preprocess
274
+ input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
275
+
276
+ # Forward pass
277
+ with torch.no_grad():
278
+ output = self.model(input_tensor)
279
+ probabilities = torch.softmax(output, dim=1)[0]
280
+
281
+ probs = probabilities.cpu().numpy()
282
+ pred_idx = probs.argmax()
283
+ pred_label = CIFAR10_LABELS[pred_idx]
284
+ pred_conf = float(probs[pred_idx])
285
+
286
+ # Generate attention visualization
287
+ attention_map = self.attention_viz.generate_attention_map(input_tensor)
288
+ explanation_image = create_attention_overlay(original_image, attention_map)
289
+
290
+ return PredictionResult(
291
+ label=pred_label,
292
+ confidence=pred_conf,
293
+ all_labels=CIFAR10_LABELS,
294
+ all_confidences=probs.tolist(),
295
+ explanation_image=explanation_image,
296
+ )
297
+
298
+ def get_example_inputs(self) -> List[Any]:
299
+ return []
300
+
301
+ def get_calibration_data(
302
+ self, max_samples: Optional[int] = None
303
+ ) -> Optional[CalibrationResult]:
304
+ """Compute calibration metrics on test set."""
305
+ cache_key = "full" if max_samples is None else f"subset:{max_samples}"
306
+ if cache_key in self._calibration_cache:
307
+ return self._calibration_cache[cache_key]
308
+
309
+ try:
310
+ import matplotlib
311
+ matplotlib.use('Agg')
312
+ import matplotlib.pyplot as plt
313
+
314
+ test_dataset = create_cifar10_test_dataset(transform=self.transform)
315
+ if max_samples is not None and 0 < max_samples < len(test_dataset):
316
+ indices = np.linspace(
317
+ 0, len(test_dataset) - 1, num=max_samples, dtype=int
318
+ ).tolist()
319
+ test_dataset = torch.utils.data.Subset(test_dataset, indices)
320
+
321
+ test_loader = torch.utils.data.DataLoader(
322
+ test_dataset, batch_size=128, shuffle=False, num_workers=0
323
+ )
324
+
325
+ all_probs = []
326
+ all_preds = []
327
+ all_targets = []
328
+
329
+ self.model.eval()
330
+ with torch.inference_mode():
331
+ for inputs, targets in test_loader:
332
+ inputs = inputs.to(self.device)
333
+ outputs = self.model(inputs)
334
+ probs = torch.softmax(outputs, dim=1)
335
+ preds = outputs.argmax(1)
336
+
337
+ all_probs.extend(probs.cpu().numpy())
338
+ all_preds.extend(preds.cpu().numpy())
339
+ all_targets.extend(targets.numpy())
340
+
341
+ all_probs = np.array(all_probs)
342
+ all_preds = np.array(all_preds)
343
+ all_targets = np.array(all_targets)
344
+
345
+ # Compute ECE
346
+ n_bins = 10
347
+ max_probs = np.max(all_probs, axis=1)
348
+ correctness = (all_preds == all_targets).astype(float)
349
+
350
+ bin_boundaries = np.linspace(0, 1, n_bins + 1)
351
+ bin_accuracies = []
352
+ bin_confidences = []
353
+ bin_counts = []
354
+
355
+ for i in range(n_bins):
356
+ lower = bin_boundaries[i]
357
+ upper = bin_boundaries[i + 1]
358
+ mask = (max_probs > lower) & (max_probs <= upper)
359
+ count = mask.sum()
360
+ bin_counts.append(int(count))
361
+
362
+ if count > 0:
363
+ bin_acc = correctness[mask].mean()
364
+ bin_conf = max_probs[mask].mean()
365
+ else:
366
+ bin_acc = 0.0
367
+ bin_conf = 0.0
368
+
369
+ bin_accuracies.append(float(bin_acc))
370
+ bin_confidences.append(float(bin_conf))
371
+
372
+ # Compute ECE
373
+ total = len(all_preds)
374
+ ece = sum(
375
+ (count / total) * abs(acc - conf)
376
+ for count, acc, conf in zip(bin_counts, bin_accuracies, bin_confidences)
377
+ )
378
+
379
+ # Create reliability diagram
380
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
381
+ fig.patch.set_facecolor('#0d1117')
382
+
383
+ # Reliability Diagram
384
+ ax1.set_facecolor('#161b22')
385
+ bin_centers = [(bin_boundaries[i] + bin_boundaries[i + 1]) / 2 for i in range(n_bins)]
386
+ width = 0.08
387
+
388
+ ax1.bar([c - width/2 for c in bin_centers], bin_accuracies, width,
389
+ label='Accuracy', color='#58a6ff', alpha=0.9, edgecolor='#58a6ff')
390
+ ax1.bar([c + width/2 for c in bin_centers], bin_confidences, width,
391
+ label='Avg Confidence', color='#f97583', alpha=0.9, edgecolor='#f97583')
392
+
393
+ ax1.plot([0, 1], [0, 1], '--', color='#8b949e', linewidth=2,
394
+ label='Perfect Calibration')
395
+ ax1.set_xlim(0, 1)
396
+ ax1.set_ylim(0, 1)
397
+ ax1.set_xlabel('Confidence', color='white', fontsize=12)
398
+ ax1.set_ylabel('Accuracy / Confidence', color='white', fontsize=12)
399
+ ax1.set_title(f'Reliability Diagram (ECE: {ece:.4f})',
400
+ color='white', fontsize=14, fontweight='bold', pad=15)
401
+ ax1.legend(facecolor='#161b22', edgecolor='#30363d', labelcolor='white', fontsize=10)
402
+ ax1.tick_params(colors='white')
403
+ for spine in ax1.spines.values():
404
+ spine.set_edgecolor('#30363d')
405
+ ax1.grid(True, alpha=0.1, color='white')
406
+
407
+ # Confidence histogram
408
+ ax2.set_facecolor('#161b22')
409
+ ax2.bar(bin_centers, [c / total for c in bin_counts], 0.08,
410
+ color='#56d364', alpha=0.9, edgecolor='#56d364')
411
+ ax2.set_xlim(0, 1)
412
+ ax2.set_xlabel('Confidence', color='white', fontsize=12)
413
+ ax2.set_ylabel('Fraction of Samples', color='white', fontsize=12)
414
+ ax2.set_title('Confidence Distribution',
415
+ color='white', fontsize=14, fontweight='bold', pad=15)
416
+ ax2.tick_params(colors='white')
417
+ for spine in ax2.spines.values():
418
+ spine.set_edgecolor('#30363d')
419
+ ax2.grid(True, alpha=0.1, color='white')
420
+
421
+ plt.tight_layout(pad=3)
422
+
423
+ fig.canvas.draw()
424
+ rgba_buffer = fig.canvas.buffer_rgba()
425
+ diagram = np.array(rgba_buffer)[:, :, :3]
426
+ plt.close(fig)
427
+
428
+ self._calibration_cache[cache_key] = CalibrationResult(
429
+ ece=ece,
430
+ bin_accuracies=bin_accuracies,
431
+ bin_confidences=bin_confidences,
432
+ bin_counts=bin_counts,
433
+ reliability_diagram=diagram,
434
+ source="Live computation",
435
+ )
436
+ return self._calibration_cache[cache_key]
437
+
438
+ except Exception as e:
439
+ print(f"Error computing calibration: {e}")
440
+ return None
assignments/assignment-1/app/main.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Deep Learning Assignment 1 - Application Demo
3
+ ===============================================
4
+ A modular Gradio application for demonstrating
5
+ trained models on Image, Text, and Multimodal datasets.
6
+
7
+ Features:
8
+ - Image classification with Grad-CAM / attention visualization
9
+ - Model Calibration analysis (ECE + Reliability Diagram)
10
+ - Easy to extend with new models/datasets
11
+
12
+ Usage:
13
+ python assignments/assignment-1/app/main.py
14
+ """
15
+
16
+ import sys
17
+ import os
18
+
19
+ # Add assignment root to path so `app.*` imports keep working.
20
+ ASSIGNMENT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
21
+ sys.path.insert(0, ASSIGNMENT_ROOT)
22
+
23
+ import gradio as gr
24
+ from typing import Dict
25
+
26
+ from app.shared.model_registry import (
27
+ register_model,
28
+ get_all_model_keys,
29
+ get_models_by_type,
30
+ BaseModelHandler,
31
+ )
32
+ from app.image.resnet18 import Cifar10ResNet18Handler
33
+ from app.image.vit_b16 import Cifar10ViTHandler
34
+
35
+
36
+ # ============================================================================
37
+ # CONFIGURATION
38
+ # ============================================================================
39
+
40
+ APP_TITLE = "🧠 Deep Learning Assignment 1 - Demo"
41
+ APP_DESCRIPTION = """
42
+ <div style="text-align: center; padding: 10px 0;">
43
+ <p style="font-size: 16px; color: #8b949e; margin: 5px 0;">
44
+ Classification on Images, Text, and Multimodal Data
45
+ </p>
46
+ <p style="font-size: 14px; color: #58a6ff; margin: 5px 0;">
47
+ CO3091 · HCM University of Technology · 2025-2026 Semester 2
48
+ </p>
49
+ </div>
50
+ """
51
+
52
+ # Load custom CSS from external file
53
+ CSS_PATH = os.path.join(os.path.dirname(__file__), "assets", "style.css")
54
+ if os.path.exists(CSS_PATH):
55
+ with open(CSS_PATH, "r", encoding="utf-8") as f:
56
+ CUSTOM_CSS = f.read()
57
+ else:
58
+ CUSTOM_CSS = ""
59
+
60
+
61
+ CUSTOM_THEME = gr.themes.Base(
62
+ primary_hue=gr.themes.colors.blue,
63
+ secondary_hue=gr.themes.colors.green,
64
+ neutral_hue=gr.themes.colors.gray,
65
+ font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"],
66
+ ).set(
67
+ body_background_fill="#0d1117",
68
+ body_background_fill_dark="#0d1117",
69
+ block_background_fill="#161b22",
70
+ block_background_fill_dark="#161b22",
71
+ block_border_color="#30363d",
72
+ block_border_color_dark="#30363d",
73
+ block_label_text_color="#c9d1d9",
74
+ block_label_text_color_dark="#c9d1d9",
75
+ block_title_text_color="#f0f6fc",
76
+ block_title_text_color_dark="#f0f6fc",
77
+ body_text_color="#c9d1d9",
78
+ body_text_color_dark="#c9d1d9",
79
+ body_text_color_subdued="#8b949e",
80
+ body_text_color_subdued_dark="#8b949e",
81
+ button_primary_background_fill="#238636",
82
+ button_primary_background_fill_dark="#238636",
83
+ button_primary_background_fill_hover="#2ea043",
84
+ button_primary_background_fill_hover_dark="#2ea043",
85
+ button_primary_text_color="white",
86
+ button_primary_text_color_dark="white",
87
+ input_background_fill="#0d1117",
88
+ input_background_fill_dark="#0d1117",
89
+ input_border_color="#30363d",
90
+ input_border_color_dark="#30363d",
91
+ shadow_drop="none",
92
+ shadow_drop_lg="none",
93
+ )
94
+
95
+
96
+ # ============================================================================
97
+ # MODEL INITIALIZATION
98
+ # ============================================================================
99
+
100
+ def init_models():
101
+ """Initialize and register all available models."""
102
+ model_dir = os.path.join(ASSIGNMENT_ROOT, "image", "models")
103
+
104
+ # CIFAR-10 ResNet-18
105
+ resnet18_path = os.path.join(model_dir, "resnet18_cifar10.pth")
106
+ if os.path.exists(resnet18_path):
107
+ try:
108
+ handler = Cifar10ResNet18Handler(resnet18_path)
109
+ register_model("cifar10_resnet18", handler)
110
+ print(f"✅ Loaded: CIFAR-10 ResNet-18 from {resnet18_path}")
111
+ except Exception as e:
112
+ print(f"❌ Failed to load CIFAR-10 ResNet-18: {e}")
113
+ else:
114
+ print(f"⚠️ Model file not found: {resnet18_path}")
115
+
116
+ # CIFAR-10 ViT-B/16
117
+ vit_path = os.path.join(model_dir, "vit_b16_cifar10.pth")
118
+ if os.path.exists(vit_path):
119
+ try:
120
+ handler = Cifar10ViTHandler(vit_path)
121
+ register_model("cifar10_vit", handler)
122
+ print(f"✅ Loaded: CIFAR-10 ViT-B/16 from {vit_path}")
123
+ except Exception as e:
124
+ print(f"❌ Failed to load CIFAR-10 ViT-B/16: {e}")
125
+ else:
126
+ print(f"⚠️ Model file not found: {vit_path}")
127
+
128
+
129
+ # ============================================================================
130
+ # UI BUILDER FUNCTIONS
131
+ # ============================================================================
132
+
133
+ def format_confidence_label(labels, confidences, top_k=5):
134
+ """Format top-k predictions as a dictionary for gr.Label."""
135
+ paired = sorted(zip(labels, confidences), key=lambda x: x[1], reverse=True)
136
+ return {label: float(conf) for label, conf in paired[:top_k]}
137
+
138
+
139
+ def build_model_info_markdown(handler: BaseModelHandler) -> str:
140
+ """Build formatted model info markdown."""
141
+ info = handler.get_model_info()
142
+ lines = ["### 📋 Model Information\n"]
143
+ for key, val in info.items():
144
+ lines.append(f"| **{key}** | {val} |")
145
+
146
+ header = "| Property | Value |\n|:---|:---|\n"
147
+ table_lines = [line for line in lines[1:]]
148
+ return lines[0] + header + "\n".join(table_lines)
149
+
150
+
151
+ def build_image_prediction_tab(model_key: str, handler: BaseModelHandler):
152
+ """Build the prediction tab UI for image models."""
153
+ with gr.Row(equal_height=True):
154
+ with gr.Column(scale=1):
155
+ input_image = gr.Image(
156
+ label="📸 Upload Image",
157
+ type="numpy",
158
+ height=300,
159
+ sources=["upload", "clipboard"],
160
+ )
161
+ predict_btn = gr.Button(
162
+ "🔍 Predict & Explain",
163
+ variant="primary",
164
+ size="lg",
165
+ )
166
+ gr.Markdown(
167
+ f"*Classes: {', '.join(handler.get_class_labels())}*",
168
+ elem_classes=["text-sm"],
169
+ )
170
+
171
+ with gr.Column(scale=1):
172
+ output_label = gr.Label(
173
+ label="📊 Prediction Results (Top-5)",
174
+ num_top_classes=5,
175
+ )
176
+
177
+ with gr.Row():
178
+ explanation_image = gr.Image(
179
+ label="🔥 Model Explanation (Interpretability)",
180
+ interactive=False,
181
+ height=350,
182
+ )
183
+
184
+ def do_predict(image):
185
+ if image is None:
186
+ return None, None
187
+ try:
188
+ result = handler.predict(image)
189
+ conf_dict = format_confidence_label(
190
+ result.all_labels, result.all_confidences
191
+ )
192
+ return conf_dict, result.explanation_image
193
+ except Exception as e:
194
+ raise gr.Error(f"Prediction failed: {str(e)}")
195
+
196
+ predict_btn.click(
197
+ fn=do_predict,
198
+ inputs=[input_image],
199
+ outputs=[output_label, explanation_image],
200
+ )
201
+
202
+
203
+ def build_calibration_tab(model_key: str, handler: BaseModelHandler):
204
+ """Build the calibration analysis tab."""
205
+ gr.Markdown("""
206
+ ### 📐 Model Calibration Analysis
207
+
208
+ Calibration measures how well the model's confidence matches its actual accuracy.
209
+ A perfectly calibrated model has **confidence = accuracy** for all predictions.
210
+
211
+ - **ECE (Expected Calibration Error)**: Lower is better (0 = perfect calibration)
212
+ - **Reliability Diagram**: Compares predicted confidence vs actual accuracy per bin
213
+ - **Quick Preview**: Uses a very small subset for fast CPU demos
214
+ - **Full Test Set**: Uses notebook artifacts instantly when available
215
+ """)
216
+
217
+ calibration_mode = gr.Radio(
218
+ choices=[
219
+ "Quick Preview (64 samples)",
220
+ "Full Test Set (10,000 samples)",
221
+ ],
222
+ value="Quick Preview (64 samples)",
223
+ label="Calibration Mode",
224
+ )
225
+
226
+ compute_btn = gr.Button(
227
+ "📊 Compute Calibration",
228
+ variant="primary",
229
+ size="lg",
230
+ )
231
+
232
+ ece_display = gr.Markdown(visible=False)
233
+ calibration_plot = gr.Image(
234
+ label="📈 Calibration Analysis",
235
+ interactive=False,
236
+ visible=False,
237
+ height=450,
238
+ )
239
+
240
+ def compute_calibration(mode):
241
+ try:
242
+ max_samples = 64 if mode.startswith("Quick Preview") else None
243
+ result = handler.get_calibration_data(max_samples=max_samples)
244
+ if result is None:
245
+ raise gr.Error("Could not compute calibration data")
246
+
247
+ sample_note = (
248
+ "Approximate preview on 64 evenly spaced test images"
249
+ if max_samples is not None
250
+ else "Full CIFAR-10 test set"
251
+ )
252
+ source_note = result.source or "Live computation"
253
+ ece_md = f"""
254
+ ### Calibration Metrics
255
+
256
+ | Metric | Value |
257
+ |:---|:---|
258
+ | **Mode** | {sample_note} |
259
+ | **Source** | {source_note} |
260
+ | **Expected Calibration Error (ECE)** | `{result.ece:.6f}` |
261
+ | **Interpretation** | {'✅ Well calibrated' if result.ece < 0.05 else '⚠️ Moderately calibrated' if result.ece < 0.15 else '❌ Poorly calibrated'} |
262
+ | **Total evaluated samples** | {sum(result.bin_counts):,} |
263
+ """
264
+ return (
265
+ gr.update(value=ece_md, visible=True),
266
+ gr.update(value=result.reliability_diagram, visible=True),
267
+ )
268
+ except Exception as e:
269
+ raise gr.Error(f"Calibration computation failed: {str(e)}")
270
+
271
+ compute_btn.click(
272
+ fn=compute_calibration,
273
+ inputs=[calibration_mode],
274
+ outputs=[ece_display, calibration_plot],
275
+ )
276
+
277
+
278
+ def build_model_tabs(model_key: str, handler: BaseModelHandler):
279
+ """Build all tabs for a specific model."""
280
+ gr.Markdown(build_model_info_markdown(handler))
281
+
282
+ with gr.Tabs():
283
+ with gr.Tab("🎯 Predict & Explain", id="predict"):
284
+ data_type = handler.get_data_type()
285
+ if data_type == "image":
286
+ build_image_prediction_tab(model_key, handler)
287
+ elif data_type == "text":
288
+ gr.Markdown("### 📝 Text Classification\n*Coming soon...*")
289
+ elif data_type == "multimodal":
290
+ gr.Markdown("### 🖼️+📝 Multimodal Classification\n*Coming soon...*")
291
+
292
+ with gr.Tab("📐 Calibration", id="calibration"):
293
+ build_calibration_tab(model_key, handler)
294
+
295
+
296
+ # ============================================================================
297
+ # MAIN APPLICATION
298
+ # ============================================================================
299
+
300
+ def create_app() -> gr.Blocks:
301
+ """Create the main Gradio application."""
302
+ init_models()
303
+
304
+ with gr.Blocks(
305
+ title="DL Assignment 1 - Demo",
306
+ ) as app:
307
+ gr.Markdown(f"# {APP_TITLE}")
308
+ gr.Markdown(APP_DESCRIPTION)
309
+
310
+ model_keys = get_all_model_keys()
311
+
312
+ if not model_keys:
313
+ gr.Markdown("""
314
+ ## ⚠️ No Models Loaded
315
+
316
+ Please ensure model files are in the `image/models/` directory.
317
+ See the README for instructions on adding models.
318
+ """)
319
+ else:
320
+ image_models = get_models_by_type("image")
321
+ text_models = get_models_by_type("text")
322
+ multimodal_models = get_models_by_type("multimodal")
323
+
324
+ with gr.Tabs():
325
+ if image_models:
326
+ with gr.Tab("🖼️ Image Classification", id="image_tab"):
327
+ if len(image_models) > 1:
328
+ with gr.Tabs():
329
+ for key, handler in image_models.items():
330
+ tab_name = f"{handler.get_model_name()} ({handler.get_dataset_name()})"
331
+ with gr.Tab(tab_name):
332
+ build_model_tabs(key, handler)
333
+ else:
334
+ key, handler = next(iter(image_models.items()))
335
+ build_model_tabs(key, handler)
336
+
337
+ if text_models:
338
+ with gr.Tab("📝 Text Classification", id="text_tab"):
339
+ if len(text_models) > 1:
340
+ with gr.Tabs():
341
+ for key, handler in text_models.items():
342
+ tab_name = f"{handler.get_model_name()} ({handler.get_dataset_name()})"
343
+ with gr.Tab(tab_name):
344
+ build_model_tabs(key, handler)
345
+ else:
346
+ key, handler = next(iter(text_models.items()))
347
+ build_model_tabs(key, handler)
348
+
349
+ if multimodal_models:
350
+ with gr.Tab("🔀 Multimodal Classification", id="mm_tab"):
351
+ if len(multimodal_models) > 1:
352
+ with gr.Tabs():
353
+ for key, handler in multimodal_models.items():
354
+ tab_name = f"{handler.get_model_name()} ({handler.get_dataset_name()})"
355
+ with gr.Tab(tab_name):
356
+ build_model_tabs(key, handler)
357
+ else:
358
+ key, handler = next(iter(multimodal_models.items()))
359
+ build_model_tabs(key, handler)
360
+
361
+ if not text_models:
362
+ with gr.Tab("📝 Text Classification", id="text_tab"):
363
+ gr.Markdown("""
364
+ ### 📝 Text Classification Models
365
+
366
+ *No text models loaded yet. Add your text model handler
367
+ and register it in `app/main.py`.*
368
+ """)
369
+
370
+ if not multimodal_models:
371
+ with gr.Tab("🔀 Multimodal Classification", id="mm_tab"):
372
+ gr.Markdown("""
373
+ ### 🔀 Multimodal Classification Models
374
+
375
+ *No multimodal models loaded yet. Add your multimodal
376
+ model handler and register it in `app/main.py`.*
377
+ """)
378
+
379
+ gr.Markdown("""
380
+ <div class="app-footer">
381
+ <p>Deep Learning and Its Applications · Assignment 1</p>
382
+ <p>HCM University of Technology (HCMUT) · VNUHCM</p>
383
+ </div>
384
+ """)
385
+
386
+ return app
387
+
388
+
389
+ # ============================================================================
390
+ # ENTRY POINT
391
+ # ============================================================================
392
+
393
+ if __name__ == "__main__":
394
+ app = create_app()
395
+ app.launch(
396
+ server_name="127.0.0.1",
397
+ server_port=5555,
398
+ share=False,
399
+ show_error=True,
400
+ theme=CUSTOM_THEME,
401
+ css=CUSTOM_CSS,
402
+ allowed_paths=[os.path.join(ASSIGNMENT_ROOT, "image", "artifacts")],
403
+ )
assignments/assignment-1/app/multimodal/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multimodal App Modules
2
+
3
+ Place multimodal-specific inference handlers here.
4
+
5
+ Suggested additions:
6
+
7
+ - multimodal model wrapper classes
8
+ - joint preprocessing helpers
9
+ - prediction utilities
10
+ - demo-specific visualization helpers
11
+
12
+ After adding a handler, register it in `assignments/assignment-1/app/main.py`.
assignments/assignment-1/app/multimodal/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Multimodal model handlers for Assignment 1."""
assignments/assignment-1/app/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Deep Learning Assignment 1 - Application Demo Dependencies
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ gradio>=5.0.0
5
+ numpy>=1.24.0
6
+ Pillow>=9.0.0
7
+ matplotlib>=3.7.0
assignments/assignment-1/app/shared/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Shared app utilities for Assignment 1."""
assignments/assignment-1/app/shared/artifact_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for reading notebook-generated artifacts and training metadata.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Any, Dict, Optional
11
+
12
+ from .model_registry import CalibrationResult
13
+
14
+
15
+ ASSIGNMENT_ROOT = Path(
16
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17
+ )
18
+ ARTIFACTS_DIR = ASSIGNMENT_ROOT / "image" / "artifacts"
19
+
20
+
21
+ def get_best_accuracy_from_history(history: Optional[Dict[str, Any]]) -> Optional[float]:
22
+ """Return the best validation accuracy found in a checkpoint history."""
23
+ if not history:
24
+ return None
25
+
26
+ val_acc = history.get("val_acc")
27
+ if isinstance(val_acc, list) and val_acc:
28
+ return float(max(val_acc))
29
+
30
+ return None
31
+
32
+
33
+ def load_precomputed_calibration_result(
34
+ model_tag: str,
35
+ sample_tag: str = "full",
36
+ ) -> Optional[CalibrationResult]:
37
+ """
38
+ Load notebook-generated calibration metrics and figure from image/artifacts/.
39
+
40
+ The function searches recursively so nested folders like artifacts/cnn and
41
+ artifacts/vit are both supported.
42
+ """
43
+ if not ARTIFACTS_DIR.exists():
44
+ return None
45
+
46
+ metrics_name = f"{model_tag}_calibration_metrics_{sample_tag}.json"
47
+ image_name = f"{model_tag}_calibration_{sample_tag}.png"
48
+
49
+ metrics_path = next(ARTIFACTS_DIR.rglob(metrics_name), None)
50
+ image_path = next(ARTIFACTS_DIR.rglob(image_name), None)
51
+
52
+ if metrics_path is None or image_path is None:
53
+ return None
54
+
55
+ metrics = json.loads(metrics_path.read_text(encoding="utf-8"))
56
+ return CalibrationResult(
57
+ ece=float(metrics["ece"]),
58
+ bin_accuracies=[float(x) for x in metrics["bin_accuracies"]],
59
+ bin_confidences=[float(x) for x in metrics["bin_confidences"]],
60
+ bin_counts=[int(x) for x in metrics["bin_counts"]],
61
+ reliability_diagram=str(image_path),
62
+ source=f"Notebook artifact ({metrics_path.parent.name})",
63
+ )
assignments/assignment-1/app/shared/model_registry.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Registry - Central place to register and manage all models.
3
+
4
+ This module makes it easy to add new models for different datasets.
5
+ Each model handler should implement the BaseModelHandler interface.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import Dict, List, Optional, Tuple, Any
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+
14
+ class PredictionResult:
15
+ """Container for prediction results from a model."""
16
+
17
+ def __init__(
18
+ self,
19
+ label: str,
20
+ confidence: float,
21
+ all_labels: List[str],
22
+ all_confidences: List[float],
23
+ explanation_image: Optional[np.ndarray] = None,
24
+ ):
25
+ self.label = label
26
+ self.confidence = confidence
27
+ self.all_labels = all_labels
28
+ self.all_confidences = all_confidences
29
+ self.explanation_image = explanation_image # Grad-CAM or attention map
30
+
31
+
32
+ class CalibrationResult:
33
+ """Container for model calibration analysis results."""
34
+
35
+ def __init__(
36
+ self,
37
+ ece: float,
38
+ bin_accuracies: List[float],
39
+ bin_confidences: List[float],
40
+ bin_counts: List[int],
41
+ reliability_diagram: Optional[Any] = None,
42
+ source: Optional[str] = None,
43
+ ):
44
+ self.ece = ece
45
+ self.bin_accuracies = bin_accuracies
46
+ self.bin_confidences = bin_confidences
47
+ self.bin_counts = bin_counts
48
+ self.reliability_diagram = reliability_diagram
49
+ self.source = source
50
+
51
+
52
+ class BaseModelHandler(ABC):
53
+ """
54
+ Abstract base class for model handlers.
55
+
56
+ To add a new model, create a subclass and implement all abstract methods.
57
+ Then register it in the MODEL_REGISTRY dictionary below.
58
+ """
59
+
60
+ @abstractmethod
61
+ def get_model_name(self) -> str:
62
+ """Return human-readable model name."""
63
+ pass
64
+
65
+ @abstractmethod
66
+ def get_dataset_name(self) -> str:
67
+ """Return the dataset name this model was trained on."""
68
+ pass
69
+
70
+ @abstractmethod
71
+ def get_data_type(self) -> str:
72
+ """Return data type: 'image', 'text', or 'multimodal'."""
73
+ pass
74
+
75
+ @abstractmethod
76
+ def get_class_labels(self) -> List[str]:
77
+ """Return list of class labels."""
78
+ pass
79
+
80
+ @abstractmethod
81
+ def get_model_info(self) -> Dict[str, str]:
82
+ """Return dict of model info for display (architecture, params, etc.)."""
83
+ pass
84
+
85
+ @abstractmethod
86
+ def predict(self, input_data) -> PredictionResult:
87
+ """
88
+ Run prediction on input data.
89
+
90
+ For image models: input_data is a PIL Image or numpy array
91
+ For text models: input_data is a string
92
+ For multimodal: input_data is a tuple (image, text)
93
+
94
+ Returns: PredictionResult
95
+ """
96
+ pass
97
+
98
+ @abstractmethod
99
+ def get_example_inputs(self) -> List[Any]:
100
+ """Return list of example inputs for the demo."""
101
+ pass
102
+
103
+ def get_calibration_data(
104
+ self, max_samples: Optional[int] = None
105
+ ) -> Optional[CalibrationResult]:
106
+ """
107
+ Optionally return calibration analysis result.
108
+ Override this in subclass if you want calibration display.
109
+ """
110
+ return None
111
+
112
+
113
+ # Global model registry - add new models here
114
+ MODEL_REGISTRY: Dict[str, BaseModelHandler] = {}
115
+
116
+
117
+ def register_model(key: str, handler: BaseModelHandler):
118
+ """Register a model handler in the global registry."""
119
+ MODEL_REGISTRY[key] = handler
120
+
121
+
122
+ def get_model_handler(key: str) -> Optional[BaseModelHandler]:
123
+ """Get a model handler by key."""
124
+ return MODEL_REGISTRY.get(key)
125
+
126
+
127
+ def get_all_model_keys() -> List[str]:
128
+ """Get all registered model keys."""
129
+ return list(MODEL_REGISTRY.keys())
130
+
131
+
132
+ def get_models_by_type(data_type: str) -> Dict[str, BaseModelHandler]:
133
+ """Get all models of a specific data type."""
134
+ return {k: v for k, v in MODEL_REGISTRY.items() if v.get_data_type() == data_type}
assignments/assignment-1/app/text/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text App Modules
2
+
3
+ Place text-specific inference handlers here.
4
+
5
+ Suggested additions:
6
+
7
+ - model wrapper classes
8
+ - preprocessing helpers
9
+ - prediction utilities
10
+ - calibration or explanation helpers if needed
11
+
12
+ After adding a handler, register it in `assignments/assignment-1/app/main.py`.
assignments/assignment-1/app/text/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Text model handlers for Assignment 1."""
assignments/assignment-1/image/artifacts/cnn/resnet18_calibration_full.png ADDED

Git LFS Details

  • SHA256: 1cb171cadaa51bc9800aeae468a823cb6e30799714c5d6c7cc39d6cbc32acc42
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB
assignments/assignment-1/image/artifacts/cnn/resnet18_calibration_metrics_full.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_tag": "resnet18",
3
+ "sample_tag": "full",
4
+ "ece": 0.020006245681643487,
5
+ "num_bins": 10,
6
+ "total_evaluated_samples": 10000,
7
+ "bin_accuracies": [
8
+ 0.0,
9
+ 0.0,
10
+ 1.0,
11
+ 0.0,
12
+ 0.6153846383094788,
13
+ 0.5058823823928833,
14
+ 0.5416666865348816,
15
+ 0.6666666865348816,
16
+ 0.6967741847038269,
17
+ 0.9816811680793762
18
+ ],
19
+ "bin_confidences": [
20
+ 0.0,
21
+ 0.0,
22
+ 0.2935279607772827,
23
+ 0.36219507455825806,
24
+ 0.46814653277397156,
25
+ 0.5451725721359253,
26
+ 0.6489876508712769,
27
+ 0.752896785736084,
28
+ 0.8511331677436829,
29
+ 0.9974784851074219
30
+ ],
31
+ "bin_counts": [
32
+ 0,
33
+ 0,
34
+ 1,
35
+ 4,
36
+ 13,
37
+ 85,
38
+ 72,
39
+ 117,
40
+ 155,
41
+ 9553
42
+ ]
43
+ }
assignments/assignment-1/image/artifacts/vit/vit_b16_calibration_full.png ADDED

Git LFS Details

  • SHA256: 9ec24a90630f7df2784bd3706eac25a4711021ed65c2dc58ec7a1ebccf6bd314
  • Pointer size: 131 Bytes
  • Size of remote file: 109 kB
assignments/assignment-1/image/artifacts/vit/vit_b16_calibration_metrics_full.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_tag": "vit_b16",
3
+ "sample_tag": "full",
4
+ "ece": 0.006916732695698738,
5
+ "num_bins": 10,
6
+ "total_evaluated_samples": 10000,
7
+ "bin_accuracies": [
8
+ 0.0,
9
+ 0.0,
10
+ 0.0,
11
+ 0.0,
12
+ 0.5,
13
+ 0.5714285969734192,
14
+ 0.6034482717514038,
15
+ 0.6901408433914185,
16
+ 0.7037037014961243,
17
+ 0.9934116005897522
18
+ ],
19
+ "bin_confidences": [
20
+ 0.0,
21
+ 0.0,
22
+ 0.2842116951942444,
23
+ 0.37363073229789734,
24
+ 0.46517834067344666,
25
+ 0.5469092130661011,
26
+ 0.6531615853309631,
27
+ 0.7513611912727356,
28
+ 0.8554656505584717,
29
+ 0.9979013204574585
30
+ ],
31
+ "bin_counts": [
32
+ 0,
33
+ 0,
34
+ 1,
35
+ 1,
36
+ 12,
37
+ 35,
38
+ 58,
39
+ 71,
40
+ 108,
41
+ 9714
42
+ ]
43
+ }
assignments/assignment-1/image/data/cifar-10-python.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce
3
+ size 170498071
assignments/assignment-1/image/models/resnet18_cifar10.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0076300593993e9e6e09a358c254f24b8ffda12f66ce566e50a289ee462cb10
3
+ size 44808651
assignments/assignment-1/image/models/vit_b16_cifar10.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e4d76e9dcb5b3eb907a00782e9f8af05b9ee46e9f2d3e0e16484d351e63f382
3
+ size 343288191
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ gradio>=5.0.0
4
+ numpy>=1.24.0
5
+ Pillow>=9.0.0
6
+ matplotlib>=3.7.0