| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Any, Iterable |
|
|
| import matplotlib |
|
|
| matplotlib.use("Agg") |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| import seaborn as sns |
| from PIL import Image, ImageOps |
| from sklearn.calibration import calibration_curve |
| from sklearn.metrics import auc, precision_recall_curve, roc_curve |
|
|
| from .data_discovery import CANONICAL_LABELS |
|
|
|
|
| sns.set_theme(style="whitegrid", context="notebook") |
|
|
|
|
| def markdown_table(df: pd.DataFrame) -> str: |
| if df.empty: |
| return "_No rows._" |
| safe = df.copy() |
| safe = safe.fillna("") |
| headers = [str(col) for col in safe.columns] |
| lines = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"] |
| for row in safe.itertuples(index=False): |
| values = [str(value).replace("\n", " ") for value in row] |
| lines.append("| " + " | ".join(values) + " |") |
| return "\n".join(lines) |
|
|
|
|
| def _savefig(path: str | Path) -> None: |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| plt.tight_layout() |
| plt.savefig(path, dpi=160, bbox_inches="tight") |
| plt.close() |
|
|
|
|
| def plot_class_distribution(df: pd.DataFrame, output_path: str | Path) -> None: |
| plt.figure(figsize=(8, 4.8)) |
| order = ["train", "val", "test"] |
| sns.countplot(data=df, x="split", hue="label", order=[s for s in order if s in set(df["split"])]) |
| plt.title("Class Distribution by Split") |
| plt.xlabel("Split") |
| plt.ylabel("Images") |
| _savefig(output_path) |
|
|
|
|
| def plot_confusion_matrix( |
| matrix: np.ndarray, |
| output_path: str | Path, |
| title: str, |
| class_names: Iterable[str] = CANONICAL_LABELS, |
| ) -> None: |
| plt.figure(figsize=(5.6, 4.8)) |
| sns.heatmap( |
| matrix, |
| annot=True, |
| fmt="d", |
| cmap="Blues", |
| xticklabels=list(class_names), |
| yticklabels=list(class_names), |
| cbar=False, |
| ) |
| plt.title(title) |
| plt.xlabel("Predicted") |
| plt.ylabel("True") |
| _savefig(output_path) |
|
|
|
|
| def plot_roc_curve_single( |
| y_true: np.ndarray, |
| y_prob: np.ndarray, |
| output_path: str | Path, |
| title: str, |
| ) -> float | None: |
| if len(np.unique(y_true)) < 2: |
| return None |
| fpr, tpr, _ = roc_curve(y_true, y_prob) |
| roc_auc = auc(fpr, tpr) |
| plt.figure(figsize=(5.8, 4.8)) |
| plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}", linewidth=2) |
| plt.plot([0, 1], [0, 1], linestyle="--", color="gray", linewidth=1) |
| plt.title(title) |
| plt.xlabel("False Positive Rate") |
| plt.ylabel("True Positive Rate") |
| plt.legend(loc="lower right") |
| _savefig(output_path) |
| return float(roc_auc) |
|
|
|
|
| def plot_precision_recall_curve_single( |
| y_true: np.ndarray, |
| y_prob: np.ndarray, |
| output_path: str | Path, |
| title: str, |
| ) -> float | None: |
| if len(np.unique(y_true)) < 2: |
| return None |
| precision, recall, _ = precision_recall_curve(y_true, y_prob) |
| pr_auc = auc(recall, precision) |
| plt.figure(figsize=(5.8, 4.8)) |
| plt.plot(recall, precision, label=f"PR AUC = {pr_auc:.3f}", linewidth=2) |
| plt.title(title) |
| plt.xlabel("Recall") |
| plt.ylabel("Precision") |
| plt.legend(loc="lower left") |
| _savefig(output_path) |
| return float(pr_auc) |
|
|
|
|
| def plot_combined_roc(prediction_frames: list[tuple[str, pd.DataFrame]], output_path: str | Path) -> None: |
| plt.figure(figsize=(7.2, 5.6)) |
| plotted = False |
| for model_name, frame in prediction_frames: |
| y_true = frame["y_true"].to_numpy() |
| y_prob = frame["prob_damaged"].to_numpy() |
| if len(np.unique(y_true)) < 2: |
| continue |
| fpr, tpr, _ = roc_curve(y_true, y_prob) |
| roc_auc = auc(fpr, tpr) |
| plt.plot(fpr, tpr, linewidth=2, label=f"{model_name} ({roc_auc:.3f})") |
| plotted = True |
| if not plotted: |
| plt.text(0.5, 0.5, "ROC unavailable: only one class present", ha="center", va="center") |
| plt.plot([0, 1], [0, 1], linestyle="--", color="gray", linewidth=1) |
| plt.title("Test ROC Comparison") |
| plt.xlabel("False Positive Rate") |
| plt.ylabel("True Positive Rate") |
| plt.legend(loc="lower right", fontsize=8) |
| _savefig(output_path) |
|
|
|
|
| def plot_metric_bars(metrics_df: pd.DataFrame, output_path: str | Path) -> None: |
| metric_cols = ["accuracy", "precision", "recall", "f1", "roc_auc", "balanced_accuracy"] |
| test_df = metrics_df[metrics_df["split"] == "test"].copy() |
| if test_df.empty: |
| test_df = metrics_df.copy() |
| melt = test_df.melt(id_vars=["model_name"], value_vars=metric_cols, var_name="metric", value_name="score") |
| plt.figure(figsize=(11, 6)) |
| sns.barplot(data=melt, x="metric", y="score", hue="model_name") |
| plt.ylim(0, 1.02) |
| plt.title("Model Metrics Comparison") |
| plt.xlabel("") |
| plt.ylabel("Score") |
| plt.xticks(rotation=25, ha="right") |
| plt.legend(loc="lower right", fontsize=8) |
| _savefig(output_path) |
|
|
|
|
| def plot_training_curves(history_df: pd.DataFrame, output_path: str | Path, model_name: str) -> None: |
| if history_df.empty: |
| return |
| fig, axes = plt.subplots(1, 2, figsize=(11, 4.5)) |
| axes[0].plot(history_df["epoch"], history_df["train_loss"], marker="o", label="Train") |
| axes[0].plot(history_df["epoch"], history_df["val_loss"], marker="o", label="Validation") |
| axes[0].set_title(f"{model_name}: Loss") |
| axes[0].set_xlabel("Epoch") |
| axes[0].set_ylabel("Loss") |
| axes[0].legend() |
| axes[1].plot(history_df["epoch"], history_df["train_accuracy"], marker="o", label="Train") |
| axes[1].plot(history_df["epoch"], history_df["val_accuracy"], marker="o", label="Validation") |
| if "val_f1" in history_df: |
| axes[1].plot(history_df["epoch"], history_df["val_f1"], marker="s", label="Val F1") |
| axes[1].set_title(f"{model_name}: Accuracy / F1") |
| axes[1].set_xlabel("Epoch") |
| axes[1].set_ylabel("Score") |
| axes[1].set_ylim(0, 1.02) |
| axes[1].legend() |
| Path(output_path).parent.mkdir(parents=True, exist_ok=True) |
| fig.tight_layout() |
| fig.savefig(output_path, dpi=160, bbox_inches="tight") |
| plt.close(fig) |
|
|
|
|
| def plot_sample_grid( |
| pred_df: pd.DataFrame, |
| output_path: str | Path, |
| title: str, |
| max_images: int = 12, |
| ) -> None: |
| sample = pred_df.head(max_images).copy() |
| cols = min(4, max(len(sample), 1)) |
| rows = int(np.ceil(max(len(sample), 1) / cols)) |
| fig, axes = plt.subplots(rows, cols, figsize=(cols * 3.2, rows * 3.4)) |
| axes_arr = np.asarray(axes).reshape(-1) |
| for ax in axes_arr: |
| ax.axis("off") |
| if sample.empty: |
| axes_arr[0].text(0.5, 0.5, "No samples", ha="center", va="center") |
| for ax, row in zip(axes_arr, sample.itertuples(index=False)): |
| img = Image.open(row.filepath) |
| img = ImageOps.exif_transpose(img).convert("RGB") |
| ax.imshow(img) |
| ax.set_title( |
| f"T: {row.label}\nP: {row.pred_label} ({row.confidence:.2f})", |
| fontsize=9, |
| ) |
| ax.axis("off") |
| fig.suptitle(title, fontsize=13) |
| Path(output_path).parent.mkdir(parents=True, exist_ok=True) |
| fig.tight_layout() |
| fig.savefig(output_path, dpi=160, bbox_inches="tight") |
| plt.close(fig) |
|
|
|
|
| def plot_calibration( |
| y_true: np.ndarray, |
| y_prob: np.ndarray, |
| output_path: str | Path, |
| title: str, |
| ) -> None: |
| if len(np.unique(y_true)) < 2: |
| return |
| prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=8, strategy="uniform") |
| plt.figure(figsize=(5.8, 4.8)) |
| plt.plot(prob_pred, prob_true, marker="o", linewidth=2) |
| plt.plot([0, 1], [0, 1], linestyle="--", color="gray", linewidth=1) |
| plt.title(title) |
| plt.xlabel("Mean Predicted Probability") |
| plt.ylabel("Fraction of Positives") |
| _savefig(output_path) |
|
|
|
|
| def write_markdown_report( |
| config: dict[str, Any], |
| splits_df: pd.DataFrame, |
| metrics_df: pd.DataFrame, |
| leaderboard_df: pd.DataFrame, |
| misclassified_df: pd.DataFrame, |
| output_path: str | Path, |
| ) -> None: |
| output_path = Path(output_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| best = leaderboard_df.iloc[0].to_dict() if not leaderboard_df.empty else {} |
| split_summary = markdown_table( |
| splits_df.groupby(["split", "label"]).size().unstack(fill_value=0).reset_index() |
| ) |
| metric_cols = [ |
| "model_name", |
| "split", |
| "accuracy", |
| "precision", |
| "recall", |
| "f1", |
| "roc_auc", |
| "balanced_accuracy", |
| "specificity", |
| "sensitivity", |
| ] |
| metrics_md = markdown_table(metrics_df[[c for c in metric_cols if c in metrics_df.columns]]) |
| error_text = "No misclassified test samples were recorded." |
| if not misclassified_df.empty: |
| by_model = misclassified_df.groupby("model_name").size().sort_values(ascending=False) |
| examples = misclassified_df.head(8)[["model_name", "label", "pred_label", "confidence", "filepath"]] |
| error_text = ( |
| "Misclassifications by model:\n\n" |
| + markdown_table(by_model.reset_index(name="count")) |
| + "\n\nExample errors:\n\n" |
| + markdown_table(examples) |
| ) |
| content = f"""# Egg Damage Classification Report |
| |
| ## Dataset Overview |
| |
| - Dataset path: `{config['paths']['data_dir']}` |
| - Task: binary classification, `Damaged` vs `Not Damaged` |
| - Split strategy: existing split folders are respected; otherwise stratified 70/15/15 splitting is used. |
| - Training balance: `{config.get('balance', {}).get('strategy', 'disabled')}`. |
| |
| ## Split Summary |
| |
| {split_summary} |
| |
| ## Preprocessing |
| |
| - Classical models: grayscale resize to {config['preprocessing']['image_size']}x{config['preprocessing']['image_size']}, HOG or LBP features, standardized SVM inputs. |
| - Deep models: ImageNet normalization, realistic train-only flips, small rotations, mild affine jitter, and light color jitter. |
| - SVM training curves are marked N/A because these models are not epoch-trained. |
| |
| ## Metrics |
| |
| {metrics_md} |
| |
| ## Best Model |
| |
| - Model: `{best.get('model_name', 'N/A')}` |
| - Test F1: `{best.get('f1', 'N/A')}` |
| - Test ROC-AUC: `{best.get('roc_auc', 'N/A')}` |
| - Test balanced accuracy: `{best.get('balanced_accuracy', 'N/A')}` |
| - Model path: `{best.get('model_path', 'N/A')}` |
| |
| ## Error Patterns |
| |
| {error_text} |
| |
| ## Deployment |
| |
| Run `python scripts/launch_gradio.py --config configs/default.yaml` to launch the local Gradio app. The app loads the best ranked model automatically and can switch among trained models. |
| """ |
| output_path.write_text(content, encoding="utf-8") |
|
|