budijuarto commited on
Commit
0258b57
·
verified ·
1 Parent(s): 5a7bdcb

Upload src/egg_damage/reporting.py

Browse files
Files changed (1) hide show
  1. src/egg_damage/reporting.py +307 -0
src/egg_damage/reporting.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Iterable
5
+
6
+ import matplotlib
7
+
8
+ matplotlib.use("Agg")
9
+
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import pandas as pd
13
+ import seaborn as sns
14
+ from PIL import Image, ImageOps
15
+ from sklearn.calibration import calibration_curve
16
+ from sklearn.metrics import auc, precision_recall_curve, roc_curve
17
+
18
+ from .data_discovery import CANONICAL_LABELS
19
+
20
+
21
+ sns.set_theme(style="whitegrid", context="notebook")
22
+
23
+
24
+ def markdown_table(df: pd.DataFrame) -> str:
25
+ if df.empty:
26
+ return "_No rows._"
27
+ safe = df.copy()
28
+ safe = safe.fillna("")
29
+ headers = [str(col) for col in safe.columns]
30
+ lines = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"]
31
+ for row in safe.itertuples(index=False):
32
+ values = [str(value).replace("\n", " ") for value in row]
33
+ lines.append("| " + " | ".join(values) + " |")
34
+ return "\n".join(lines)
35
+
36
+
37
+ def _savefig(path: str | Path) -> None:
38
+ path = Path(path)
39
+ path.parent.mkdir(parents=True, exist_ok=True)
40
+ plt.tight_layout()
41
+ plt.savefig(path, dpi=160, bbox_inches="tight")
42
+ plt.close()
43
+
44
+
45
+ def plot_class_distribution(df: pd.DataFrame, output_path: str | Path) -> None:
46
+ plt.figure(figsize=(8, 4.8))
47
+ order = ["train", "val", "test"]
48
+ sns.countplot(data=df, x="split", hue="label", order=[s for s in order if s in set(df["split"])])
49
+ plt.title("Class Distribution by Split")
50
+ plt.xlabel("Split")
51
+ plt.ylabel("Images")
52
+ _savefig(output_path)
53
+
54
+
55
+ def plot_confusion_matrix(
56
+ matrix: np.ndarray,
57
+ output_path: str | Path,
58
+ title: str,
59
+ class_names: Iterable[str] = CANONICAL_LABELS,
60
+ ) -> None:
61
+ plt.figure(figsize=(5.6, 4.8))
62
+ sns.heatmap(
63
+ matrix,
64
+ annot=True,
65
+ fmt="d",
66
+ cmap="Blues",
67
+ xticklabels=list(class_names),
68
+ yticklabels=list(class_names),
69
+ cbar=False,
70
+ )
71
+ plt.title(title)
72
+ plt.xlabel("Predicted")
73
+ plt.ylabel("True")
74
+ _savefig(output_path)
75
+
76
+
77
+ def plot_roc_curve_single(
78
+ y_true: np.ndarray,
79
+ y_prob: np.ndarray,
80
+ output_path: str | Path,
81
+ title: str,
82
+ ) -> float | None:
83
+ if len(np.unique(y_true)) < 2:
84
+ return None
85
+ fpr, tpr, _ = roc_curve(y_true, y_prob)
86
+ roc_auc = auc(fpr, tpr)
87
+ plt.figure(figsize=(5.8, 4.8))
88
+ plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}", linewidth=2)
89
+ plt.plot([0, 1], [0, 1], linestyle="--", color="gray", linewidth=1)
90
+ plt.title(title)
91
+ plt.xlabel("False Positive Rate")
92
+ plt.ylabel("True Positive Rate")
93
+ plt.legend(loc="lower right")
94
+ _savefig(output_path)
95
+ return float(roc_auc)
96
+
97
+
98
+ def plot_precision_recall_curve_single(
99
+ y_true: np.ndarray,
100
+ y_prob: np.ndarray,
101
+ output_path: str | Path,
102
+ title: str,
103
+ ) -> float | None:
104
+ if len(np.unique(y_true)) < 2:
105
+ return None
106
+ precision, recall, _ = precision_recall_curve(y_true, y_prob)
107
+ pr_auc = auc(recall, precision)
108
+ plt.figure(figsize=(5.8, 4.8))
109
+ plt.plot(recall, precision, label=f"PR AUC = {pr_auc:.3f}", linewidth=2)
110
+ plt.title(title)
111
+ plt.xlabel("Recall")
112
+ plt.ylabel("Precision")
113
+ plt.legend(loc="lower left")
114
+ _savefig(output_path)
115
+ return float(pr_auc)
116
+
117
+
118
+ def plot_combined_roc(prediction_frames: list[tuple[str, pd.DataFrame]], output_path: str | Path) -> None:
119
+ plt.figure(figsize=(7.2, 5.6))
120
+ plotted = False
121
+ for model_name, frame in prediction_frames:
122
+ y_true = frame["y_true"].to_numpy()
123
+ y_prob = frame["prob_damaged"].to_numpy()
124
+ if len(np.unique(y_true)) < 2:
125
+ continue
126
+ fpr, tpr, _ = roc_curve(y_true, y_prob)
127
+ roc_auc = auc(fpr, tpr)
128
+ plt.plot(fpr, tpr, linewidth=2, label=f"{model_name} ({roc_auc:.3f})")
129
+ plotted = True
130
+ if not plotted:
131
+ plt.text(0.5, 0.5, "ROC unavailable: only one class present", ha="center", va="center")
132
+ plt.plot([0, 1], [0, 1], linestyle="--", color="gray", linewidth=1)
133
+ plt.title("Test ROC Comparison")
134
+ plt.xlabel("False Positive Rate")
135
+ plt.ylabel("True Positive Rate")
136
+ plt.legend(loc="lower right", fontsize=8)
137
+ _savefig(output_path)
138
+
139
+
140
+ def plot_metric_bars(metrics_df: pd.DataFrame, output_path: str | Path) -> None:
141
+ metric_cols = ["accuracy", "precision", "recall", "f1", "roc_auc", "balanced_accuracy"]
142
+ test_df = metrics_df[metrics_df["split"] == "test"].copy()
143
+ if test_df.empty:
144
+ test_df = metrics_df.copy()
145
+ melt = test_df.melt(id_vars=["model_name"], value_vars=metric_cols, var_name="metric", value_name="score")
146
+ plt.figure(figsize=(11, 6))
147
+ sns.barplot(data=melt, x="metric", y="score", hue="model_name")
148
+ plt.ylim(0, 1.02)
149
+ plt.title("Model Metrics Comparison")
150
+ plt.xlabel("")
151
+ plt.ylabel("Score")
152
+ plt.xticks(rotation=25, ha="right")
153
+ plt.legend(loc="lower right", fontsize=8)
154
+ _savefig(output_path)
155
+
156
+
157
+ def plot_training_curves(history_df: pd.DataFrame, output_path: str | Path, model_name: str) -> None:
158
+ if history_df.empty:
159
+ return
160
+ fig, axes = plt.subplots(1, 2, figsize=(11, 4.5))
161
+ axes[0].plot(history_df["epoch"], history_df["train_loss"], marker="o", label="Train")
162
+ axes[0].plot(history_df["epoch"], history_df["val_loss"], marker="o", label="Validation")
163
+ axes[0].set_title(f"{model_name}: Loss")
164
+ axes[0].set_xlabel("Epoch")
165
+ axes[0].set_ylabel("Loss")
166
+ axes[0].legend()
167
+ axes[1].plot(history_df["epoch"], history_df["train_accuracy"], marker="o", label="Train")
168
+ axes[1].plot(history_df["epoch"], history_df["val_accuracy"], marker="o", label="Validation")
169
+ if "val_f1" in history_df:
170
+ axes[1].plot(history_df["epoch"], history_df["val_f1"], marker="s", label="Val F1")
171
+ axes[1].set_title(f"{model_name}: Accuracy / F1")
172
+ axes[1].set_xlabel("Epoch")
173
+ axes[1].set_ylabel("Score")
174
+ axes[1].set_ylim(0, 1.02)
175
+ axes[1].legend()
176
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
177
+ fig.tight_layout()
178
+ fig.savefig(output_path, dpi=160, bbox_inches="tight")
179
+ plt.close(fig)
180
+
181
+
182
+ def plot_sample_grid(
183
+ pred_df: pd.DataFrame,
184
+ output_path: str | Path,
185
+ title: str,
186
+ max_images: int = 12,
187
+ ) -> None:
188
+ sample = pred_df.head(max_images).copy()
189
+ cols = min(4, max(len(sample), 1))
190
+ rows = int(np.ceil(max(len(sample), 1) / cols))
191
+ fig, axes = plt.subplots(rows, cols, figsize=(cols * 3.2, rows * 3.4))
192
+ axes_arr = np.asarray(axes).reshape(-1)
193
+ for ax in axes_arr:
194
+ ax.axis("off")
195
+ if sample.empty:
196
+ axes_arr[0].text(0.5, 0.5, "No samples", ha="center", va="center")
197
+ for ax, row in zip(axes_arr, sample.itertuples(index=False)):
198
+ img = Image.open(row.filepath)
199
+ img = ImageOps.exif_transpose(img).convert("RGB")
200
+ ax.imshow(img)
201
+ ax.set_title(
202
+ f"T: {row.label}\nP: {row.pred_label} ({row.confidence:.2f})",
203
+ fontsize=9,
204
+ )
205
+ ax.axis("off")
206
+ fig.suptitle(title, fontsize=13)
207
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
208
+ fig.tight_layout()
209
+ fig.savefig(output_path, dpi=160, bbox_inches="tight")
210
+ plt.close(fig)
211
+
212
+
213
+ def plot_calibration(
214
+ y_true: np.ndarray,
215
+ y_prob: np.ndarray,
216
+ output_path: str | Path,
217
+ title: str,
218
+ ) -> None:
219
+ if len(np.unique(y_true)) < 2:
220
+ return
221
+ prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=8, strategy="uniform")
222
+ plt.figure(figsize=(5.8, 4.8))
223
+ plt.plot(prob_pred, prob_true, marker="o", linewidth=2)
224
+ plt.plot([0, 1], [0, 1], linestyle="--", color="gray", linewidth=1)
225
+ plt.title(title)
226
+ plt.xlabel("Mean Predicted Probability")
227
+ plt.ylabel("Fraction of Positives")
228
+ _savefig(output_path)
229
+
230
+
231
+ def write_markdown_report(
232
+ config: dict[str, Any],
233
+ splits_df: pd.DataFrame,
234
+ metrics_df: pd.DataFrame,
235
+ leaderboard_df: pd.DataFrame,
236
+ misclassified_df: pd.DataFrame,
237
+ output_path: str | Path,
238
+ ) -> None:
239
+ output_path = Path(output_path)
240
+ output_path.parent.mkdir(parents=True, exist_ok=True)
241
+ best = leaderboard_df.iloc[0].to_dict() if not leaderboard_df.empty else {}
242
+ split_summary = markdown_table(
243
+ splits_df.groupby(["split", "label"]).size().unstack(fill_value=0).reset_index()
244
+ )
245
+ metric_cols = [
246
+ "model_name",
247
+ "split",
248
+ "accuracy",
249
+ "precision",
250
+ "recall",
251
+ "f1",
252
+ "roc_auc",
253
+ "balanced_accuracy",
254
+ "specificity",
255
+ "sensitivity",
256
+ ]
257
+ metrics_md = markdown_table(metrics_df[[c for c in metric_cols if c in metrics_df.columns]])
258
+ error_text = "No misclassified test samples were recorded."
259
+ if not misclassified_df.empty:
260
+ by_model = misclassified_df.groupby("model_name").size().sort_values(ascending=False)
261
+ examples = misclassified_df.head(8)[["model_name", "label", "pred_label", "confidence", "filepath"]]
262
+ error_text = (
263
+ "Misclassifications by model:\n\n"
264
+ + markdown_table(by_model.reset_index(name="count"))
265
+ + "\n\nExample errors:\n\n"
266
+ + markdown_table(examples)
267
+ )
268
+ content = f"""# Egg Damage Classification Report
269
+
270
+ ## Dataset Overview
271
+
272
+ - Dataset path: `{config['paths']['data_dir']}`
273
+ - Task: binary classification, `Damaged` vs `Not Damaged`
274
+ - Split strategy: existing split folders are respected; otherwise stratified 70/15/15 splitting is used.
275
+ - Training balance: `{config.get('balance', {}).get('strategy', 'disabled')}`.
276
+
277
+ ## Split Summary
278
+
279
+ {split_summary}
280
+
281
+ ## Preprocessing
282
+
283
+ - Classical models: grayscale resize to {config['preprocessing']['image_size']}x{config['preprocessing']['image_size']}, HOG or LBP features, standardized SVM inputs.
284
+ - Deep models: ImageNet normalization, realistic train-only flips, small rotations, mild affine jitter, and light color jitter.
285
+ - SVM training curves are marked N/A because these models are not epoch-trained.
286
+
287
+ ## Metrics
288
+
289
+ {metrics_md}
290
+
291
+ ## Best Model
292
+
293
+ - Model: `{best.get('model_name', 'N/A')}`
294
+ - Test F1: `{best.get('f1', 'N/A')}`
295
+ - Test ROC-AUC: `{best.get('roc_auc', 'N/A')}`
296
+ - Test balanced accuracy: `{best.get('balanced_accuracy', 'N/A')}`
297
+ - Model path: `{best.get('model_path', 'N/A')}`
298
+
299
+ ## Error Patterns
300
+
301
+ {error_text}
302
+
303
+ ## Deployment
304
+
305
+ Run `python scripts/launch_gradio.py --config configs/default.yaml` to launch the local Gradio app. The app loads the best ranked model automatically and can switch among trained models.
306
+ """
307
+ output_path.write_text(content, encoding="utf-8")