tanh1c's picture
Add Gradio image demo without binary calibration PNGs
9c763c6
"""
Helpers for reading notebook-generated artifacts and training metadata.
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
from PIL import Image
from .model_registry import CalibrationResult
ASSIGNMENT_ROOT = Path(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
ARTIFACTS_DIR = ASSIGNMENT_ROOT / "image" / "artifacts"
def _render_reliability_diagram_from_metrics(metrics: Dict[str, Any]) -> np.ndarray:
"""Render a reliability diagram directly from saved calibration metrics."""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
bin_accuracies = [float(x) for x in metrics["bin_accuracies"]]
bin_confidences = [float(x) for x in metrics["bin_confidences"]]
bin_counts = [int(x) for x in metrics["bin_counts"]]
ece = float(metrics["ece"])
n_bins = len(bin_accuracies)
bin_boundaries = np.linspace(0, 1, n_bins + 1)
bin_centers = [
(bin_boundaries[i] + bin_boundaries[i + 1]) / 2 for i in range(n_bins)
]
total = max(sum(bin_counts), 1)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
fig.patch.set_facecolor("#0d1117")
ax1.set_facecolor("#161b22")
width = 0.08
ax1.bar(
[c - width / 2 for c in bin_centers],
bin_accuracies,
width,
label="Accuracy",
color="#58a6ff",
alpha=0.9,
edgecolor="#58a6ff",
)
ax1.bar(
[c + width / 2 for c in bin_centers],
bin_confidences,
width,
label="Avg Confidence",
color="#f97583",
alpha=0.9,
edgecolor="#f97583",
)
ax1.plot(
[0, 1],
[0, 1],
"--",
color="#8b949e",
linewidth=2,
label="Perfect Calibration",
)
ax1.set_xlim(0, 1)
ax1.set_ylim(0, 1)
ax1.set_xlabel("Confidence", color="white", fontsize=12)
ax1.set_ylabel("Accuracy / Confidence", color="white", fontsize=12)
ax1.set_title(
f"Reliability Diagram (ECE: {ece:.4f})",
color="white",
fontsize=14,
fontweight="bold",
pad=15,
)
ax1.legend(
facecolor="#161b22",
edgecolor="#30363d",
labelcolor="white",
fontsize=10,
)
ax1.tick_params(colors="white")
for spine in ax1.spines.values():
spine.set_edgecolor("#30363d")
ax1.grid(True, alpha=0.1, color="white")
ax2.set_facecolor("#161b22")
ax2.bar(
bin_centers,
[count / total for count in bin_counts],
0.08,
color="#56d364",
alpha=0.9,
edgecolor="#56d364",
)
ax2.set_xlim(0, 1)
ax2.set_xlabel("Confidence", color="white", fontsize=12)
ax2.set_ylabel("Fraction of Samples", color="white", fontsize=12)
ax2.set_title(
"Confidence Distribution",
color="white",
fontsize=14,
fontweight="bold",
pad=15,
)
ax2.tick_params(colors="white")
for spine in ax2.spines.values():
spine.set_edgecolor("#30363d")
ax2.grid(True, alpha=0.1, color="white")
plt.tight_layout(pad=3)
fig.canvas.draw()
rgba_buffer = fig.canvas.buffer_rgba()
diagram = np.array(rgba_buffer)[:, :, :3]
plt.close(fig)
return diagram
def get_best_accuracy_from_history(history: Optional[Dict[str, Any]]) -> Optional[float]:
"""Return the best validation accuracy found in a checkpoint history."""
if not history:
return None
val_acc = history.get("val_acc")
if isinstance(val_acc, list) and val_acc:
return float(max(val_acc))
return None
def load_precomputed_calibration_result(
model_tag: str,
sample_tag: str = "full",
) -> Optional[CalibrationResult]:
"""
Load notebook-generated calibration metrics and figure from image/artifacts/.
The function searches recursively so nested folders like artifacts/cnn and
artifacts/vit are both supported.
"""
if not ARTIFACTS_DIR.exists():
return None
metrics_name = f"{model_tag}_calibration_metrics_{sample_tag}.json"
metrics_path = next(ARTIFACTS_DIR.rglob(metrics_name), None)
image_name = f"{model_tag}_calibration_{sample_tag}.png"
image_path = next(ARTIFACTS_DIR.rglob(image_name), None)
if metrics_path is None:
return None
metrics = json.loads(metrics_path.read_text(encoding="utf-8"))
if image_path is not None:
reliability_diagram = np.array(Image.open(image_path).convert("RGB"))
else:
reliability_diagram = _render_reliability_diagram_from_metrics(metrics)
return CalibrationResult(
ece=float(metrics["ece"]),
bin_accuracies=[float(x) for x in metrics["bin_accuracies"]],
bin_confidences=[float(x) for x in metrics["bin_confidences"]],
bin_counts=[int(x) for x in metrics["bin_counts"]],
reliability_diagram=reliability_diagram,
source=f"Notebook artifact ({metrics_path.parent.name})",
)