Spaces:
Sleeping
Sleeping
| """ | |
| Helpers for reading notebook-generated artifacts and training metadata. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| import numpy as np | |
| from PIL import Image | |
| from .model_registry import CalibrationResult | |
| ASSIGNMENT_ROOT = Path( | |
| os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| ) | |
| ARTIFACTS_DIR = ASSIGNMENT_ROOT / "image" / "artifacts" | |
| def _render_reliability_diagram_from_metrics(metrics: Dict[str, Any]) -> np.ndarray: | |
| """Render a reliability diagram directly from saved calibration metrics.""" | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| bin_accuracies = [float(x) for x in metrics["bin_accuracies"]] | |
| bin_confidences = [float(x) for x in metrics["bin_confidences"]] | |
| bin_counts = [int(x) for x in metrics["bin_counts"]] | |
| ece = float(metrics["ece"]) | |
| n_bins = len(bin_accuracies) | |
| bin_boundaries = np.linspace(0, 1, n_bins + 1) | |
| bin_centers = [ | |
| (bin_boundaries[i] + bin_boundaries[i + 1]) / 2 for i in range(n_bins) | |
| ] | |
| total = max(sum(bin_counts), 1) | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) | |
| fig.patch.set_facecolor("#0d1117") | |
| ax1.set_facecolor("#161b22") | |
| width = 0.08 | |
| ax1.bar( | |
| [c - width / 2 for c in bin_centers], | |
| bin_accuracies, | |
| width, | |
| label="Accuracy", | |
| color="#58a6ff", | |
| alpha=0.9, | |
| edgecolor="#58a6ff", | |
| ) | |
| ax1.bar( | |
| [c + width / 2 for c in bin_centers], | |
| bin_confidences, | |
| width, | |
| label="Avg Confidence", | |
| color="#f97583", | |
| alpha=0.9, | |
| edgecolor="#f97583", | |
| ) | |
| ax1.plot( | |
| [0, 1], | |
| [0, 1], | |
| "--", | |
| color="#8b949e", | |
| linewidth=2, | |
| label="Perfect Calibration", | |
| ) | |
| ax1.set_xlim(0, 1) | |
| ax1.set_ylim(0, 1) | |
| ax1.set_xlabel("Confidence", color="white", fontsize=12) | |
| ax1.set_ylabel("Accuracy / Confidence", color="white", fontsize=12) | |
| ax1.set_title( | |
| f"Reliability Diagram (ECE: {ece:.4f})", | |
| color="white", | |
| fontsize=14, | |
| fontweight="bold", | |
| pad=15, | |
| ) | |
| ax1.legend( | |
| facecolor="#161b22", | |
| edgecolor="#30363d", | |
| labelcolor="white", | |
| fontsize=10, | |
| ) | |
| ax1.tick_params(colors="white") | |
| for spine in ax1.spines.values(): | |
| spine.set_edgecolor("#30363d") | |
| ax1.grid(True, alpha=0.1, color="white") | |
| ax2.set_facecolor("#161b22") | |
| ax2.bar( | |
| bin_centers, | |
| [count / total for count in bin_counts], | |
| 0.08, | |
| color="#56d364", | |
| alpha=0.9, | |
| edgecolor="#56d364", | |
| ) | |
| ax2.set_xlim(0, 1) | |
| ax2.set_xlabel("Confidence", color="white", fontsize=12) | |
| ax2.set_ylabel("Fraction of Samples", color="white", fontsize=12) | |
| ax2.set_title( | |
| "Confidence Distribution", | |
| color="white", | |
| fontsize=14, | |
| fontweight="bold", | |
| pad=15, | |
| ) | |
| ax2.tick_params(colors="white") | |
| for spine in ax2.spines.values(): | |
| spine.set_edgecolor("#30363d") | |
| ax2.grid(True, alpha=0.1, color="white") | |
| plt.tight_layout(pad=3) | |
| fig.canvas.draw() | |
| rgba_buffer = fig.canvas.buffer_rgba() | |
| diagram = np.array(rgba_buffer)[:, :, :3] | |
| plt.close(fig) | |
| return diagram | |
| def get_best_accuracy_from_history(history: Optional[Dict[str, Any]]) -> Optional[float]: | |
| """Return the best validation accuracy found in a checkpoint history.""" | |
| if not history: | |
| return None | |
| val_acc = history.get("val_acc") | |
| if isinstance(val_acc, list) and val_acc: | |
| return float(max(val_acc)) | |
| return None | |
| def load_precomputed_calibration_result( | |
| model_tag: str, | |
| sample_tag: str = "full", | |
| ) -> Optional[CalibrationResult]: | |
| """ | |
| Load notebook-generated calibration metrics and figure from image/artifacts/. | |
| The function searches recursively so nested folders like artifacts/cnn and | |
| artifacts/vit are both supported. | |
| """ | |
| if not ARTIFACTS_DIR.exists(): | |
| return None | |
| metrics_name = f"{model_tag}_calibration_metrics_{sample_tag}.json" | |
| metrics_path = next(ARTIFACTS_DIR.rglob(metrics_name), None) | |
| image_name = f"{model_tag}_calibration_{sample_tag}.png" | |
| image_path = next(ARTIFACTS_DIR.rglob(image_name), None) | |
| if metrics_path is None: | |
| return None | |
| metrics = json.loads(metrics_path.read_text(encoding="utf-8")) | |
| if image_path is not None: | |
| reliability_diagram = np.array(Image.open(image_path).convert("RGB")) | |
| else: | |
| reliability_diagram = _render_reliability_diagram_from_metrics(metrics) | |
| return CalibrationResult( | |
| ece=float(metrics["ece"]), | |
| bin_accuracies=[float(x) for x in metrics["bin_accuracies"]], | |
| bin_confidences=[float(x) for x in metrics["bin_confidences"]], | |
| bin_counts=[int(x) for x in metrics["bin_counts"]], | |
| reliability_diagram=reliability_diagram, | |
| source=f"Notebook artifact ({metrics_path.parent.name})", | |
| ) | |