"""
Deep Learning Assignment 1 - Application Demo
===============================================
A modular Gradio application for demonstrating
trained models on Image, Text, and Multimodal datasets.
Features:
- Image classification with Grad-CAM / attention visualization
- Model Calibration analysis (ECE + Reliability Diagram)
- Easy to extend with new models/datasets
Usage:
python assignments/assignment-1/app/main.py
"""
import sys
import os
# Add assignment root to path so `app.*` imports keep working.
ASSIGNMENT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ASSIGNMENT_ROOT)
import gradio as gr
from typing import Dict
from app.shared.model_registry import (
register_model,
get_all_model_keys,
get_models_by_type,
BaseModelHandler,
)
from app.image.resnet18 import Cifar10ResNet18Handler
from app.image.vit_b16 import Cifar10ViTHandler
# ============================================================================
# CONFIGURATION
# ============================================================================
APP_TITLE = "๐ง Deep Learning Assignment 1 - Demo"
APP_DESCRIPTION = """
Classification on Images, Text, and Multimodal Data
CO3091 ยท HCM University of Technology ยท 2025-2026 Semester 2
"""
# Load custom CSS from external file
CSS_PATH = os.path.join(os.path.dirname(__file__), "assets", "style.css")
if os.path.exists(CSS_PATH):
with open(CSS_PATH, "r", encoding="utf-8") as f:
CUSTOM_CSS = f.read()
else:
CUSTOM_CSS = ""
CUSTOM_THEME = gr.themes.Base(
primary_hue=gr.themes.colors.blue,
secondary_hue=gr.themes.colors.green,
neutral_hue=gr.themes.colors.gray,
font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"],
).set(
body_background_fill="#0d1117",
body_background_fill_dark="#0d1117",
block_background_fill="#161b22",
block_background_fill_dark="#161b22",
block_border_color="#30363d",
block_border_color_dark="#30363d",
block_label_text_color="#c9d1d9",
block_label_text_color_dark="#c9d1d9",
block_title_text_color="#f0f6fc",
block_title_text_color_dark="#f0f6fc",
body_text_color="#c9d1d9",
body_text_color_dark="#c9d1d9",
body_text_color_subdued="#8b949e",
body_text_color_subdued_dark="#8b949e",
button_primary_background_fill="#238636",
button_primary_background_fill_dark="#238636",
button_primary_background_fill_hover="#2ea043",
button_primary_background_fill_hover_dark="#2ea043",
button_primary_text_color="white",
button_primary_text_color_dark="white",
input_background_fill="#0d1117",
input_background_fill_dark="#0d1117",
input_border_color="#30363d",
input_border_color_dark="#30363d",
shadow_drop="none",
shadow_drop_lg="none",
)
# ============================================================================
# MODEL INITIALIZATION
# ============================================================================
def init_models():
"""Initialize and register all available models."""
model_dir = os.path.join(ASSIGNMENT_ROOT, "image", "models")
# CIFAR-10 ResNet-18
resnet18_path = os.path.join(model_dir, "resnet18_cifar10.pth")
if os.path.exists(resnet18_path):
try:
handler = Cifar10ResNet18Handler(resnet18_path)
register_model("cifar10_resnet18", handler)
print(f"โ
Loaded: CIFAR-10 ResNet-18 from {resnet18_path}")
except Exception as e:
print(f"โ Failed to load CIFAR-10 ResNet-18: {e}")
else:
print(f"โ ๏ธ Model file not found: {resnet18_path}")
# CIFAR-10 ViT-B/16
vit_path = os.path.join(model_dir, "vit_b16_cifar10.pth")
if os.path.exists(vit_path):
try:
handler = Cifar10ViTHandler(vit_path)
register_model("cifar10_vit", handler)
print(f"โ
Loaded: CIFAR-10 ViT-B/16 from {vit_path}")
except Exception as e:
print(f"โ Failed to load CIFAR-10 ViT-B/16: {e}")
else:
print(f"โ ๏ธ Model file not found: {vit_path}")
# ============================================================================
# UI BUILDER FUNCTIONS
# ============================================================================
def format_confidence_label(labels, confidences, top_k=5):
"""Format top-k predictions as a dictionary for gr.Label."""
paired = sorted(zip(labels, confidences), key=lambda x: x[1], reverse=True)
return {label: float(conf) for label, conf in paired[:top_k]}
def build_model_info_markdown(handler: BaseModelHandler) -> str:
"""Build formatted model info markdown."""
info = handler.get_model_info()
lines = ["### ๐ Model Information\n"]
for key, val in info.items():
lines.append(f"| **{key}** | {val} |")
header = "| Property | Value |\n|:---|:---|\n"
table_lines = [line for line in lines[1:]]
return lines[0] + header + "\n".join(table_lines)
def build_image_prediction_tab(model_key: str, handler: BaseModelHandler):
"""Build the prediction tab UI for image models."""
with gr.Row(equal_height=True):
with gr.Column(scale=1):
input_image = gr.Image(
label="๐ธ Upload Image",
type="numpy",
height=300,
sources=["upload", "clipboard"],
)
predict_btn = gr.Button(
"๐ Predict & Explain",
variant="primary",
size="lg",
)
gr.Markdown(
f"*Classes: {', '.join(handler.get_class_labels())}*",
elem_classes=["text-sm"],
)
with gr.Column(scale=1):
output_label = gr.Label(
label="๐ Prediction Results (Top-5)",
num_top_classes=5,
)
with gr.Row():
explanation_image = gr.Image(
label="๐ฅ Model Explanation (Interpretability)",
interactive=False,
height=350,
)
def do_predict(image):
if image is None:
return None, None
try:
result = handler.predict(image)
conf_dict = format_confidence_label(
result.all_labels, result.all_confidences
)
return conf_dict, result.explanation_image
except Exception as e:
raise gr.Error(f"Prediction failed: {str(e)}")
predict_btn.click(
fn=do_predict,
inputs=[input_image],
outputs=[output_label, explanation_image],
)
def build_calibration_tab(model_key: str, handler: BaseModelHandler):
"""Build the calibration analysis tab."""
gr.Markdown("""
### ๐ Model Calibration Analysis
Calibration measures how well the model's confidence matches its actual accuracy.
A perfectly calibrated model has **confidence = accuracy** for all predictions.
- **ECE (Expected Calibration Error)**: Lower is better (0 = perfect calibration)
- **Reliability Diagram**: Compares predicted confidence vs actual accuracy per bin
- **Quick Preview**: Uses a very small subset for fast CPU demos
- **Full Test Set**: Uses notebook artifacts instantly when available
""")
calibration_mode = gr.Radio(
choices=[
"Quick Preview (64 samples)",
"Full Test Set (10,000 samples)",
],
value="Quick Preview (64 samples)",
label="Calibration Mode",
)
compute_btn = gr.Button(
"๐ Compute Calibration",
variant="primary",
size="lg",
)
ece_display = gr.Markdown(visible=False)
calibration_plot = gr.Image(
label="๐ Calibration Analysis",
interactive=False,
visible=False,
height=450,
)
def compute_calibration(mode):
try:
max_samples = 64 if mode.startswith("Quick Preview") else None
result = handler.get_calibration_data(max_samples=max_samples)
if result is None:
raise gr.Error("Could not compute calibration data")
sample_note = (
"Approximate preview on 64 evenly spaced test images"
if max_samples is not None
else "Full CIFAR-10 test set"
)
source_note = result.source or "Live computation"
ece_md = f"""
### Calibration Metrics
| Metric | Value |
|:---|:---|
| **Mode** | {sample_note} |
| **Source** | {source_note} |
| **Expected Calibration Error (ECE)** | `{result.ece:.6f}` |
| **Interpretation** | {'โ
Well calibrated' if result.ece < 0.05 else 'โ ๏ธ Moderately calibrated' if result.ece < 0.15 else 'โ Poorly calibrated'} |
| **Total evaluated samples** | {sum(result.bin_counts):,} |
"""
return (
gr.update(value=ece_md, visible=True),
gr.update(value=result.reliability_diagram, visible=True),
)
except Exception as e:
raise gr.Error(f"Calibration computation failed: {str(e)}")
compute_btn.click(
fn=compute_calibration,
inputs=[calibration_mode],
outputs=[ece_display, calibration_plot],
)
def build_model_tabs(model_key: str, handler: BaseModelHandler):
"""Build all tabs for a specific model."""
gr.Markdown(build_model_info_markdown(handler))
with gr.Tabs():
with gr.Tab("๐ฏ Predict & Explain", id="predict"):
data_type = handler.get_data_type()
if data_type == "image":
build_image_prediction_tab(model_key, handler)
elif data_type == "text":
gr.Markdown("### ๐ Text Classification\n*Coming soon...*")
elif data_type == "multimodal":
gr.Markdown("### ๐ผ๏ธ+๐ Multimodal Classification\n*Coming soon...*")
with gr.Tab("๐ Calibration", id="calibration"):
build_calibration_tab(model_key, handler)
# ============================================================================
# MAIN APPLICATION
# ============================================================================
def create_app() -> gr.Blocks:
"""Create the main Gradio application."""
init_models()
with gr.Blocks(
title="DL Assignment 1 - Demo",
) as app:
gr.Markdown(f"# {APP_TITLE}")
gr.Markdown(APP_DESCRIPTION)
model_keys = get_all_model_keys()
if not model_keys:
gr.Markdown("""
## โ ๏ธ No Models Loaded
Please ensure model files are in the `image/models/` directory.
See the README for instructions on adding models.
""")
else:
image_models = get_models_by_type("image")
text_models = get_models_by_type("text")
multimodal_models = get_models_by_type("multimodal")
with gr.Tabs():
if image_models:
with gr.Tab("๐ผ๏ธ Image Classification", id="image_tab"):
if len(image_models) > 1:
with gr.Tabs():
for key, handler in image_models.items():
tab_name = f"{handler.get_model_name()} ({handler.get_dataset_name()})"
with gr.Tab(tab_name):
build_model_tabs(key, handler)
else:
key, handler = next(iter(image_models.items()))
build_model_tabs(key, handler)
if text_models:
with gr.Tab("๐ Text Classification", id="text_tab"):
if len(text_models) > 1:
with gr.Tabs():
for key, handler in text_models.items():
tab_name = f"{handler.get_model_name()} ({handler.get_dataset_name()})"
with gr.Tab(tab_name):
build_model_tabs(key, handler)
else:
key, handler = next(iter(text_models.items()))
build_model_tabs(key, handler)
if multimodal_models:
with gr.Tab("๐ Multimodal Classification", id="mm_tab"):
if len(multimodal_models) > 1:
with gr.Tabs():
for key, handler in multimodal_models.items():
tab_name = f"{handler.get_model_name()} ({handler.get_dataset_name()})"
with gr.Tab(tab_name):
build_model_tabs(key, handler)
else:
key, handler = next(iter(multimodal_models.items()))
build_model_tabs(key, handler)
if not text_models:
with gr.Tab("๐ Text Classification", id="text_tab"):
gr.Markdown("""
### ๐ Text Classification Models
*No text models loaded yet. Add your text model handler
and register it in `app/main.py`.*
""")
if not multimodal_models:
with gr.Tab("๐ Multimodal Classification", id="mm_tab"):
gr.Markdown("""
### ๐ Multimodal Classification Models
*No multimodal models loaded yet. Add your multimodal
model handler and register it in `app/main.py`.*
""")
gr.Markdown("""
""")
return app
# ============================================================================
# ENTRY POINT
# ============================================================================
if __name__ == "__main__":
app = create_app()
app.launch(
server_name="127.0.0.1",
server_port=5555,
share=False,
show_error=True,
theme=CUSTOM_THEME,
css=CUSTOM_CSS,
allowed_paths=[os.path.join(ASSIGNMENT_ROOT, "image", "artifacts")],
)