Spaces:
Sleeping
Sleeping
Add Gradio image demo
Browse files- .gitattributes +2 -34
- app.py +12 -0
- assignments/assignment-1/app/__init__.py +1 -0
- assignments/assignment-1/app/assets/style.css +135 -0
- assignments/assignment-1/app/image/__init__.py +1 -0
- assignments/assignment-1/app/image/data.py +74 -0
- assignments/assignment-1/app/image/resnet18.py +430 -0
- assignments/assignment-1/app/image/vit_b16.py +440 -0
- assignments/assignment-1/app/main.py +403 -0
- assignments/assignment-1/app/multimodal/README.md +12 -0
- assignments/assignment-1/app/multimodal/__init__.py +1 -0
- assignments/assignment-1/app/requirements.txt +7 -0
- assignments/assignment-1/app/shared/__init__.py +1 -0
- assignments/assignment-1/app/shared/artifact_utils.py +63 -0
- assignments/assignment-1/app/shared/model_registry.py +134 -0
- assignments/assignment-1/app/text/README.md +12 -0
- assignments/assignment-1/app/text/__init__.py +1 -0
- assignments/assignment-1/image/artifacts/cnn/resnet18_calibration_full.png +3 -0
- assignments/assignment-1/image/artifacts/cnn/resnet18_calibration_metrics_full.json +43 -0
- assignments/assignment-1/image/artifacts/vit/vit_b16_calibration_full.png +3 -0
- assignments/assignment-1/image/artifacts/vit/vit_b16_calibration_metrics_full.json +43 -0
- assignments/assignment-1/image/data/cifar-10-python.tar.gz +3 -0
- assignments/assignment-1/image/models/resnet18_cifar10.pth +3 -0
- assignments/assignment-1/image/models/vit_b16_cifar10.pth +3 -0
- requirements.txt +6 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,3 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.
|
| 25 |
-
*.
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import runpy
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
ns = runpy.run_path(
|
| 5 |
+
"assignments/assignment-1/app/main.py",
|
| 6 |
+
run_name="hf_space_app",
|
| 7 |
+
)
|
| 8 |
+
demo = ns["create_app"]()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
if __name__ == "__main__":
|
| 12 |
+
demo.launch()
|
assignments/assignment-1/app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# DL Assignment 1 - Application Demo
|
assignments/assignment-1/app/assets/style.css
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
Deep Learning Assignment - Custom Style System
|
| 3 |
+
=============================================
|
| 4 |
+
Designed for a premium, dark-themed experience.
|
| 5 |
+
*/
|
| 6 |
+
|
| 7 |
+
/* Main container stabilization */
|
| 8 |
+
.gradio-container {
|
| 9 |
+
max-width: 1200px !important;
|
| 10 |
+
margin: 0 auto !important;
|
| 11 |
+
font-family: 'Inter', system-ui, -apple-system, sans-serif !important;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
/* Header & Title Styling */
|
| 15 |
+
.app-header {
|
| 16 |
+
text-align: center;
|
| 17 |
+
padding: 30px 0;
|
| 18 |
+
border-bottom: 1px solid #30363d;
|
| 19 |
+
margin-bottom: 30px;
|
| 20 |
+
background: linear-gradient(to bottom, #161b22, #0d1117);
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
.app-header h1 {
|
| 24 |
+
font-weight: 800 !important;
|
| 25 |
+
letter-spacing: -0.02em;
|
| 26 |
+
background: linear-gradient(135deg, #58a6ff 0%, #bc8cff 100%);
|
| 27 |
+
-webkit-background-clip: text;
|
| 28 |
+
background-clip: text;
|
| 29 |
+
-webkit-text-fill-color: transparent;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
/* Model Info Display */
|
| 33 |
+
.model-info-box {
|
| 34 |
+
background: #161b22;
|
| 35 |
+
border: 1px solid #30363d;
|
| 36 |
+
border-radius: 12px;
|
| 37 |
+
padding: 24px;
|
| 38 |
+
margin: 15px 0;
|
| 39 |
+
box-shadow: 0 4px 20px rgba(0,0,0,0.3);
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
/* Prediction Result Premium Card */
|
| 43 |
+
.prediction-label {
|
| 44 |
+
font-size: 26px !important;
|
| 45 |
+
font-weight: 700 !important;
|
| 46 |
+
text-align: center;
|
| 47 |
+
padding: 20px;
|
| 48 |
+
background: linear-gradient(135deg, #238636 0%, #2ea043 100%);
|
| 49 |
+
border-radius: 12px;
|
| 50 |
+
color: white !important;
|
| 51 |
+
margin: 15px 0;
|
| 52 |
+
box-shadow: 0 8px 32px rgba(35, 134, 54, 0.2);
|
| 53 |
+
border: 1px solid rgba(255,255,255,0.1);
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
/* Confidence Bars & Progress */
|
| 57 |
+
.confidence-bar {
|
| 58 |
+
height: 32px;
|
| 59 |
+
border-radius: 8px;
|
| 60 |
+
background-color: #21262d;
|
| 61 |
+
overflow: hidden;
|
| 62 |
+
margin: 8px 0;
|
| 63 |
+
border: 1px solid #30363d;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
/* Modern Tabs Navigation */
|
| 67 |
+
.tab-nav {
|
| 68 |
+
border-bottom: 1px solid #30363d !important;
|
| 69 |
+
margin-bottom: 20px !important;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
.tab-nav button {
|
| 73 |
+
font-size: 15px !important;
|
| 74 |
+
font-weight: 600 !important;
|
| 75 |
+
padding: 14px 28px !important;
|
| 76 |
+
color: #8b949e !important;
|
| 77 |
+
transition: all 0.2s ease !important;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
.tab-nav button:hover {
|
| 81 |
+
color: #f0f6fc !important;
|
| 82 |
+
background-color: rgba(139, 148, 158, 0.1) !important;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
.tab-nav button.selected {
|
| 86 |
+
color: #58a6ff !important;
|
| 87 |
+
border-bottom: 2px solid #58a6ff !important;
|
| 88 |
+
background: transparent !important;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
/* Calibration Metric Cards */
|
| 92 |
+
.metric-card {
|
| 93 |
+
background: #161b22;
|
| 94 |
+
border: 1px solid #30363d;
|
| 95 |
+
border-radius: 12px;
|
| 96 |
+
padding: 25px;
|
| 97 |
+
text-align: center;
|
| 98 |
+
transition: transform 0.2s ease;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
.metric-card:hover {
|
| 102 |
+
transform: translateY(-2px);
|
| 103 |
+
border-color: #58a6ff;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
/* Custom Buttons Styling */
|
| 107 |
+
.gr-button-primary {
|
| 108 |
+
background: linear-gradient(135deg, #1f6feb 0%, #58a6ff 100%) !important;
|
| 109 |
+
border: none !important;
|
| 110 |
+
font-weight: 600 !important;
|
| 111 |
+
box-shadow: 0 4px 12px rgba(31, 111, 235, 0.3) !important;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
.gr-button-primary:hover {
|
| 115 |
+
filter: brightness(1.1);
|
| 116 |
+
transform: translateY(-1px);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
/* Footer Section */
|
| 120 |
+
.app-footer {
|
| 121 |
+
text-align: center;
|
| 122 |
+
padding: 40px 20px;
|
| 123 |
+
color: #8b949e;
|
| 124 |
+
font-size: 14px;
|
| 125 |
+
border-top: 1px solid #30363d;
|
| 126 |
+
margin-top: 40px;
|
| 127 |
+
opacity: 0.8;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
/* Glassmorphism utility */
|
| 131 |
+
.glass {
|
| 132 |
+
background: rgba(22, 27, 34, 0.7) !important;
|
| 133 |
+
backdrop-filter: blur(10px) !important;
|
| 134 |
+
border: 1px solid rgba(48, 54, 61, 0.5) !important;
|
| 135 |
+
}
|
assignments/assignment-1/app/image/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Image handlers for Assignment 1."""
|
assignments/assignment-1/app/image/data.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for loading the CIFAR-10 test split from local project assets.
|
| 3 |
+
|
| 4 |
+
The workspace keeps the archive at ``image/data/cifar-10-python.tar.gz``.
|
| 5 |
+
Reading the test batch directly from that archive avoids permission issues
|
| 6 |
+
with extracted files while keeping calibration fully offline.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import pickle
|
| 11 |
+
import tarfile
|
| 12 |
+
from functools import lru_cache
|
| 13 |
+
from typing import Tuple
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from torch.utils.data import Dataset
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
ASSIGNMENT_ROOT = os.path.dirname(
|
| 21 |
+
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 22 |
+
)
|
| 23 |
+
DEFAULT_DATA_DIR = os.path.join(ASSIGNMENT_ROOT, "image", "data")
|
| 24 |
+
DEFAULT_ARCHIVE_PATH = os.path.join(DEFAULT_DATA_DIR, "cifar-10-python.tar.gz")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@lru_cache(maxsize=1)
|
| 28 |
+
def load_cifar10_test_arrays(
|
| 29 |
+
archive_path: str = DEFAULT_ARCHIVE_PATH,
|
| 30 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 31 |
+
"""Load CIFAR-10 test images and labels from the local archive."""
|
| 32 |
+
if not os.path.exists(archive_path):
|
| 33 |
+
raise FileNotFoundError(
|
| 34 |
+
f"CIFAR-10 archive not found at {archive_path}. "
|
| 35 |
+
"Expected image/data/cifar-10-python.tar.gz to exist."
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
with tarfile.open(archive_path, "r:gz") as tar:
|
| 39 |
+
member = tar.extractfile("cifar-10-batches-py/test_batch")
|
| 40 |
+
if member is None:
|
| 41 |
+
raise FileNotFoundError(
|
| 42 |
+
"Could not find cifar-10-batches-py/test_batch inside the archive."
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
batch = pickle.load(member, encoding="bytes")
|
| 46 |
+
|
| 47 |
+
images = batch[b"data"].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
|
| 48 |
+
labels = np.asarray(batch[b"labels"], dtype=np.int64)
|
| 49 |
+
return images, labels
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class LocalCIFAR10TestDataset(Dataset):
|
| 53 |
+
"""Dataset wrapper that serves the CIFAR-10 test split from local files."""
|
| 54 |
+
|
| 55 |
+
def __init__(self, transform=None, archive_path: str = DEFAULT_ARCHIVE_PATH):
|
| 56 |
+
self.transform = transform
|
| 57 |
+
self.images, self.labels = load_cifar10_test_arrays(archive_path)
|
| 58 |
+
|
| 59 |
+
def __len__(self) -> int:
|
| 60 |
+
return len(self.labels)
|
| 61 |
+
|
| 62 |
+
def __getitem__(self, idx: int):
|
| 63 |
+
image = Image.fromarray(self.images[idx])
|
| 64 |
+
label = int(self.labels[idx])
|
| 65 |
+
|
| 66 |
+
if self.transform is not None:
|
| 67 |
+
image = self.transform(image)
|
| 68 |
+
|
| 69 |
+
return image, label
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def create_cifar10_test_dataset(transform=None) -> LocalCIFAR10TestDataset:
|
| 73 |
+
"""Create the CIFAR-10 test dataset used by the calibration tab."""
|
| 74 |
+
return LocalCIFAR10TestDataset(transform=transform)
|
assignments/assignment-1/app/image/resnet18.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CIFAR-10 ResNet-18 Model Handler
|
| 3 |
+
|
| 4 |
+
Handles prediction, Grad-CAM visualization, and calibration
|
| 5 |
+
for the ResNet-18 model trained on CIFAR-10.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import numpy as np
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from typing import Dict, List, Optional, Any
|
| 14 |
+
import torchvision.transforms as transforms
|
| 15 |
+
from torchvision.models import resnet18
|
| 16 |
+
|
| 17 |
+
from app.shared.model_registry import (
|
| 18 |
+
BaseModelHandler,
|
| 19 |
+
PredictionResult,
|
| 20 |
+
CalibrationResult,
|
| 21 |
+
)
|
| 22 |
+
from app.shared.artifact_utils import (
|
| 23 |
+
get_best_accuracy_from_history,
|
| 24 |
+
load_precomputed_calibration_result,
|
| 25 |
+
)
|
| 26 |
+
from app.image.data import create_cifar10_test_dataset
|
| 27 |
+
|
| 28 |
+
# CIFAR-10 class labels
|
| 29 |
+
CIFAR10_LABELS = [
|
| 30 |
+
'Airplane', 'Automobile', 'Bird', 'Cat', 'Deer',
|
| 31 |
+
'Dog', 'Frog', 'Horse', 'Ship', 'Truck'
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
# CIFAR-10 normalization values
|
| 35 |
+
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
|
| 36 |
+
CIFAR10_STD = (0.2470, 0.2435, 0.2616)
|
| 37 |
+
|
| 38 |
+
# Image size ResNet expects
|
| 39 |
+
IMAGE_SIZE = 224
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def create_resnet18_cifar10(num_classes=10):
|
| 43 |
+
"""Create ResNet-18 with modified classifier for CIFAR-10."""
|
| 44 |
+
model = resnet18(weights=None)
|
| 45 |
+
num_features = model.fc.in_features
|
| 46 |
+
model.fc = nn.Linear(num_features, num_classes)
|
| 47 |
+
return model
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class GradCAM:
|
| 51 |
+
"""
|
| 52 |
+
Grad-CAM implementation for visual explanation.
|
| 53 |
+
Generates heatmap showing which regions the model focuses on.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, model, target_layer):
|
| 57 |
+
self.model = model
|
| 58 |
+
self.target_layer = target_layer
|
| 59 |
+
self.gradients = None
|
| 60 |
+
self.activations = None
|
| 61 |
+
self._register_hooks()
|
| 62 |
+
|
| 63 |
+
def _register_hooks(self):
|
| 64 |
+
def forward_hook(module, input, output):
|
| 65 |
+
self.activations = output.detach()
|
| 66 |
+
|
| 67 |
+
def backward_hook(module, grad_input, grad_output):
|
| 68 |
+
self.gradients = grad_output[0].detach()
|
| 69 |
+
|
| 70 |
+
self.target_layer.register_forward_hook(forward_hook)
|
| 71 |
+
self.target_layer.register_full_backward_hook(backward_hook)
|
| 72 |
+
|
| 73 |
+
def generate(self, input_tensor, target_class=None):
|
| 74 |
+
"""Generate Grad-CAM heatmap."""
|
| 75 |
+
self.model.eval()
|
| 76 |
+
output = self.model(input_tensor)
|
| 77 |
+
|
| 78 |
+
if target_class is None:
|
| 79 |
+
target_class = output.argmax(dim=1).item()
|
| 80 |
+
|
| 81 |
+
self.model.zero_grad()
|
| 82 |
+
one_hot = torch.zeros_like(output)
|
| 83 |
+
one_hot[0, target_class] = 1.0
|
| 84 |
+
output.backward(gradient=one_hot, retain_graph=True)
|
| 85 |
+
|
| 86 |
+
# Pool gradients across spatial dimensions
|
| 87 |
+
weights = self.gradients.mean(dim=[2, 3], keepdim=True)
|
| 88 |
+
cam = (weights * self.activations).sum(dim=1, keepdim=True)
|
| 89 |
+
cam = torch.relu(cam)
|
| 90 |
+
|
| 91 |
+
# Normalize
|
| 92 |
+
cam = cam - cam.min()
|
| 93 |
+
if cam.max() > 0:
|
| 94 |
+
cam = cam / cam.max()
|
| 95 |
+
|
| 96 |
+
# Resize to input size
|
| 97 |
+
cam = torch.nn.functional.interpolate(
|
| 98 |
+
cam, size=(IMAGE_SIZE, IMAGE_SIZE), mode='bilinear', align_corners=False
|
| 99 |
+
)
|
| 100 |
+
return cam.squeeze().cpu().numpy()
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def create_gradcam_overlay(image_np, heatmap, alpha=0.5):
|
| 104 |
+
"""Create overlay of Grad-CAM heatmap on original image."""
|
| 105 |
+
import matplotlib
|
| 106 |
+
matplotlib.use('Agg')
|
| 107 |
+
import matplotlib.pyplot as plt
|
| 108 |
+
import matplotlib.cm as cm
|
| 109 |
+
|
| 110 |
+
# Apply colormap to heatmap
|
| 111 |
+
colormap = cm.jet(heatmap)[:, :, :3] # Remove alpha channel
|
| 112 |
+
colormap = (colormap * 255).astype(np.uint8)
|
| 113 |
+
|
| 114 |
+
# Resize image to match heatmap
|
| 115 |
+
if image_np.shape[:2] != (IMAGE_SIZE, IMAGE_SIZE):
|
| 116 |
+
img_pil = Image.fromarray(image_np).resize((IMAGE_SIZE, IMAGE_SIZE))
|
| 117 |
+
image_np = np.array(img_pil)
|
| 118 |
+
|
| 119 |
+
# Create overlay
|
| 120 |
+
overlay = (alpha * colormap + (1 - alpha) * image_np).astype(np.uint8)
|
| 121 |
+
|
| 122 |
+
# Create figure with original + heatmap + overlay
|
| 123 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 124 |
+
fig.patch.set_facecolor('#0d1117')
|
| 125 |
+
|
| 126 |
+
titles = ['Original Image', 'Grad-CAM Heatmap', 'Overlay']
|
| 127 |
+
images = [image_np, colormap, overlay]
|
| 128 |
+
|
| 129 |
+
for ax, img, title in zip(axes, images, titles):
|
| 130 |
+
ax.imshow(img)
|
| 131 |
+
ax.set_title(title, color='white', fontsize=14, fontweight='bold', pad=10)
|
| 132 |
+
ax.axis('off')
|
| 133 |
+
ax.set_facecolor('#0d1117')
|
| 134 |
+
|
| 135 |
+
plt.tight_layout(pad=2)
|
| 136 |
+
|
| 137 |
+
# Convert figure to numpy array
|
| 138 |
+
fig.canvas.draw()
|
| 139 |
+
# Use buffer_rgba() which is more robust in newer matplotlib versions
|
| 140 |
+
rgba_buffer = fig.canvas.buffer_rgba()
|
| 141 |
+
result = np.array(rgba_buffer)[:, :, :3] # Strip alpha channel
|
| 142 |
+
plt.close(fig)
|
| 143 |
+
|
| 144 |
+
return result
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class Cifar10ResNet18Handler(BaseModelHandler):
|
| 148 |
+
"""Model handler for CIFAR-10 ResNet-18."""
|
| 149 |
+
|
| 150 |
+
def __init__(self, model_path: str):
|
| 151 |
+
self.model_path = model_path
|
| 152 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 153 |
+
self.model = None
|
| 154 |
+
self.grad_cam = None
|
| 155 |
+
self.history = {}
|
| 156 |
+
self.config = {}
|
| 157 |
+
self.best_accuracy = None
|
| 158 |
+
self._calibration_cache = {}
|
| 159 |
+
self.transform = transforms.Compose([
|
| 160 |
+
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
| 161 |
+
transforms.ToTensor(),
|
| 162 |
+
transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
|
| 163 |
+
])
|
| 164 |
+
self._load_model()
|
| 165 |
+
|
| 166 |
+
def _load_model(self):
|
| 167 |
+
"""Load the trained model."""
|
| 168 |
+
self.model = create_resnet18_cifar10(num_classes=10)
|
| 169 |
+
|
| 170 |
+
if os.path.exists(self.model_path):
|
| 171 |
+
checkpoint = torch.load(self.model_path, map_location=self.device,
|
| 172 |
+
weights_only=True)
|
| 173 |
+
if isinstance(checkpoint, dict):
|
| 174 |
+
self.history = checkpoint.get('history', {}) or {}
|
| 175 |
+
self.config = checkpoint.get('config', {}) or {}
|
| 176 |
+
self.best_accuracy = get_best_accuracy_from_history(self.history)
|
| 177 |
+
# Handle both state_dict and full model saves
|
| 178 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 179 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 180 |
+
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
| 181 |
+
self.model.load_state_dict(checkpoint['state_dict'])
|
| 182 |
+
else:
|
| 183 |
+
self.model.load_state_dict(checkpoint)
|
| 184 |
+
|
| 185 |
+
self.model = self.model.to(self.device)
|
| 186 |
+
self.model.eval()
|
| 187 |
+
|
| 188 |
+
# Initialize Grad-CAM with the last conv layer
|
| 189 |
+
self.grad_cam = GradCAM(self.model, self.model.layer4[-1])
|
| 190 |
+
|
| 191 |
+
precomputed_full = load_precomputed_calibration_result("resnet18")
|
| 192 |
+
if precomputed_full is not None:
|
| 193 |
+
self._calibration_cache["full"] = precomputed_full
|
| 194 |
+
|
| 195 |
+
def get_model_name(self) -> str:
|
| 196 |
+
return "ResNet-18"
|
| 197 |
+
|
| 198 |
+
def get_dataset_name(self) -> str:
|
| 199 |
+
return "CIFAR-10"
|
| 200 |
+
|
| 201 |
+
def get_data_type(self) -> str:
|
| 202 |
+
return "image"
|
| 203 |
+
|
| 204 |
+
def get_class_labels(self) -> List[str]:
|
| 205 |
+
return CIFAR10_LABELS
|
| 206 |
+
|
| 207 |
+
def get_model_info(self) -> Dict[str, str]:
|
| 208 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
| 209 |
+
best_accuracy = (
|
| 210 |
+
f"{self.best_accuracy:.2f}%"
|
| 211 |
+
if self.best_accuracy is not None
|
| 212 |
+
else "N/A"
|
| 213 |
+
)
|
| 214 |
+
info = {
|
| 215 |
+
"Architecture": "ResNet-18 (Transfer Learning from ImageNet)",
|
| 216 |
+
"Dataset": "CIFAR-10 (10 classes, 60,000 images)",
|
| 217 |
+
"Parameters": f"{total_params:,}",
|
| 218 |
+
"Input Size": f"{IMAGE_SIZE}×{IMAGE_SIZE}×3",
|
| 219 |
+
"Training": "Full fine-tune, AdamW, Cosine Annealing LR",
|
| 220 |
+
"Best Accuracy": best_accuracy,
|
| 221 |
+
"Device": str(self.device),
|
| 222 |
+
}
|
| 223 |
+
if "epochs" in self.config:
|
| 224 |
+
info["Epochs"] = str(self.config["epochs"])
|
| 225 |
+
full_result = self._calibration_cache.get("full")
|
| 226 |
+
if full_result is not None:
|
| 227 |
+
info["Full-Test ECE"] = f"{full_result.ece:.6f}"
|
| 228 |
+
return info
|
| 229 |
+
|
| 230 |
+
def predict(self, input_data) -> PredictionResult:
|
| 231 |
+
"""Run prediction with Grad-CAM visualization."""
|
| 232 |
+
if input_data is None:
|
| 233 |
+
raise ValueError("No input image provided")
|
| 234 |
+
|
| 235 |
+
# Convert to PIL Image if numpy array
|
| 236 |
+
if isinstance(input_data, np.ndarray):
|
| 237 |
+
original_image = input_data.copy()
|
| 238 |
+
pil_image = Image.fromarray(input_data).convert('RGB')
|
| 239 |
+
else:
|
| 240 |
+
pil_image = input_data.convert('RGB')
|
| 241 |
+
original_image = np.array(pil_image)
|
| 242 |
+
|
| 243 |
+
# Preprocess
|
| 244 |
+
input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
|
| 245 |
+
|
| 246 |
+
# Forward pass
|
| 247 |
+
with torch.no_grad():
|
| 248 |
+
output = self.model(input_tensor)
|
| 249 |
+
probabilities = torch.softmax(output, dim=1)[0]
|
| 250 |
+
|
| 251 |
+
probs = probabilities.cpu().numpy()
|
| 252 |
+
pred_idx = probs.argmax()
|
| 253 |
+
pred_label = CIFAR10_LABELS[pred_idx]
|
| 254 |
+
pred_conf = float(probs[pred_idx])
|
| 255 |
+
|
| 256 |
+
# Generate Grad-CAM
|
| 257 |
+
# Need to re-run with gradients enabled
|
| 258 |
+
input_tensor_grad = self.transform(pil_image).unsqueeze(0).to(self.device)
|
| 259 |
+
input_tensor_grad.requires_grad_(True)
|
| 260 |
+
|
| 261 |
+
heatmap = self.grad_cam.generate(input_tensor_grad, target_class=pred_idx)
|
| 262 |
+
explanation_image = create_gradcam_overlay(original_image, heatmap)
|
| 263 |
+
|
| 264 |
+
return PredictionResult(
|
| 265 |
+
label=pred_label,
|
| 266 |
+
confidence=pred_conf,
|
| 267 |
+
all_labels=CIFAR10_LABELS,
|
| 268 |
+
all_confidences=probs.tolist(),
|
| 269 |
+
explanation_image=explanation_image,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
def get_example_inputs(self) -> List[Any]:
|
| 273 |
+
"""Return example images from CIFAR-10 test set if available."""
|
| 274 |
+
return []
|
| 275 |
+
|
| 276 |
+
def get_calibration_data(
|
| 277 |
+
self, max_samples: Optional[int] = None
|
| 278 |
+
) -> Optional[CalibrationResult]:
|
| 279 |
+
"""
|
| 280 |
+
Compute calibration metrics on test set.
|
| 281 |
+
This runs evaluation on the full test set - can be slow on CPU.
|
| 282 |
+
"""
|
| 283 |
+
cache_key = "full" if max_samples is None else f"subset:{max_samples}"
|
| 284 |
+
if cache_key in self._calibration_cache:
|
| 285 |
+
return self._calibration_cache[cache_key]
|
| 286 |
+
|
| 287 |
+
try:
|
| 288 |
+
import matplotlib
|
| 289 |
+
matplotlib.use('Agg')
|
| 290 |
+
import matplotlib.pyplot as plt
|
| 291 |
+
|
| 292 |
+
test_dataset = create_cifar10_test_dataset(transform=self.transform)
|
| 293 |
+
if max_samples is not None and 0 < max_samples < len(test_dataset):
|
| 294 |
+
indices = np.linspace(
|
| 295 |
+
0, len(test_dataset) - 1, num=max_samples, dtype=int
|
| 296 |
+
).tolist()
|
| 297 |
+
test_dataset = torch.utils.data.Subset(test_dataset, indices)
|
| 298 |
+
|
| 299 |
+
test_loader = torch.utils.data.DataLoader(
|
| 300 |
+
test_dataset, batch_size=128, shuffle=False, num_workers=0
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
all_probs = []
|
| 304 |
+
all_preds = []
|
| 305 |
+
all_targets = []
|
| 306 |
+
|
| 307 |
+
self.model.eval()
|
| 308 |
+
with torch.inference_mode():
|
| 309 |
+
for inputs, targets in test_loader:
|
| 310 |
+
inputs = inputs.to(self.device)
|
| 311 |
+
outputs = self.model(inputs)
|
| 312 |
+
probs = torch.softmax(outputs, dim=1)
|
| 313 |
+
preds = outputs.argmax(1)
|
| 314 |
+
|
| 315 |
+
all_probs.extend(probs.cpu().numpy())
|
| 316 |
+
all_preds.extend(preds.cpu().numpy())
|
| 317 |
+
all_targets.extend(targets.numpy())
|
| 318 |
+
|
| 319 |
+
all_probs = np.array(all_probs)
|
| 320 |
+
all_preds = np.array(all_preds)
|
| 321 |
+
all_targets = np.array(all_targets)
|
| 322 |
+
|
| 323 |
+
# Compute ECE (Expected Calibration Error)
|
| 324 |
+
n_bins = 10
|
| 325 |
+
max_probs = np.max(all_probs, axis=1)
|
| 326 |
+
correctness = (all_preds == all_targets).astype(float)
|
| 327 |
+
|
| 328 |
+
bin_boundaries = np.linspace(0, 1, n_bins + 1)
|
| 329 |
+
bin_accuracies = []
|
| 330 |
+
bin_confidences = []
|
| 331 |
+
bin_counts = []
|
| 332 |
+
|
| 333 |
+
for i in range(n_bins):
|
| 334 |
+
lower = bin_boundaries[i]
|
| 335 |
+
upper = bin_boundaries[i + 1]
|
| 336 |
+
mask = (max_probs > lower) & (max_probs <= upper)
|
| 337 |
+
count = mask.sum()
|
| 338 |
+
bin_counts.append(int(count))
|
| 339 |
+
|
| 340 |
+
if count > 0:
|
| 341 |
+
bin_acc = correctness[mask].mean()
|
| 342 |
+
bin_conf = max_probs[mask].mean()
|
| 343 |
+
else:
|
| 344 |
+
bin_acc = 0.0
|
| 345 |
+
bin_conf = 0.0
|
| 346 |
+
|
| 347 |
+
bin_accuracies.append(float(bin_acc))
|
| 348 |
+
bin_confidences.append(float(bin_conf))
|
| 349 |
+
|
| 350 |
+
# Compute ECE
|
| 351 |
+
total = len(all_preds)
|
| 352 |
+
ece = sum(
|
| 353 |
+
(count / total) * abs(acc - conf)
|
| 354 |
+
for count, acc, conf in zip(bin_counts, bin_accuracies, bin_confidences)
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# Create reliability diagram
|
| 358 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
|
| 359 |
+
fig.patch.set_facecolor('#0d1117')
|
| 360 |
+
|
| 361 |
+
# Reliability Diagram
|
| 362 |
+
ax1.set_facecolor('#161b22')
|
| 363 |
+
bin_centers = [(bin_boundaries[i] + bin_boundaries[i + 1]) / 2 for i in range(n_bins)]
|
| 364 |
+
width = 0.08
|
| 365 |
+
|
| 366 |
+
bars1 = ax1.bar(
|
| 367 |
+
[c - width/2 for c in bin_centers], bin_accuracies, width,
|
| 368 |
+
label='Accuracy', color='#58a6ff', alpha=0.9, edgecolor='#58a6ff'
|
| 369 |
+
)
|
| 370 |
+
bars2 = ax1.bar(
|
| 371 |
+
[c + width/2 for c in bin_centers], bin_confidences, width,
|
| 372 |
+
label='Avg Confidence', color='#f97583', alpha=0.9, edgecolor='#f97583'
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
ax1.plot([0, 1], [0, 1], '--', color='#8b949e', linewidth=2,
|
| 376 |
+
label='Perfect Calibration')
|
| 377 |
+
ax1.set_xlim(0, 1)
|
| 378 |
+
ax1.set_ylim(0, 1)
|
| 379 |
+
ax1.set_xlabel('Confidence', color='white', fontsize=12)
|
| 380 |
+
ax1.set_ylabel('Accuracy / Confidence', color='white', fontsize=12)
|
| 381 |
+
ax1.set_title(
|
| 382 |
+
f'Reliability Diagram (ECE: {ece:.4f})',
|
| 383 |
+
color='white', fontsize=14, fontweight='bold', pad=15
|
| 384 |
+
)
|
| 385 |
+
ax1.legend(facecolor='#161b22', edgecolor='#30363d',
|
| 386 |
+
labelcolor='white', fontsize=10)
|
| 387 |
+
ax1.tick_params(colors='white')
|
| 388 |
+
for spine in ax1.spines.values():
|
| 389 |
+
spine.set_edgecolor('#30363d')
|
| 390 |
+
ax1.grid(True, alpha=0.1, color='white')
|
| 391 |
+
|
| 392 |
+
# Confidence histogram
|
| 393 |
+
ax2.set_facecolor('#161b22')
|
| 394 |
+
ax2.bar(
|
| 395 |
+
bin_centers, [c / total for c in bin_counts], 0.08,
|
| 396 |
+
color='#56d364', alpha=0.9, edgecolor='#56d364'
|
| 397 |
+
)
|
| 398 |
+
ax2.set_xlim(0, 1)
|
| 399 |
+
ax2.set_xlabel('Confidence', color='white', fontsize=12)
|
| 400 |
+
ax2.set_ylabel('Fraction of Samples', color='white', fontsize=12)
|
| 401 |
+
ax2.set_title(
|
| 402 |
+
'Confidence Distribution',
|
| 403 |
+
color='white', fontsize=14, fontweight='bold', pad=15
|
| 404 |
+
)
|
| 405 |
+
ax2.tick_params(colors='white')
|
| 406 |
+
for spine in ax2.spines.values():
|
| 407 |
+
spine.set_edgecolor('#30363d')
|
| 408 |
+
ax2.grid(True, alpha=0.1, color='white')
|
| 409 |
+
|
| 410 |
+
plt.tight_layout(pad=3)
|
| 411 |
+
|
| 412 |
+
# Convert to numpy
|
| 413 |
+
fig.canvas.draw()
|
| 414 |
+
rgba_buffer = fig.canvas.buffer_rgba()
|
| 415 |
+
diagram = np.array(rgba_buffer)[:, :, :3] # Strip alpha channel
|
| 416 |
+
plt.close(fig)
|
| 417 |
+
|
| 418 |
+
self._calibration_cache[cache_key] = CalibrationResult(
|
| 419 |
+
ece=ece,
|
| 420 |
+
bin_accuracies=bin_accuracies,
|
| 421 |
+
bin_confidences=bin_confidences,
|
| 422 |
+
bin_counts=bin_counts,
|
| 423 |
+
reliability_diagram=diagram,
|
| 424 |
+
source="Live computation",
|
| 425 |
+
)
|
| 426 |
+
return self._calibration_cache[cache_key]
|
| 427 |
+
|
| 428 |
+
except Exception as e:
|
| 429 |
+
print(f"Error computing calibration: {e}")
|
| 430 |
+
return None
|
assignments/assignment-1/app/image/vit_b16.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CIFAR-10 ViT-B/16 Model Handler
|
| 3 |
+
|
| 4 |
+
Handles prediction, Grad-CAM visualization, and calibration
|
| 5 |
+
for the ViT-B/16 model trained on CIFAR-10.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import types
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from typing import Dict, List, Optional, Any
|
| 15 |
+
import torchvision.transforms as transforms
|
| 16 |
+
from torchvision.models import vit_b_16
|
| 17 |
+
|
| 18 |
+
from app.shared.model_registry import (
|
| 19 |
+
BaseModelHandler,
|
| 20 |
+
PredictionResult,
|
| 21 |
+
CalibrationResult,
|
| 22 |
+
)
|
| 23 |
+
from app.shared.artifact_utils import (
|
| 24 |
+
get_best_accuracy_from_history,
|
| 25 |
+
load_precomputed_calibration_result,
|
| 26 |
+
)
|
| 27 |
+
from app.image.data import create_cifar10_test_dataset
|
| 28 |
+
|
| 29 |
+
# CIFAR-10 class labels
|
| 30 |
+
CIFAR10_LABELS = [
|
| 31 |
+
'airplane', 'automobile', 'bird', 'cat', 'deer',
|
| 32 |
+
'dog', 'frog', 'horse', 'ship', 'truck'
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
# CIFAR-10 normalization values
|
| 36 |
+
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
|
| 37 |
+
CIFAR10_STD = (0.2470, 0.2435, 0.2616)
|
| 38 |
+
|
| 39 |
+
# Image size ViT expects
|
| 40 |
+
IMAGE_SIZE = 224
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def create_vit_model(num_classes=10):
|
| 44 |
+
"""Create ViT-B/16 with modified classifier for CIFAR-10."""
|
| 45 |
+
model = vit_b_16(weights=None)
|
| 46 |
+
# Replace classifier head
|
| 47 |
+
model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
|
| 48 |
+
return model
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ViTAttentionVisualizer:
|
| 52 |
+
"""
|
| 53 |
+
Attention visualization for ViT.
|
| 54 |
+
Shows which patches the model attends to.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, model):
|
| 58 |
+
self.model = model
|
| 59 |
+
self.attentions = None
|
| 60 |
+
self._patch_last_encoder_block()
|
| 61 |
+
|
| 62 |
+
def _patch_last_encoder_block(self):
|
| 63 |
+
"""
|
| 64 |
+
Torchvision's ViT encoder block calls MultiheadAttention with
|
| 65 |
+
need_weights=False, so a normal forward hook never receives attention
|
| 66 |
+
maps. We patch only the last block to request weights during inference.
|
| 67 |
+
"""
|
| 68 |
+
last_block = self.model.encoder.layers[-1]
|
| 69 |
+
visualizer = self
|
| 70 |
+
|
| 71 |
+
def forward_with_attention(block, input_tensor):
|
| 72 |
+
torch._assert(
|
| 73 |
+
input_tensor.dim() == 3,
|
| 74 |
+
f"Expected (batch_size, seq_length, hidden_dim) got {input_tensor.shape}",
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
x = block.ln_1(input_tensor)
|
| 78 |
+
attn_output, attn_weights = block.self_attention(
|
| 79 |
+
x,
|
| 80 |
+
x,
|
| 81 |
+
x,
|
| 82 |
+
need_weights=True,
|
| 83 |
+
average_attn_weights=False,
|
| 84 |
+
)
|
| 85 |
+
visualizer.attentions = attn_weights.detach()
|
| 86 |
+
|
| 87 |
+
x = block.dropout(attn_output)
|
| 88 |
+
x = x + input_tensor
|
| 89 |
+
|
| 90 |
+
y = block.ln_2(x)
|
| 91 |
+
y = block.mlp(y)
|
| 92 |
+
return x + y
|
| 93 |
+
|
| 94 |
+
last_block.forward = types.MethodType(forward_with_attention, last_block)
|
| 95 |
+
|
| 96 |
+
def generate_attention_map(self, input_tensor):
|
| 97 |
+
"""Generate attention map from input tensor."""
|
| 98 |
+
self.model.eval()
|
| 99 |
+
|
| 100 |
+
# Forward pass
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
_ = self.model(input_tensor)
|
| 103 |
+
|
| 104 |
+
if self.attentions is None:
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
# Get the [CLS] token attention across all heads
|
| 108 |
+
# Shape: (batch, heads, seq_len, seq_len) -> take cls token row
|
| 109 |
+
cls_attention = self.attentions[0, :, 0, 1:].mean(dim=0) # Average over heads
|
| 110 |
+
|
| 111 |
+
# Reshape to patch grid (assuming 16x16 patches for 224x224 image)
|
| 112 |
+
num_patches = int(cls_attention.shape[0] ** 0.5)
|
| 113 |
+
|
| 114 |
+
if num_patches * num_patches != cls_attention.shape[0]:
|
| 115 |
+
# Fallback: just return raw attention
|
| 116 |
+
return cls_attention.cpu().numpy()
|
| 117 |
+
|
| 118 |
+
# Reshape to 2D grid
|
| 119 |
+
attention_map = cls_attention.reshape(num_patches, num_patches).cpu().numpy()
|
| 120 |
+
|
| 121 |
+
# Normalize
|
| 122 |
+
attention_map = attention_map - attention_map.min()
|
| 123 |
+
if attention_map.max() > 0:
|
| 124 |
+
attention_map = attention_map / attention_map.max()
|
| 125 |
+
|
| 126 |
+
return attention_map
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def create_attention_overlay(image_np, attention_map, alpha=0.5):
|
| 130 |
+
"""Create overlay of attention map on original image."""
|
| 131 |
+
import matplotlib
|
| 132 |
+
matplotlib.use('Agg')
|
| 133 |
+
import matplotlib.pyplot as plt
|
| 134 |
+
import matplotlib.cm as cm
|
| 135 |
+
|
| 136 |
+
if attention_map is None:
|
| 137 |
+
return image_np
|
| 138 |
+
|
| 139 |
+
# Resize attention map to image size
|
| 140 |
+
from PIL import Image as PILImage
|
| 141 |
+
attention_uint8 = (attention_map * 255).astype(np.uint8)
|
| 142 |
+
attention_resized = PILImage.fromarray(attention_uint8).resize(
|
| 143 |
+
(IMAGE_SIZE, IMAGE_SIZE), PILImage.BILINEAR
|
| 144 |
+
)
|
| 145 |
+
attention_resized = np.array(attention_resized).astype(np.float32) / 255.0
|
| 146 |
+
|
| 147 |
+
if image_np.shape[:2] != (IMAGE_SIZE, IMAGE_SIZE):
|
| 148 |
+
image_np = np.array(
|
| 149 |
+
PILImage.fromarray(image_np).resize((IMAGE_SIZE, IMAGE_SIZE), PILImage.BILINEAR)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Apply colormap
|
| 153 |
+
colormap = cm.jet(attention_resized)[:, :, :3]
|
| 154 |
+
colormap = (colormap * 255).astype(np.uint8)
|
| 155 |
+
|
| 156 |
+
# Create overlay
|
| 157 |
+
overlay = (alpha * colormap + (1 - alpha) * image_np).astype(np.uint8)
|
| 158 |
+
|
| 159 |
+
# Create figure
|
| 160 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 161 |
+
fig.patch.set_facecolor('#0d1117')
|
| 162 |
+
|
| 163 |
+
titles = ['Original Image', 'Attention Map', 'Overlay']
|
| 164 |
+
images = [image_np, colormap, overlay]
|
| 165 |
+
|
| 166 |
+
for ax, img, title in zip(axes, images, titles):
|
| 167 |
+
ax.imshow(img)
|
| 168 |
+
ax.set_title(title, color='white', fontsize=14, fontweight='bold', pad=10)
|
| 169 |
+
ax.axis('off')
|
| 170 |
+
ax.set_facecolor('#0d1117')
|
| 171 |
+
|
| 172 |
+
plt.tight_layout(pad=2)
|
| 173 |
+
|
| 174 |
+
fig.canvas.draw()
|
| 175 |
+
rgba_buffer = fig.canvas.buffer_rgba()
|
| 176 |
+
result = np.array(rgba_buffer)[:, :, :3]
|
| 177 |
+
plt.close(fig)
|
| 178 |
+
|
| 179 |
+
return result
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class Cifar10ViTHandler(BaseModelHandler):
|
| 183 |
+
"""Model handler for CIFAR-10 ViT-B/16."""
|
| 184 |
+
|
| 185 |
+
def __init__(self, model_path: str):
|
| 186 |
+
self.model_path = model_path
|
| 187 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 188 |
+
self.model = None
|
| 189 |
+
self.attention_viz = None
|
| 190 |
+
self.history = {}
|
| 191 |
+
self.best_accuracy = None
|
| 192 |
+
self._calibration_cache = {}
|
| 193 |
+
self.transform = transforms.Compose([
|
| 194 |
+
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
| 195 |
+
transforms.ToTensor(),
|
| 196 |
+
transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
|
| 197 |
+
])
|
| 198 |
+
self._load_model()
|
| 199 |
+
|
| 200 |
+
def _load_model(self):
|
| 201 |
+
"""Load the trained model."""
|
| 202 |
+
self.model = create_vit_model(num_classes=10)
|
| 203 |
+
|
| 204 |
+
if os.path.exists(self.model_path):
|
| 205 |
+
checkpoint = torch.load(self.model_path, map_location=self.device,
|
| 206 |
+
weights_only=True)
|
| 207 |
+
if isinstance(checkpoint, dict):
|
| 208 |
+
self.history = checkpoint.get('history', {}) or {}
|
| 209 |
+
self.best_accuracy = get_best_accuracy_from_history(self.history)
|
| 210 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 211 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 212 |
+
else:
|
| 213 |
+
self.model.load_state_dict(checkpoint)
|
| 214 |
+
|
| 215 |
+
self.model = self.model.to(self.device)
|
| 216 |
+
self.model.eval()
|
| 217 |
+
|
| 218 |
+
# Initialize attention visualizer
|
| 219 |
+
self.attention_viz = ViTAttentionVisualizer(self.model)
|
| 220 |
+
|
| 221 |
+
precomputed_full = load_precomputed_calibration_result("vit_b16")
|
| 222 |
+
if precomputed_full is not None:
|
| 223 |
+
self._calibration_cache["full"] = precomputed_full
|
| 224 |
+
|
| 225 |
+
def get_model_name(self) -> str:
|
| 226 |
+
return "ViT-B/16"
|
| 227 |
+
|
| 228 |
+
def get_dataset_name(self) -> str:
|
| 229 |
+
return "CIFAR-10"
|
| 230 |
+
|
| 231 |
+
def get_data_type(self) -> str:
|
| 232 |
+
return "image"
|
| 233 |
+
|
| 234 |
+
def get_class_labels(self) -> List[str]:
|
| 235 |
+
return CIFAR10_LABELS
|
| 236 |
+
|
| 237 |
+
def get_model_info(self) -> Dict[str, str]:
|
| 238 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
| 239 |
+
best_accuracy = (
|
| 240 |
+
f"{self.best_accuracy:.2f}%"
|
| 241 |
+
if self.best_accuracy is not None
|
| 242 |
+
else "N/A"
|
| 243 |
+
)
|
| 244 |
+
info = {
|
| 245 |
+
"Architecture": "ViT-B/16 (Transfer Learning from ImageNet)",
|
| 246 |
+
"Dataset": "CIFAR-10 (10 classes, 60,000 images)",
|
| 247 |
+
"Parameters": f"{total_params:,}",
|
| 248 |
+
"Input Size": f"{IMAGE_SIZE}×{IMAGE_SIZE}×3",
|
| 249 |
+
"Training": "Full fine-tune, AdamW, Cosine Annealing LR",
|
| 250 |
+
"Best Accuracy": best_accuracy,
|
| 251 |
+
"Device": str(self.device),
|
| 252 |
+
}
|
| 253 |
+
if self.history:
|
| 254 |
+
info["Epochs"] = str(len(self.history.get("val_acc", [])))
|
| 255 |
+
full_result = self._calibration_cache.get("full")
|
| 256 |
+
if full_result is not None:
|
| 257 |
+
info["Full-Test ECE"] = f"{full_result.ece:.6f}"
|
| 258 |
+
return info
|
| 259 |
+
|
| 260 |
+
def predict(self, input_data) -> PredictionResult:
|
| 261 |
+
"""Run prediction with attention visualization."""
|
| 262 |
+
if input_data is None:
|
| 263 |
+
raise ValueError("No input image provided")
|
| 264 |
+
|
| 265 |
+
# Convert to PIL Image if numpy array
|
| 266 |
+
if isinstance(input_data, np.ndarray):
|
| 267 |
+
original_image = input_data.copy()
|
| 268 |
+
pil_image = Image.fromarray(input_data).convert('RGB')
|
| 269 |
+
else:
|
| 270 |
+
pil_image = input_data.convert('RGB')
|
| 271 |
+
original_image = np.array(pil_image)
|
| 272 |
+
|
| 273 |
+
# Preprocess
|
| 274 |
+
input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
|
| 275 |
+
|
| 276 |
+
# Forward pass
|
| 277 |
+
with torch.no_grad():
|
| 278 |
+
output = self.model(input_tensor)
|
| 279 |
+
probabilities = torch.softmax(output, dim=1)[0]
|
| 280 |
+
|
| 281 |
+
probs = probabilities.cpu().numpy()
|
| 282 |
+
pred_idx = probs.argmax()
|
| 283 |
+
pred_label = CIFAR10_LABELS[pred_idx]
|
| 284 |
+
pred_conf = float(probs[pred_idx])
|
| 285 |
+
|
| 286 |
+
# Generate attention visualization
|
| 287 |
+
attention_map = self.attention_viz.generate_attention_map(input_tensor)
|
| 288 |
+
explanation_image = create_attention_overlay(original_image, attention_map)
|
| 289 |
+
|
| 290 |
+
return PredictionResult(
|
| 291 |
+
label=pred_label,
|
| 292 |
+
confidence=pred_conf,
|
| 293 |
+
all_labels=CIFAR10_LABELS,
|
| 294 |
+
all_confidences=probs.tolist(),
|
| 295 |
+
explanation_image=explanation_image,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
def get_example_inputs(self) -> List[Any]:
|
| 299 |
+
return []
|
| 300 |
+
|
| 301 |
+
def get_calibration_data(
|
| 302 |
+
self, max_samples: Optional[int] = None
|
| 303 |
+
) -> Optional[CalibrationResult]:
|
| 304 |
+
"""Compute calibration metrics on test set."""
|
| 305 |
+
cache_key = "full" if max_samples is None else f"subset:{max_samples}"
|
| 306 |
+
if cache_key in self._calibration_cache:
|
| 307 |
+
return self._calibration_cache[cache_key]
|
| 308 |
+
|
| 309 |
+
try:
|
| 310 |
+
import matplotlib
|
| 311 |
+
matplotlib.use('Agg')
|
| 312 |
+
import matplotlib.pyplot as plt
|
| 313 |
+
|
| 314 |
+
test_dataset = create_cifar10_test_dataset(transform=self.transform)
|
| 315 |
+
if max_samples is not None and 0 < max_samples < len(test_dataset):
|
| 316 |
+
indices = np.linspace(
|
| 317 |
+
0, len(test_dataset) - 1, num=max_samples, dtype=int
|
| 318 |
+
).tolist()
|
| 319 |
+
test_dataset = torch.utils.data.Subset(test_dataset, indices)
|
| 320 |
+
|
| 321 |
+
test_loader = torch.utils.data.DataLoader(
|
| 322 |
+
test_dataset, batch_size=128, shuffle=False, num_workers=0
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
all_probs = []
|
| 326 |
+
all_preds = []
|
| 327 |
+
all_targets = []
|
| 328 |
+
|
| 329 |
+
self.model.eval()
|
| 330 |
+
with torch.inference_mode():
|
| 331 |
+
for inputs, targets in test_loader:
|
| 332 |
+
inputs = inputs.to(self.device)
|
| 333 |
+
outputs = self.model(inputs)
|
| 334 |
+
probs = torch.softmax(outputs, dim=1)
|
| 335 |
+
preds = outputs.argmax(1)
|
| 336 |
+
|
| 337 |
+
all_probs.extend(probs.cpu().numpy())
|
| 338 |
+
all_preds.extend(preds.cpu().numpy())
|
| 339 |
+
all_targets.extend(targets.numpy())
|
| 340 |
+
|
| 341 |
+
all_probs = np.array(all_probs)
|
| 342 |
+
all_preds = np.array(all_preds)
|
| 343 |
+
all_targets = np.array(all_targets)
|
| 344 |
+
|
| 345 |
+
# Compute ECE
|
| 346 |
+
n_bins = 10
|
| 347 |
+
max_probs = np.max(all_probs, axis=1)
|
| 348 |
+
correctness = (all_preds == all_targets).astype(float)
|
| 349 |
+
|
| 350 |
+
bin_boundaries = np.linspace(0, 1, n_bins + 1)
|
| 351 |
+
bin_accuracies = []
|
| 352 |
+
bin_confidences = []
|
| 353 |
+
bin_counts = []
|
| 354 |
+
|
| 355 |
+
for i in range(n_bins):
|
| 356 |
+
lower = bin_boundaries[i]
|
| 357 |
+
upper = bin_boundaries[i + 1]
|
| 358 |
+
mask = (max_probs > lower) & (max_probs <= upper)
|
| 359 |
+
count = mask.sum()
|
| 360 |
+
bin_counts.append(int(count))
|
| 361 |
+
|
| 362 |
+
if count > 0:
|
| 363 |
+
bin_acc = correctness[mask].mean()
|
| 364 |
+
bin_conf = max_probs[mask].mean()
|
| 365 |
+
else:
|
| 366 |
+
bin_acc = 0.0
|
| 367 |
+
bin_conf = 0.0
|
| 368 |
+
|
| 369 |
+
bin_accuracies.append(float(bin_acc))
|
| 370 |
+
bin_confidences.append(float(bin_conf))
|
| 371 |
+
|
| 372 |
+
# Compute ECE
|
| 373 |
+
total = len(all_preds)
|
| 374 |
+
ece = sum(
|
| 375 |
+
(count / total) * abs(acc - conf)
|
| 376 |
+
for count, acc, conf in zip(bin_counts, bin_accuracies, bin_confidences)
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
# Create reliability diagram
|
| 380 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
|
| 381 |
+
fig.patch.set_facecolor('#0d1117')
|
| 382 |
+
|
| 383 |
+
# Reliability Diagram
|
| 384 |
+
ax1.set_facecolor('#161b22')
|
| 385 |
+
bin_centers = [(bin_boundaries[i] + bin_boundaries[i + 1]) / 2 for i in range(n_bins)]
|
| 386 |
+
width = 0.08
|
| 387 |
+
|
| 388 |
+
ax1.bar([c - width/2 for c in bin_centers], bin_accuracies, width,
|
| 389 |
+
label='Accuracy', color='#58a6ff', alpha=0.9, edgecolor='#58a6ff')
|
| 390 |
+
ax1.bar([c + width/2 for c in bin_centers], bin_confidences, width,
|
| 391 |
+
label='Avg Confidence', color='#f97583', alpha=0.9, edgecolor='#f97583')
|
| 392 |
+
|
| 393 |
+
ax1.plot([0, 1], [0, 1], '--', color='#8b949e', linewidth=2,
|
| 394 |
+
label='Perfect Calibration')
|
| 395 |
+
ax1.set_xlim(0, 1)
|
| 396 |
+
ax1.set_ylim(0, 1)
|
| 397 |
+
ax1.set_xlabel('Confidence', color='white', fontsize=12)
|
| 398 |
+
ax1.set_ylabel('Accuracy / Confidence', color='white', fontsize=12)
|
| 399 |
+
ax1.set_title(f'Reliability Diagram (ECE: {ece:.4f})',
|
| 400 |
+
color='white', fontsize=14, fontweight='bold', pad=15)
|
| 401 |
+
ax1.legend(facecolor='#161b22', edgecolor='#30363d', labelcolor='white', fontsize=10)
|
| 402 |
+
ax1.tick_params(colors='white')
|
| 403 |
+
for spine in ax1.spines.values():
|
| 404 |
+
spine.set_edgecolor('#30363d')
|
| 405 |
+
ax1.grid(True, alpha=0.1, color='white')
|
| 406 |
+
|
| 407 |
+
# Confidence histogram
|
| 408 |
+
ax2.set_facecolor('#161b22')
|
| 409 |
+
ax2.bar(bin_centers, [c / total for c in bin_counts], 0.08,
|
| 410 |
+
color='#56d364', alpha=0.9, edgecolor='#56d364')
|
| 411 |
+
ax2.set_xlim(0, 1)
|
| 412 |
+
ax2.set_xlabel('Confidence', color='white', fontsize=12)
|
| 413 |
+
ax2.set_ylabel('Fraction of Samples', color='white', fontsize=12)
|
| 414 |
+
ax2.set_title('Confidence Distribution',
|
| 415 |
+
color='white', fontsize=14, fontweight='bold', pad=15)
|
| 416 |
+
ax2.tick_params(colors='white')
|
| 417 |
+
for spine in ax2.spines.values():
|
| 418 |
+
spine.set_edgecolor('#30363d')
|
| 419 |
+
ax2.grid(True, alpha=0.1, color='white')
|
| 420 |
+
|
| 421 |
+
plt.tight_layout(pad=3)
|
| 422 |
+
|
| 423 |
+
fig.canvas.draw()
|
| 424 |
+
rgba_buffer = fig.canvas.buffer_rgba()
|
| 425 |
+
diagram = np.array(rgba_buffer)[:, :, :3]
|
| 426 |
+
plt.close(fig)
|
| 427 |
+
|
| 428 |
+
self._calibration_cache[cache_key] = CalibrationResult(
|
| 429 |
+
ece=ece,
|
| 430 |
+
bin_accuracies=bin_accuracies,
|
| 431 |
+
bin_confidences=bin_confidences,
|
| 432 |
+
bin_counts=bin_counts,
|
| 433 |
+
reliability_diagram=diagram,
|
| 434 |
+
source="Live computation",
|
| 435 |
+
)
|
| 436 |
+
return self._calibration_cache[cache_key]
|
| 437 |
+
|
| 438 |
+
except Exception as e:
|
| 439 |
+
print(f"Error computing calibration: {e}")
|
| 440 |
+
return None
|
assignments/assignment-1/app/main.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Deep Learning Assignment 1 - Application Demo
|
| 3 |
+
===============================================
|
| 4 |
+
A modular Gradio application for demonstrating
|
| 5 |
+
trained models on Image, Text, and Multimodal datasets.
|
| 6 |
+
|
| 7 |
+
Features:
|
| 8 |
+
- Image classification with Grad-CAM / attention visualization
|
| 9 |
+
- Model Calibration analysis (ECE + Reliability Diagram)
|
| 10 |
+
- Easy to extend with new models/datasets
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python assignments/assignment-1/app/main.py
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import sys
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
# Add assignment root to path so `app.*` imports keep working.
|
| 20 |
+
ASSIGNMENT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 21 |
+
sys.path.insert(0, ASSIGNMENT_ROOT)
|
| 22 |
+
|
| 23 |
+
import gradio as gr
|
| 24 |
+
from typing import Dict
|
| 25 |
+
|
| 26 |
+
from app.shared.model_registry import (
|
| 27 |
+
register_model,
|
| 28 |
+
get_all_model_keys,
|
| 29 |
+
get_models_by_type,
|
| 30 |
+
BaseModelHandler,
|
| 31 |
+
)
|
| 32 |
+
from app.image.resnet18 import Cifar10ResNet18Handler
|
| 33 |
+
from app.image.vit_b16 import Cifar10ViTHandler
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ============================================================================
|
| 37 |
+
# CONFIGURATION
|
| 38 |
+
# ============================================================================
|
| 39 |
+
|
| 40 |
+
APP_TITLE = "🧠 Deep Learning Assignment 1 - Demo"
|
| 41 |
+
APP_DESCRIPTION = """
|
| 42 |
+
<div style="text-align: center; padding: 10px 0;">
|
| 43 |
+
<p style="font-size: 16px; color: #8b949e; margin: 5px 0;">
|
| 44 |
+
Classification on Images, Text, and Multimodal Data
|
| 45 |
+
</p>
|
| 46 |
+
<p style="font-size: 14px; color: #58a6ff; margin: 5px 0;">
|
| 47 |
+
CO3091 · HCM University of Technology · 2025-2026 Semester 2
|
| 48 |
+
</p>
|
| 49 |
+
</div>
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
# Load custom CSS from external file
|
| 53 |
+
CSS_PATH = os.path.join(os.path.dirname(__file__), "assets", "style.css")
|
| 54 |
+
if os.path.exists(CSS_PATH):
|
| 55 |
+
with open(CSS_PATH, "r", encoding="utf-8") as f:
|
| 56 |
+
CUSTOM_CSS = f.read()
|
| 57 |
+
else:
|
| 58 |
+
CUSTOM_CSS = ""
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
CUSTOM_THEME = gr.themes.Base(
|
| 62 |
+
primary_hue=gr.themes.colors.blue,
|
| 63 |
+
secondary_hue=gr.themes.colors.green,
|
| 64 |
+
neutral_hue=gr.themes.colors.gray,
|
| 65 |
+
font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"],
|
| 66 |
+
).set(
|
| 67 |
+
body_background_fill="#0d1117",
|
| 68 |
+
body_background_fill_dark="#0d1117",
|
| 69 |
+
block_background_fill="#161b22",
|
| 70 |
+
block_background_fill_dark="#161b22",
|
| 71 |
+
block_border_color="#30363d",
|
| 72 |
+
block_border_color_dark="#30363d",
|
| 73 |
+
block_label_text_color="#c9d1d9",
|
| 74 |
+
block_label_text_color_dark="#c9d1d9",
|
| 75 |
+
block_title_text_color="#f0f6fc",
|
| 76 |
+
block_title_text_color_dark="#f0f6fc",
|
| 77 |
+
body_text_color="#c9d1d9",
|
| 78 |
+
body_text_color_dark="#c9d1d9",
|
| 79 |
+
body_text_color_subdued="#8b949e",
|
| 80 |
+
body_text_color_subdued_dark="#8b949e",
|
| 81 |
+
button_primary_background_fill="#238636",
|
| 82 |
+
button_primary_background_fill_dark="#238636",
|
| 83 |
+
button_primary_background_fill_hover="#2ea043",
|
| 84 |
+
button_primary_background_fill_hover_dark="#2ea043",
|
| 85 |
+
button_primary_text_color="white",
|
| 86 |
+
button_primary_text_color_dark="white",
|
| 87 |
+
input_background_fill="#0d1117",
|
| 88 |
+
input_background_fill_dark="#0d1117",
|
| 89 |
+
input_border_color="#30363d",
|
| 90 |
+
input_border_color_dark="#30363d",
|
| 91 |
+
shadow_drop="none",
|
| 92 |
+
shadow_drop_lg="none",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ============================================================================
|
| 97 |
+
# MODEL INITIALIZATION
|
| 98 |
+
# ============================================================================
|
| 99 |
+
|
| 100 |
+
def init_models():
|
| 101 |
+
"""Initialize and register all available models."""
|
| 102 |
+
model_dir = os.path.join(ASSIGNMENT_ROOT, "image", "models")
|
| 103 |
+
|
| 104 |
+
# CIFAR-10 ResNet-18
|
| 105 |
+
resnet18_path = os.path.join(model_dir, "resnet18_cifar10.pth")
|
| 106 |
+
if os.path.exists(resnet18_path):
|
| 107 |
+
try:
|
| 108 |
+
handler = Cifar10ResNet18Handler(resnet18_path)
|
| 109 |
+
register_model("cifar10_resnet18", handler)
|
| 110 |
+
print(f"✅ Loaded: CIFAR-10 ResNet-18 from {resnet18_path}")
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"❌ Failed to load CIFAR-10 ResNet-18: {e}")
|
| 113 |
+
else:
|
| 114 |
+
print(f"⚠️ Model file not found: {resnet18_path}")
|
| 115 |
+
|
| 116 |
+
# CIFAR-10 ViT-B/16
|
| 117 |
+
vit_path = os.path.join(model_dir, "vit_b16_cifar10.pth")
|
| 118 |
+
if os.path.exists(vit_path):
|
| 119 |
+
try:
|
| 120 |
+
handler = Cifar10ViTHandler(vit_path)
|
| 121 |
+
register_model("cifar10_vit", handler)
|
| 122 |
+
print(f"✅ Loaded: CIFAR-10 ViT-B/16 from {vit_path}")
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f"❌ Failed to load CIFAR-10 ViT-B/16: {e}")
|
| 125 |
+
else:
|
| 126 |
+
print(f"⚠️ Model file not found: {vit_path}")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ============================================================================
|
| 130 |
+
# UI BUILDER FUNCTIONS
|
| 131 |
+
# ============================================================================
|
| 132 |
+
|
| 133 |
+
def format_confidence_label(labels, confidences, top_k=5):
|
| 134 |
+
"""Format top-k predictions as a dictionary for gr.Label."""
|
| 135 |
+
paired = sorted(zip(labels, confidences), key=lambda x: x[1], reverse=True)
|
| 136 |
+
return {label: float(conf) for label, conf in paired[:top_k]}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def build_model_info_markdown(handler: BaseModelHandler) -> str:
|
| 140 |
+
"""Build formatted model info markdown."""
|
| 141 |
+
info = handler.get_model_info()
|
| 142 |
+
lines = ["### 📋 Model Information\n"]
|
| 143 |
+
for key, val in info.items():
|
| 144 |
+
lines.append(f"| **{key}** | {val} |")
|
| 145 |
+
|
| 146 |
+
header = "| Property | Value |\n|:---|:---|\n"
|
| 147 |
+
table_lines = [line for line in lines[1:]]
|
| 148 |
+
return lines[0] + header + "\n".join(table_lines)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def build_image_prediction_tab(model_key: str, handler: BaseModelHandler):
|
| 152 |
+
"""Build the prediction tab UI for image models."""
|
| 153 |
+
with gr.Row(equal_height=True):
|
| 154 |
+
with gr.Column(scale=1):
|
| 155 |
+
input_image = gr.Image(
|
| 156 |
+
label="📸 Upload Image",
|
| 157 |
+
type="numpy",
|
| 158 |
+
height=300,
|
| 159 |
+
sources=["upload", "clipboard"],
|
| 160 |
+
)
|
| 161 |
+
predict_btn = gr.Button(
|
| 162 |
+
"🔍 Predict & Explain",
|
| 163 |
+
variant="primary",
|
| 164 |
+
size="lg",
|
| 165 |
+
)
|
| 166 |
+
gr.Markdown(
|
| 167 |
+
f"*Classes: {', '.join(handler.get_class_labels())}*",
|
| 168 |
+
elem_classes=["text-sm"],
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
with gr.Column(scale=1):
|
| 172 |
+
output_label = gr.Label(
|
| 173 |
+
label="📊 Prediction Results (Top-5)",
|
| 174 |
+
num_top_classes=5,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
with gr.Row():
|
| 178 |
+
explanation_image = gr.Image(
|
| 179 |
+
label="🔥 Model Explanation (Interpretability)",
|
| 180 |
+
interactive=False,
|
| 181 |
+
height=350,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def do_predict(image):
|
| 185 |
+
if image is None:
|
| 186 |
+
return None, None
|
| 187 |
+
try:
|
| 188 |
+
result = handler.predict(image)
|
| 189 |
+
conf_dict = format_confidence_label(
|
| 190 |
+
result.all_labels, result.all_confidences
|
| 191 |
+
)
|
| 192 |
+
return conf_dict, result.explanation_image
|
| 193 |
+
except Exception as e:
|
| 194 |
+
raise gr.Error(f"Prediction failed: {str(e)}")
|
| 195 |
+
|
| 196 |
+
predict_btn.click(
|
| 197 |
+
fn=do_predict,
|
| 198 |
+
inputs=[input_image],
|
| 199 |
+
outputs=[output_label, explanation_image],
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def build_calibration_tab(model_key: str, handler: BaseModelHandler):
|
| 204 |
+
"""Build the calibration analysis tab."""
|
| 205 |
+
gr.Markdown("""
|
| 206 |
+
### 📐 Model Calibration Analysis
|
| 207 |
+
|
| 208 |
+
Calibration measures how well the model's confidence matches its actual accuracy.
|
| 209 |
+
A perfectly calibrated model has **confidence = accuracy** for all predictions.
|
| 210 |
+
|
| 211 |
+
- **ECE (Expected Calibration Error)**: Lower is better (0 = perfect calibration)
|
| 212 |
+
- **Reliability Diagram**: Compares predicted confidence vs actual accuracy per bin
|
| 213 |
+
- **Quick Preview**: Uses a very small subset for fast CPU demos
|
| 214 |
+
- **Full Test Set**: Uses notebook artifacts instantly when available
|
| 215 |
+
""")
|
| 216 |
+
|
| 217 |
+
calibration_mode = gr.Radio(
|
| 218 |
+
choices=[
|
| 219 |
+
"Quick Preview (64 samples)",
|
| 220 |
+
"Full Test Set (10,000 samples)",
|
| 221 |
+
],
|
| 222 |
+
value="Quick Preview (64 samples)",
|
| 223 |
+
label="Calibration Mode",
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
compute_btn = gr.Button(
|
| 227 |
+
"📊 Compute Calibration",
|
| 228 |
+
variant="primary",
|
| 229 |
+
size="lg",
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
ece_display = gr.Markdown(visible=False)
|
| 233 |
+
calibration_plot = gr.Image(
|
| 234 |
+
label="📈 Calibration Analysis",
|
| 235 |
+
interactive=False,
|
| 236 |
+
visible=False,
|
| 237 |
+
height=450,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def compute_calibration(mode):
|
| 241 |
+
try:
|
| 242 |
+
max_samples = 64 if mode.startswith("Quick Preview") else None
|
| 243 |
+
result = handler.get_calibration_data(max_samples=max_samples)
|
| 244 |
+
if result is None:
|
| 245 |
+
raise gr.Error("Could not compute calibration data")
|
| 246 |
+
|
| 247 |
+
sample_note = (
|
| 248 |
+
"Approximate preview on 64 evenly spaced test images"
|
| 249 |
+
if max_samples is not None
|
| 250 |
+
else "Full CIFAR-10 test set"
|
| 251 |
+
)
|
| 252 |
+
source_note = result.source or "Live computation"
|
| 253 |
+
ece_md = f"""
|
| 254 |
+
### Calibration Metrics
|
| 255 |
+
|
| 256 |
+
| Metric | Value |
|
| 257 |
+
|:---|:---|
|
| 258 |
+
| **Mode** | {sample_note} |
|
| 259 |
+
| **Source** | {source_note} |
|
| 260 |
+
| **Expected Calibration Error (ECE)** | `{result.ece:.6f}` |
|
| 261 |
+
| **Interpretation** | {'✅ Well calibrated' if result.ece < 0.05 else '⚠️ Moderately calibrated' if result.ece < 0.15 else '❌ Poorly calibrated'} |
|
| 262 |
+
| **Total evaluated samples** | {sum(result.bin_counts):,} |
|
| 263 |
+
"""
|
| 264 |
+
return (
|
| 265 |
+
gr.update(value=ece_md, visible=True),
|
| 266 |
+
gr.update(value=result.reliability_diagram, visible=True),
|
| 267 |
+
)
|
| 268 |
+
except Exception as e:
|
| 269 |
+
raise gr.Error(f"Calibration computation failed: {str(e)}")
|
| 270 |
+
|
| 271 |
+
compute_btn.click(
|
| 272 |
+
fn=compute_calibration,
|
| 273 |
+
inputs=[calibration_mode],
|
| 274 |
+
outputs=[ece_display, calibration_plot],
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def build_model_tabs(model_key: str, handler: BaseModelHandler):
|
| 279 |
+
"""Build all tabs for a specific model."""
|
| 280 |
+
gr.Markdown(build_model_info_markdown(handler))
|
| 281 |
+
|
| 282 |
+
with gr.Tabs():
|
| 283 |
+
with gr.Tab("🎯 Predict & Explain", id="predict"):
|
| 284 |
+
data_type = handler.get_data_type()
|
| 285 |
+
if data_type == "image":
|
| 286 |
+
build_image_prediction_tab(model_key, handler)
|
| 287 |
+
elif data_type == "text":
|
| 288 |
+
gr.Markdown("### 📝 Text Classification\n*Coming soon...*")
|
| 289 |
+
elif data_type == "multimodal":
|
| 290 |
+
gr.Markdown("### 🖼️+📝 Multimodal Classification\n*Coming soon...*")
|
| 291 |
+
|
| 292 |
+
with gr.Tab("📐 Calibration", id="calibration"):
|
| 293 |
+
build_calibration_tab(model_key, handler)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
# ============================================================================
|
| 297 |
+
# MAIN APPLICATION
|
| 298 |
+
# ============================================================================
|
| 299 |
+
|
| 300 |
+
def create_app() -> gr.Blocks:
|
| 301 |
+
"""Create the main Gradio application."""
|
| 302 |
+
init_models()
|
| 303 |
+
|
| 304 |
+
with gr.Blocks(
|
| 305 |
+
title="DL Assignment 1 - Demo",
|
| 306 |
+
) as app:
|
| 307 |
+
gr.Markdown(f"# {APP_TITLE}")
|
| 308 |
+
gr.Markdown(APP_DESCRIPTION)
|
| 309 |
+
|
| 310 |
+
model_keys = get_all_model_keys()
|
| 311 |
+
|
| 312 |
+
if not model_keys:
|
| 313 |
+
gr.Markdown("""
|
| 314 |
+
## ⚠️ No Models Loaded
|
| 315 |
+
|
| 316 |
+
Please ensure model files are in the `image/models/` directory.
|
| 317 |
+
See the README for instructions on adding models.
|
| 318 |
+
""")
|
| 319 |
+
else:
|
| 320 |
+
image_models = get_models_by_type("image")
|
| 321 |
+
text_models = get_models_by_type("text")
|
| 322 |
+
multimodal_models = get_models_by_type("multimodal")
|
| 323 |
+
|
| 324 |
+
with gr.Tabs():
|
| 325 |
+
if image_models:
|
| 326 |
+
with gr.Tab("🖼️ Image Classification", id="image_tab"):
|
| 327 |
+
if len(image_models) > 1:
|
| 328 |
+
with gr.Tabs():
|
| 329 |
+
for key, handler in image_models.items():
|
| 330 |
+
tab_name = f"{handler.get_model_name()} ({handler.get_dataset_name()})"
|
| 331 |
+
with gr.Tab(tab_name):
|
| 332 |
+
build_model_tabs(key, handler)
|
| 333 |
+
else:
|
| 334 |
+
key, handler = next(iter(image_models.items()))
|
| 335 |
+
build_model_tabs(key, handler)
|
| 336 |
+
|
| 337 |
+
if text_models:
|
| 338 |
+
with gr.Tab("📝 Text Classification", id="text_tab"):
|
| 339 |
+
if len(text_models) > 1:
|
| 340 |
+
with gr.Tabs():
|
| 341 |
+
for key, handler in text_models.items():
|
| 342 |
+
tab_name = f"{handler.get_model_name()} ({handler.get_dataset_name()})"
|
| 343 |
+
with gr.Tab(tab_name):
|
| 344 |
+
build_model_tabs(key, handler)
|
| 345 |
+
else:
|
| 346 |
+
key, handler = next(iter(text_models.items()))
|
| 347 |
+
build_model_tabs(key, handler)
|
| 348 |
+
|
| 349 |
+
if multimodal_models:
|
| 350 |
+
with gr.Tab("🔀 Multimodal Classification", id="mm_tab"):
|
| 351 |
+
if len(multimodal_models) > 1:
|
| 352 |
+
with gr.Tabs():
|
| 353 |
+
for key, handler in multimodal_models.items():
|
| 354 |
+
tab_name = f"{handler.get_model_name()} ({handler.get_dataset_name()})"
|
| 355 |
+
with gr.Tab(tab_name):
|
| 356 |
+
build_model_tabs(key, handler)
|
| 357 |
+
else:
|
| 358 |
+
key, handler = next(iter(multimodal_models.items()))
|
| 359 |
+
build_model_tabs(key, handler)
|
| 360 |
+
|
| 361 |
+
if not text_models:
|
| 362 |
+
with gr.Tab("📝 Text Classification", id="text_tab"):
|
| 363 |
+
gr.Markdown("""
|
| 364 |
+
### 📝 Text Classification Models
|
| 365 |
+
|
| 366 |
+
*No text models loaded yet. Add your text model handler
|
| 367 |
+
and register it in `app/main.py`.*
|
| 368 |
+
""")
|
| 369 |
+
|
| 370 |
+
if not multimodal_models:
|
| 371 |
+
with gr.Tab("🔀 Multimodal Classification", id="mm_tab"):
|
| 372 |
+
gr.Markdown("""
|
| 373 |
+
### 🔀 Multimodal Classification Models
|
| 374 |
+
|
| 375 |
+
*No multimodal models loaded yet. Add your multimodal
|
| 376 |
+
model handler and register it in `app/main.py`.*
|
| 377 |
+
""")
|
| 378 |
+
|
| 379 |
+
gr.Markdown("""
|
| 380 |
+
<div class="app-footer">
|
| 381 |
+
<p>Deep Learning and Its Applications · Assignment 1</p>
|
| 382 |
+
<p>HCM University of Technology (HCMUT) · VNUHCM</p>
|
| 383 |
+
</div>
|
| 384 |
+
""")
|
| 385 |
+
|
| 386 |
+
return app
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
# ============================================================================
|
| 390 |
+
# ENTRY POINT
|
| 391 |
+
# ============================================================================
|
| 392 |
+
|
| 393 |
+
if __name__ == "__main__":
|
| 394 |
+
app = create_app()
|
| 395 |
+
app.launch(
|
| 396 |
+
server_name="127.0.0.1",
|
| 397 |
+
server_port=5555,
|
| 398 |
+
share=False,
|
| 399 |
+
show_error=True,
|
| 400 |
+
theme=CUSTOM_THEME,
|
| 401 |
+
css=CUSTOM_CSS,
|
| 402 |
+
allowed_paths=[os.path.join(ASSIGNMENT_ROOT, "image", "artifacts")],
|
| 403 |
+
)
|
assignments/assignment-1/app/multimodal/README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Multimodal App Modules
|
| 2 |
+
|
| 3 |
+
Place multimodal-specific inference handlers here.
|
| 4 |
+
|
| 5 |
+
Suggested additions:
|
| 6 |
+
|
| 7 |
+
- multimodal model wrapper classes
|
| 8 |
+
- joint preprocessing helpers
|
| 9 |
+
- prediction utilities
|
| 10 |
+
- demo-specific visualization helpers
|
| 11 |
+
|
| 12 |
+
After adding a handler, register it in `assignments/assignment-1/app/main.py`.
|
assignments/assignment-1/app/multimodal/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Multimodal model handlers for Assignment 1."""
|
assignments/assignment-1/app/requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deep Learning Assignment 1 - Application Demo Dependencies
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
torchvision>=0.15.0
|
| 4 |
+
gradio>=5.0.0
|
| 5 |
+
numpy>=1.24.0
|
| 6 |
+
Pillow>=9.0.0
|
| 7 |
+
matplotlib>=3.7.0
|
assignments/assignment-1/app/shared/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Shared app utilities for Assignment 1."""
|
assignments/assignment-1/app/shared/artifact_utils.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Helpers for reading notebook-generated artifacts and training metadata.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict, Optional
|
| 11 |
+
|
| 12 |
+
from .model_registry import CalibrationResult
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
ASSIGNMENT_ROOT = Path(
|
| 16 |
+
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 17 |
+
)
|
| 18 |
+
ARTIFACTS_DIR = ASSIGNMENT_ROOT / "image" / "artifacts"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_best_accuracy_from_history(history: Optional[Dict[str, Any]]) -> Optional[float]:
|
| 22 |
+
"""Return the best validation accuracy found in a checkpoint history."""
|
| 23 |
+
if not history:
|
| 24 |
+
return None
|
| 25 |
+
|
| 26 |
+
val_acc = history.get("val_acc")
|
| 27 |
+
if isinstance(val_acc, list) and val_acc:
|
| 28 |
+
return float(max(val_acc))
|
| 29 |
+
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_precomputed_calibration_result(
|
| 34 |
+
model_tag: str,
|
| 35 |
+
sample_tag: str = "full",
|
| 36 |
+
) -> Optional[CalibrationResult]:
|
| 37 |
+
"""
|
| 38 |
+
Load notebook-generated calibration metrics and figure from image/artifacts/.
|
| 39 |
+
|
| 40 |
+
The function searches recursively so nested folders like artifacts/cnn and
|
| 41 |
+
artifacts/vit are both supported.
|
| 42 |
+
"""
|
| 43 |
+
if not ARTIFACTS_DIR.exists():
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
metrics_name = f"{model_tag}_calibration_metrics_{sample_tag}.json"
|
| 47 |
+
image_name = f"{model_tag}_calibration_{sample_tag}.png"
|
| 48 |
+
|
| 49 |
+
metrics_path = next(ARTIFACTS_DIR.rglob(metrics_name), None)
|
| 50 |
+
image_path = next(ARTIFACTS_DIR.rglob(image_name), None)
|
| 51 |
+
|
| 52 |
+
if metrics_path is None or image_path is None:
|
| 53 |
+
return None
|
| 54 |
+
|
| 55 |
+
metrics = json.loads(metrics_path.read_text(encoding="utf-8"))
|
| 56 |
+
return CalibrationResult(
|
| 57 |
+
ece=float(metrics["ece"]),
|
| 58 |
+
bin_accuracies=[float(x) for x in metrics["bin_accuracies"]],
|
| 59 |
+
bin_confidences=[float(x) for x in metrics["bin_confidences"]],
|
| 60 |
+
bin_counts=[int(x) for x in metrics["bin_counts"]],
|
| 61 |
+
reliability_diagram=str(image_path),
|
| 62 |
+
source=f"Notebook artifact ({metrics_path.parent.name})",
|
| 63 |
+
)
|
assignments/assignment-1/app/shared/model_registry.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Registry - Central place to register and manage all models.
|
| 3 |
+
|
| 4 |
+
This module makes it easy to add new models for different datasets.
|
| 5 |
+
Each model handler should implement the BaseModelHandler interface.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from abc import ABC, abstractmethod
|
| 9 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class PredictionResult:
|
| 15 |
+
"""Container for prediction results from a model."""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
label: str,
|
| 20 |
+
confidence: float,
|
| 21 |
+
all_labels: List[str],
|
| 22 |
+
all_confidences: List[float],
|
| 23 |
+
explanation_image: Optional[np.ndarray] = None,
|
| 24 |
+
):
|
| 25 |
+
self.label = label
|
| 26 |
+
self.confidence = confidence
|
| 27 |
+
self.all_labels = all_labels
|
| 28 |
+
self.all_confidences = all_confidences
|
| 29 |
+
self.explanation_image = explanation_image # Grad-CAM or attention map
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class CalibrationResult:
|
| 33 |
+
"""Container for model calibration analysis results."""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
ece: float,
|
| 38 |
+
bin_accuracies: List[float],
|
| 39 |
+
bin_confidences: List[float],
|
| 40 |
+
bin_counts: List[int],
|
| 41 |
+
reliability_diagram: Optional[Any] = None,
|
| 42 |
+
source: Optional[str] = None,
|
| 43 |
+
):
|
| 44 |
+
self.ece = ece
|
| 45 |
+
self.bin_accuracies = bin_accuracies
|
| 46 |
+
self.bin_confidences = bin_confidences
|
| 47 |
+
self.bin_counts = bin_counts
|
| 48 |
+
self.reliability_diagram = reliability_diagram
|
| 49 |
+
self.source = source
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class BaseModelHandler(ABC):
|
| 53 |
+
"""
|
| 54 |
+
Abstract base class for model handlers.
|
| 55 |
+
|
| 56 |
+
To add a new model, create a subclass and implement all abstract methods.
|
| 57 |
+
Then register it in the MODEL_REGISTRY dictionary below.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
@abstractmethod
|
| 61 |
+
def get_model_name(self) -> str:
|
| 62 |
+
"""Return human-readable model name."""
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
@abstractmethod
|
| 66 |
+
def get_dataset_name(self) -> str:
|
| 67 |
+
"""Return the dataset name this model was trained on."""
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
@abstractmethod
|
| 71 |
+
def get_data_type(self) -> str:
|
| 72 |
+
"""Return data type: 'image', 'text', or 'multimodal'."""
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
@abstractmethod
|
| 76 |
+
def get_class_labels(self) -> List[str]:
|
| 77 |
+
"""Return list of class labels."""
|
| 78 |
+
pass
|
| 79 |
+
|
| 80 |
+
@abstractmethod
|
| 81 |
+
def get_model_info(self) -> Dict[str, str]:
|
| 82 |
+
"""Return dict of model info for display (architecture, params, etc.)."""
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
@abstractmethod
|
| 86 |
+
def predict(self, input_data) -> PredictionResult:
|
| 87 |
+
"""
|
| 88 |
+
Run prediction on input data.
|
| 89 |
+
|
| 90 |
+
For image models: input_data is a PIL Image or numpy array
|
| 91 |
+
For text models: input_data is a string
|
| 92 |
+
For multimodal: input_data is a tuple (image, text)
|
| 93 |
+
|
| 94 |
+
Returns: PredictionResult
|
| 95 |
+
"""
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
@abstractmethod
|
| 99 |
+
def get_example_inputs(self) -> List[Any]:
|
| 100 |
+
"""Return list of example inputs for the demo."""
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
def get_calibration_data(
|
| 104 |
+
self, max_samples: Optional[int] = None
|
| 105 |
+
) -> Optional[CalibrationResult]:
|
| 106 |
+
"""
|
| 107 |
+
Optionally return calibration analysis result.
|
| 108 |
+
Override this in subclass if you want calibration display.
|
| 109 |
+
"""
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# Global model registry - add new models here
|
| 114 |
+
MODEL_REGISTRY: Dict[str, BaseModelHandler] = {}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def register_model(key: str, handler: BaseModelHandler):
|
| 118 |
+
"""Register a model handler in the global registry."""
|
| 119 |
+
MODEL_REGISTRY[key] = handler
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_model_handler(key: str) -> Optional[BaseModelHandler]:
|
| 123 |
+
"""Get a model handler by key."""
|
| 124 |
+
return MODEL_REGISTRY.get(key)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_all_model_keys() -> List[str]:
|
| 128 |
+
"""Get all registered model keys."""
|
| 129 |
+
return list(MODEL_REGISTRY.keys())
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_models_by_type(data_type: str) -> Dict[str, BaseModelHandler]:
|
| 133 |
+
"""Get all models of a specific data type."""
|
| 134 |
+
return {k: v for k, v in MODEL_REGISTRY.items() if v.get_data_type() == data_type}
|
assignments/assignment-1/app/text/README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Text App Modules
|
| 2 |
+
|
| 3 |
+
Place text-specific inference handlers here.
|
| 4 |
+
|
| 5 |
+
Suggested additions:
|
| 6 |
+
|
| 7 |
+
- model wrapper classes
|
| 8 |
+
- preprocessing helpers
|
| 9 |
+
- prediction utilities
|
| 10 |
+
- calibration or explanation helpers if needed
|
| 11 |
+
|
| 12 |
+
After adding a handler, register it in `assignments/assignment-1/app/main.py`.
|
assignments/assignment-1/app/text/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Text model handlers for Assignment 1."""
|
assignments/assignment-1/image/artifacts/cnn/resnet18_calibration_full.png
ADDED
|
Git LFS Details
|
assignments/assignment-1/image/artifacts/cnn/resnet18_calibration_metrics_full.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_tag": "resnet18",
|
| 3 |
+
"sample_tag": "full",
|
| 4 |
+
"ece": 0.020006245681643487,
|
| 5 |
+
"num_bins": 10,
|
| 6 |
+
"total_evaluated_samples": 10000,
|
| 7 |
+
"bin_accuracies": [
|
| 8 |
+
0.0,
|
| 9 |
+
0.0,
|
| 10 |
+
1.0,
|
| 11 |
+
0.0,
|
| 12 |
+
0.6153846383094788,
|
| 13 |
+
0.5058823823928833,
|
| 14 |
+
0.5416666865348816,
|
| 15 |
+
0.6666666865348816,
|
| 16 |
+
0.6967741847038269,
|
| 17 |
+
0.9816811680793762
|
| 18 |
+
],
|
| 19 |
+
"bin_confidences": [
|
| 20 |
+
0.0,
|
| 21 |
+
0.0,
|
| 22 |
+
0.2935279607772827,
|
| 23 |
+
0.36219507455825806,
|
| 24 |
+
0.46814653277397156,
|
| 25 |
+
0.5451725721359253,
|
| 26 |
+
0.6489876508712769,
|
| 27 |
+
0.752896785736084,
|
| 28 |
+
0.8511331677436829,
|
| 29 |
+
0.9974784851074219
|
| 30 |
+
],
|
| 31 |
+
"bin_counts": [
|
| 32 |
+
0,
|
| 33 |
+
0,
|
| 34 |
+
1,
|
| 35 |
+
4,
|
| 36 |
+
13,
|
| 37 |
+
85,
|
| 38 |
+
72,
|
| 39 |
+
117,
|
| 40 |
+
155,
|
| 41 |
+
9553
|
| 42 |
+
]
|
| 43 |
+
}
|
assignments/assignment-1/image/artifacts/vit/vit_b16_calibration_full.png
ADDED
|
Git LFS Details
|
assignments/assignment-1/image/artifacts/vit/vit_b16_calibration_metrics_full.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_tag": "vit_b16",
|
| 3 |
+
"sample_tag": "full",
|
| 4 |
+
"ece": 0.006916732695698738,
|
| 5 |
+
"num_bins": 10,
|
| 6 |
+
"total_evaluated_samples": 10000,
|
| 7 |
+
"bin_accuracies": [
|
| 8 |
+
0.0,
|
| 9 |
+
0.0,
|
| 10 |
+
0.0,
|
| 11 |
+
0.0,
|
| 12 |
+
0.5,
|
| 13 |
+
0.5714285969734192,
|
| 14 |
+
0.6034482717514038,
|
| 15 |
+
0.6901408433914185,
|
| 16 |
+
0.7037037014961243,
|
| 17 |
+
0.9934116005897522
|
| 18 |
+
],
|
| 19 |
+
"bin_confidences": [
|
| 20 |
+
0.0,
|
| 21 |
+
0.0,
|
| 22 |
+
0.2842116951942444,
|
| 23 |
+
0.37363073229789734,
|
| 24 |
+
0.46517834067344666,
|
| 25 |
+
0.5469092130661011,
|
| 26 |
+
0.6531615853309631,
|
| 27 |
+
0.7513611912727356,
|
| 28 |
+
0.8554656505584717,
|
| 29 |
+
0.9979013204574585
|
| 30 |
+
],
|
| 31 |
+
"bin_counts": [
|
| 32 |
+
0,
|
| 33 |
+
0,
|
| 34 |
+
1,
|
| 35 |
+
1,
|
| 36 |
+
12,
|
| 37 |
+
35,
|
| 38 |
+
58,
|
| 39 |
+
71,
|
| 40 |
+
108,
|
| 41 |
+
9714
|
| 42 |
+
]
|
| 43 |
+
}
|
assignments/assignment-1/image/data/cifar-10-python.tar.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce
|
| 3 |
+
size 170498071
|
assignments/assignment-1/image/models/resnet18_cifar10.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a0076300593993e9e6e09a358c254f24b8ffda12f66ce566e50a289ee462cb10
|
| 3 |
+
size 44808651
|
assignments/assignment-1/image/models/vit_b16_cifar10.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2e4d76e9dcb5b3eb907a00782e9f8af05b9ee46e9f2d3e0e16484d351e63f382
|
| 3 |
+
size 343288191
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision>=0.15.0
|
| 3 |
+
gradio>=5.0.0
|
| 4 |
+
numpy>=1.24.0
|
| 5 |
+
Pillow>=9.0.0
|
| 6 |
+
matplotlib>=3.7.0
|