budijuarto's picture
Upload src/egg_damage/reporting.py
0258b57 verified
from __future__ import annotations
from pathlib import Path
from typing import Any, Iterable
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from PIL import Image, ImageOps
from sklearn.calibration import calibration_curve
from sklearn.metrics import auc, precision_recall_curve, roc_curve
from .data_discovery import CANONICAL_LABELS
sns.set_theme(style="whitegrid", context="notebook")
def markdown_table(df: pd.DataFrame) -> str:
if df.empty:
return "_No rows._"
safe = df.copy()
safe = safe.fillna("")
headers = [str(col) for col in safe.columns]
lines = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"]
for row in safe.itertuples(index=False):
values = [str(value).replace("\n", " ") for value in row]
lines.append("| " + " | ".join(values) + " |")
return "\n".join(lines)
def _savefig(path: str | Path) -> None:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
plt.tight_layout()
plt.savefig(path, dpi=160, bbox_inches="tight")
plt.close()
def plot_class_distribution(df: pd.DataFrame, output_path: str | Path) -> None:
plt.figure(figsize=(8, 4.8))
order = ["train", "val", "test"]
sns.countplot(data=df, x="split", hue="label", order=[s for s in order if s in set(df["split"])])
plt.title("Class Distribution by Split")
plt.xlabel("Split")
plt.ylabel("Images")
_savefig(output_path)
def plot_confusion_matrix(
matrix: np.ndarray,
output_path: str | Path,
title: str,
class_names: Iterable[str] = CANONICAL_LABELS,
) -> None:
plt.figure(figsize=(5.6, 4.8))
sns.heatmap(
matrix,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=list(class_names),
yticklabels=list(class_names),
cbar=False,
)
plt.title(title)
plt.xlabel("Predicted")
plt.ylabel("True")
_savefig(output_path)
def plot_roc_curve_single(
y_true: np.ndarray,
y_prob: np.ndarray,
output_path: str | Path,
title: str,
) -> float | None:
if len(np.unique(y_true)) < 2:
return None
fpr, tpr, _ = roc_curve(y_true, y_prob)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(5.8, 4.8))
plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}", linewidth=2)
plt.plot([0, 1], [0, 1], linestyle="--", color="gray", linewidth=1)
plt.title(title)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend(loc="lower right")
_savefig(output_path)
return float(roc_auc)
def plot_precision_recall_curve_single(
y_true: np.ndarray,
y_prob: np.ndarray,
output_path: str | Path,
title: str,
) -> float | None:
if len(np.unique(y_true)) < 2:
return None
precision, recall, _ = precision_recall_curve(y_true, y_prob)
pr_auc = auc(recall, precision)
plt.figure(figsize=(5.8, 4.8))
plt.plot(recall, precision, label=f"PR AUC = {pr_auc:.3f}", linewidth=2)
plt.title(title)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend(loc="lower left")
_savefig(output_path)
return float(pr_auc)
def plot_combined_roc(prediction_frames: list[tuple[str, pd.DataFrame]], output_path: str | Path) -> None:
plt.figure(figsize=(7.2, 5.6))
plotted = False
for model_name, frame in prediction_frames:
y_true = frame["y_true"].to_numpy()
y_prob = frame["prob_damaged"].to_numpy()
if len(np.unique(y_true)) < 2:
continue
fpr, tpr, _ = roc_curve(y_true, y_prob)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, linewidth=2, label=f"{model_name} ({roc_auc:.3f})")
plotted = True
if not plotted:
plt.text(0.5, 0.5, "ROC unavailable: only one class present", ha="center", va="center")
plt.plot([0, 1], [0, 1], linestyle="--", color="gray", linewidth=1)
plt.title("Test ROC Comparison")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend(loc="lower right", fontsize=8)
_savefig(output_path)
def plot_metric_bars(metrics_df: pd.DataFrame, output_path: str | Path) -> None:
metric_cols = ["accuracy", "precision", "recall", "f1", "roc_auc", "balanced_accuracy"]
test_df = metrics_df[metrics_df["split"] == "test"].copy()
if test_df.empty:
test_df = metrics_df.copy()
melt = test_df.melt(id_vars=["model_name"], value_vars=metric_cols, var_name="metric", value_name="score")
plt.figure(figsize=(11, 6))
sns.barplot(data=melt, x="metric", y="score", hue="model_name")
plt.ylim(0, 1.02)
plt.title("Model Metrics Comparison")
plt.xlabel("")
plt.ylabel("Score")
plt.xticks(rotation=25, ha="right")
plt.legend(loc="lower right", fontsize=8)
_savefig(output_path)
def plot_training_curves(history_df: pd.DataFrame, output_path: str | Path, model_name: str) -> None:
if history_df.empty:
return
fig, axes = plt.subplots(1, 2, figsize=(11, 4.5))
axes[0].plot(history_df["epoch"], history_df["train_loss"], marker="o", label="Train")
axes[0].plot(history_df["epoch"], history_df["val_loss"], marker="o", label="Validation")
axes[0].set_title(f"{model_name}: Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].legend()
axes[1].plot(history_df["epoch"], history_df["train_accuracy"], marker="o", label="Train")
axes[1].plot(history_df["epoch"], history_df["val_accuracy"], marker="o", label="Validation")
if "val_f1" in history_df:
axes[1].plot(history_df["epoch"], history_df["val_f1"], marker="s", label="Val F1")
axes[1].set_title(f"{model_name}: Accuracy / F1")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Score")
axes[1].set_ylim(0, 1.02)
axes[1].legend()
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
fig.tight_layout()
fig.savefig(output_path, dpi=160, bbox_inches="tight")
plt.close(fig)
def plot_sample_grid(
pred_df: pd.DataFrame,
output_path: str | Path,
title: str,
max_images: int = 12,
) -> None:
sample = pred_df.head(max_images).copy()
cols = min(4, max(len(sample), 1))
rows = int(np.ceil(max(len(sample), 1) / cols))
fig, axes = plt.subplots(rows, cols, figsize=(cols * 3.2, rows * 3.4))
axes_arr = np.asarray(axes).reshape(-1)
for ax in axes_arr:
ax.axis("off")
if sample.empty:
axes_arr[0].text(0.5, 0.5, "No samples", ha="center", va="center")
for ax, row in zip(axes_arr, sample.itertuples(index=False)):
img = Image.open(row.filepath)
img = ImageOps.exif_transpose(img).convert("RGB")
ax.imshow(img)
ax.set_title(
f"T: {row.label}\nP: {row.pred_label} ({row.confidence:.2f})",
fontsize=9,
)
ax.axis("off")
fig.suptitle(title, fontsize=13)
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
fig.tight_layout()
fig.savefig(output_path, dpi=160, bbox_inches="tight")
plt.close(fig)
def plot_calibration(
y_true: np.ndarray,
y_prob: np.ndarray,
output_path: str | Path,
title: str,
) -> None:
if len(np.unique(y_true)) < 2:
return
prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=8, strategy="uniform")
plt.figure(figsize=(5.8, 4.8))
plt.plot(prob_pred, prob_true, marker="o", linewidth=2)
plt.plot([0, 1], [0, 1], linestyle="--", color="gray", linewidth=1)
plt.title(title)
plt.xlabel("Mean Predicted Probability")
plt.ylabel("Fraction of Positives")
_savefig(output_path)
def write_markdown_report(
config: dict[str, Any],
splits_df: pd.DataFrame,
metrics_df: pd.DataFrame,
leaderboard_df: pd.DataFrame,
misclassified_df: pd.DataFrame,
output_path: str | Path,
) -> None:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
best = leaderboard_df.iloc[0].to_dict() if not leaderboard_df.empty else {}
split_summary = markdown_table(
splits_df.groupby(["split", "label"]).size().unstack(fill_value=0).reset_index()
)
metric_cols = [
"model_name",
"split",
"accuracy",
"precision",
"recall",
"f1",
"roc_auc",
"balanced_accuracy",
"specificity",
"sensitivity",
]
metrics_md = markdown_table(metrics_df[[c for c in metric_cols if c in metrics_df.columns]])
error_text = "No misclassified test samples were recorded."
if not misclassified_df.empty:
by_model = misclassified_df.groupby("model_name").size().sort_values(ascending=False)
examples = misclassified_df.head(8)[["model_name", "label", "pred_label", "confidence", "filepath"]]
error_text = (
"Misclassifications by model:\n\n"
+ markdown_table(by_model.reset_index(name="count"))
+ "\n\nExample errors:\n\n"
+ markdown_table(examples)
)
content = f"""# Egg Damage Classification Report
## Dataset Overview
- Dataset path: `{config['paths']['data_dir']}`
- Task: binary classification, `Damaged` vs `Not Damaged`
- Split strategy: existing split folders are respected; otherwise stratified 70/15/15 splitting is used.
- Training balance: `{config.get('balance', {}).get('strategy', 'disabled')}`.
## Split Summary
{split_summary}
## Preprocessing
- Classical models: grayscale resize to {config['preprocessing']['image_size']}x{config['preprocessing']['image_size']}, HOG or LBP features, standardized SVM inputs.
- Deep models: ImageNet normalization, realistic train-only flips, small rotations, mild affine jitter, and light color jitter.
- SVM training curves are marked N/A because these models are not epoch-trained.
## Metrics
{metrics_md}
## Best Model
- Model: `{best.get('model_name', 'N/A')}`
- Test F1: `{best.get('f1', 'N/A')}`
- Test ROC-AUC: `{best.get('roc_auc', 'N/A')}`
- Test balanced accuracy: `{best.get('balanced_accuracy', 'N/A')}`
- Model path: `{best.get('model_path', 'N/A')}`
## Error Patterns
{error_text}
## Deployment
Run `python scripts/launch_gradio.py --config configs/default.yaml` to launch the local Gradio app. The app loads the best ranked model automatically and can switch among trained models.
"""
output_path.write_text(content, encoding="utf-8")