""" Visualization utilities for multilabel protein localization training and evaluation. All plot functions save PNG files to ``output_dir`` (150 DPI) and return matplotlib Figure objects. """ from __future__ import annotations from pathlib import Path import json from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple, Union import matplotlib.pyplot as plt import numpy as np try: import mlflow except ImportError: # pragma: no cover mlflow = None # type: ignore import seaborn as sns from matplotlib.figure import Figure from sklearn.metrics import confusion_matrix, f1_score # Clean default style sns.set_theme(style="whitegrid", context="notebook") plt.rcParams.update({"figure.facecolor": "white", "axes.facecolor": "white"}) def _to_numpy(x: Any) -> np.ndarray: if isinstance(x, np.ndarray): return x if hasattr(x, "detach") and hasattr(x, "cpu"): return x.detach().cpu().numpy() return np.asarray(x) def _ensure_output_dir(output_dir: str | Path) -> Path: out = Path(output_dir).expanduser().resolve() out.mkdir(parents=True, exist_ok=True) return out def _subplot_grid(n: int, max_cols: int = 5) -> Tuple[int, int]: if n <= 0: raise ValueError("n must be positive") ncols = min(max_cols, n) nrows = int(np.ceil(n / ncols)) return nrows, ncols def plot_training_curves( train_losses: Sequence[float], val_losses: Sequence[float], output_dir: str | Path, best_epoch: Optional[int] = None, filename: str = "training_curves.png", ) -> Figure: """Plot train vs validation loss per epoch; vertical line at best epoch (1-based, min val loss if omitted).""" out = _ensure_output_dir(output_dir) train_losses = list(train_losses) val_losses = list(val_losses) if len(train_losses) != len(val_losses): raise ValueError("train_losses and val_losses must have the same length") if best_epoch is None and len(val_losses) > 0: best_epoch = int(np.argmin(np.asarray(val_losses, dtype=np.float64))) + 1 epochs = np.arange(1, len(train_losses) + 1) fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(epochs, train_losses, label="Train loss", color="#1f77b4", linewidth=2) ax.plot(epochs, val_losses, label="Validation loss", color="#ff7f0e", linewidth=2) if best_epoch is not None and 1 <= best_epoch <= len(train_losses): ax.axvline(best_epoch, color="gray", linestyle="--", linewidth=1.5, label=f"Best epoch ({best_epoch})") ax.set_xlabel("Epoch", fontsize=11) ax.set_ylabel("Loss", fontsize=11) ax.set_title("Training & Validation Loss", fontsize=13, fontweight="bold") ax.legend(loc="upper right", fontsize=10) ax.grid(True, alpha=0.35) fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def plot_per_class_f1_progression( epoch_metrics_list: Sequence[Mapping[str, float]], label_names: Sequence[str], output_dir: str | Path, filename: str = "per_class_f1_progression.png", ) -> Figure: """ ``epoch_metrics_list``: one dict per epoch, each mapping ``label_name -> f1`` for validation. """ out = _ensure_output_dir(output_dir) label_names = list(label_names) n_epochs = len(epoch_metrics_list) if n_epochs == 0: raise ValueError("epoch_metrics_list is empty") epochs = np.arange(1, n_epochs + 1) fig, ax = plt.subplots(figsize=(12, 6)) palette = sns.color_palette("husl", n_colors=max(len(label_names), 1)) for i, name in enumerate(label_names): series = [float(epoch_metrics_list[e].get(name, np.nan)) for e in range(n_epochs)] ax.plot(epochs, series, label=name, color=palette[i % len(palette)], linewidth=1.8) ax.set_xlabel("Epoch", fontsize=11) ax.set_ylabel("F1", fontsize=11) ax.set_title("Per-Class F1 Score Across Training", fontsize=13, fontweight="bold") ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", fontsize=8, frameon=True) ax.grid(True, alpha=0.35) fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def _auroc_color(val: float) -> str: if np.isnan(val): return "#cccccc" if val < 0.7: return "#d62728" if val <= 0.85: return "#ffbf00" return "#2ca02c" def plot_auroc_bars( test_metrics: Mapping[str, Any], label_names: Sequence[str], output_dir: str | Path, filename: str = "test_auroc_bars.png", ) -> Figure: """Horizontal bar chart of per-class AUROC (from ``test_metrics['per_class'][name]['auroc']``).""" out = _ensure_output_dir(output_dir) per_class = test_metrics.get("per_class", {}) rows: List[Tuple[str, float]] = [] for name in label_names: d = per_class.get(name, {}) a = float(d.get("auroc", np.nan)) rows.append((name, a)) rows.sort( key=lambda x: (0 if not np.isnan(x[1]) else 1, -(x[1] if not np.isnan(x[1]) else 0.0), x[0]), ) labels = [r[0] for r in rows] values = [r[1] for r in rows] fig, ax = plt.subplots(figsize=(10, max(4, 0.45 * len(labels)))) colors = [_auroc_color(v) for v in values] y_pos = np.arange(len(labels)) bars = ax.barh(y_pos, values, color=colors, edgecolor="white", linewidth=0.5) ax.set_yticks(y_pos) ax.set_yticklabels(labels, fontsize=10) ax.set_xlabel("AUROC", fontsize=11) ax.set_xlim(0, 1.05) ax.set_title("Test Set AUROC by Subcellular Location", fontsize=13, fontweight="bold") for bar, v in zip(bars, values): if np.isnan(v): ax.text(0.02, bar.get_y() + bar.get_height() / 2, "nan", va="center", fontsize=9) else: ax.text(min(v + 0.02, 1.0), bar.get_y() + bar.get_height() / 2, f"{v:.3f}", va="center", fontsize=9) ax.grid(True, axis="x", alpha=0.35) fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def plot_confusion_matrices( y_true: Any, y_pred: Any, label_names: Sequence[str], output_dir: str | Path, filename: str = "confusion_matrices.png", ) -> Figure: """One 2×2 confusion matrix per class (binary multilabel).""" out = _ensure_output_dir(output_dir) yt = _to_numpy(y_true).astype(np.int64) yp = _to_numpy(y_pred).astype(np.int64) if yt.shape != yp.shape or yt.ndim != 2: raise ValueError("y_true and y_pred must be 2D arrays of the same shape") n = len(label_names) nrows, ncols = _subplot_grid(n, max_cols=5) fig, axes = plt.subplots(nrows, ncols, figsize=(3.2 * ncols, 3.0 * nrows)) axes_flat = np.atleast_1d(axes).ravel() for i, name in enumerate(label_names): ax = axes_flat[i] cm = confusion_matrix(yt[:, i], yp[:, i], labels=[0, 1]) sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", cbar=False, ax=ax, xticklabels=["Pred 0", "Pred 1"], yticklabels=["True 0", "True 1"], ) ax.set_title(name, fontsize=10, fontweight="bold") for j in range(n, len(axes_flat)): axes_flat[j].set_visible(False) fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def _cooccurrence_matrix(y_bin: np.ndarray) -> np.ndarray: """Symmetric co-occurrence counts: (y.T @ y) for binary matrix y (n_samples, n_labels).""" return (y_bin.T @ y_bin).astype(np.float64) def plot_label_cooccurrence( y_true: Any, y_pred: Any, label_names: Sequence[str], output_dir: str | Path, filename: str = "label_cooccurrence.png", ) -> Figure: """Side-by-side heatmaps: ground-truth vs predicted label co-occurrence.""" out = _ensure_output_dir(output_dir) yt = _to_numpy(y_true).astype(np.int64) yp = _to_numpy(y_pred).astype(np.int64) if yt.shape != yp.shape: raise ValueError("y_true and y_pred must have the same shape") c_true = _cooccurrence_matrix(yt) c_pred = _cooccurrence_matrix(yp) vmax = max(c_true.max(), c_pred.max(), 1.0) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) sns.heatmap( c_true, xticklabels=label_names, yticklabels=label_names, annot=False, fmt=".0f", cmap="YlOrRd", ax=ax1, vmin=0, vmax=vmax, square=True, ) ax1.set_title("Ground truth co-occurrence", fontsize=12, fontweight="bold") sns.heatmap( c_pred, xticklabels=label_names, yticklabels=label_names, annot=False, fmt=".0f", cmap="YlOrRd", ax=ax2, vmin=0, vmax=vmax, square=True, ) ax2.set_title("Predicted co-occurrence", fontsize=12, fontweight="bold") fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def plot_probability_distributions( y_true: Any, y_pred_proba: Any, label_names: Sequence[str], output_dir: str | Path, filename: str = "probability_distributions.png", ) -> Figure: """Per-class histograms: predicted probability for negatives (red) vs positives (blue).""" out = _ensure_output_dir(output_dir) yt = _to_numpy(y_true).astype(np.float64) yp = _to_numpy(y_pred_proba).astype(np.float64) n = len(label_names) nrows, ncols = _subplot_grid(n, max_cols=5) fig, axes = plt.subplots(nrows, ncols, figsize=(3.2 * ncols, 2.8 * nrows)) axes_flat = np.atleast_1d(axes).ravel() for i, name in enumerate(label_names): ax = axes_flat[i] pos_mask = yt[:, i] > 0.5 neg_mask = ~pos_mask pos_p = yp[pos_mask, i] neg_p = yp[neg_mask, i] ax.hist(neg_p, bins=20, alpha=0.65, color="#d62728", label="Negative", density=True) ax.hist(pos_p, bins=20, alpha=0.65, color="#1f77b4", label="Positive", density=True) ax.set_title(name, fontsize=9, fontweight="bold") ax.set_xlim(0, 1) ax.tick_params(labelsize=8) if i == 0: ax.legend(fontsize=7, loc="upper right") for j in range(n, len(axes_flat)): axes_flat[j].set_visible(False) fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def plot_threshold_analysis( y_true: Any, y_pred_proba: Any, label_names: Sequence[str], output_dir: str | Path, filename: str = "threshold_analysis.png", thresholds: np.ndarray | None = None, ) -> Figure: """Per-class F1 vs threshold; marks optimal threshold (max F1) with a red dot.""" out = _ensure_output_dir(output_dir) yt = _to_numpy(y_true).astype(np.int64) yp = _to_numpy(y_pred_proba).astype(np.float64) if thresholds is None: thresholds = np.linspace(0.01, 0.99, 99) n = len(label_names) nrows, ncols = _subplot_grid(n, max_cols=5) fig, axes = plt.subplots(nrows, ncols, figsize=(3.2 * ncols, 2.6 * nrows)) axes_flat = np.atleast_1d(axes).ravel() for i, name in enumerate(label_names): ax = axes_flat[i] f1s = [] for thr in thresholds: yb = (yp[:, i] >= thr).astype(np.int64) f1s.append(f1_score(yt[:, i], yb, zero_division=0)) f1s = np.asarray(f1s) best_idx = int(np.argmax(f1s)) best_thr = float(thresholds[best_idx]) best_f1 = float(f1s[best_idx]) ax.plot(thresholds, f1s, color="#1f77b4", linewidth=1.5) ax.scatter([best_thr], [best_f1], color="red", s=40, zorder=5) ax.set_title(name, fontsize=9, fontweight="bold") ax.set_xlabel("Threshold", fontsize=8) ax.set_ylabel("F1", fontsize=8) ax.tick_params(labelsize=7) ax.grid(True, alpha=0.35) for j in range(n, len(axes_flat)): axes_flat[j].set_visible(False) fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def generate_all_plots( *, output_dir: str | Path, train_losses: Sequence[float], val_losses: Sequence[float], epoch_metrics_list: Sequence[Mapping[str, float]], label_names: Sequence[str], test_metrics: Mapping[str, Any], y_true: Any, y_pred: Any, y_pred_proba: Any, best_epoch: Optional[int] = None, mlflow_log: bool = True, ) -> Dict[str, Figure]: """ Generate all analysis plots and optionally log each PNG to the active MLflow run. ``epoch_metrics_list``: one dict per epoch, each mapping ``label_name -> validation F1``. """ out = _ensure_output_dir(output_dir) figures: Dict[str, Figure] = {} figures["training_curves"] = plot_training_curves( train_losses, val_losses, out, best_epoch=best_epoch ) figures["per_class_f1"] = plot_per_class_f1_progression(epoch_metrics_list, label_names, out) figures["auroc_bars"] = plot_auroc_bars(test_metrics, label_names, out) figures["confusion"] = plot_confusion_matrices(y_true, y_pred, label_names, out) figures["cooccurrence"] = plot_label_cooccurrence(y_true, y_pred, label_names, out) figures["proba_hist"] = plot_probability_distributions(y_true, y_pred_proba, label_names, out) figures["threshold"] = plot_threshold_analysis(y_true, y_pred_proba, label_names, out) if mlflow_log and mlflow is not None: try: if mlflow.active_run() is not None: artifact_files = [ "training_curves.png", "per_class_f1_progression.png", "test_auroc_bars.png", "confusion_matrices.png", "label_cooccurrence.png", "probability_distributions.png", "threshold_analysis.png", ] for fname in artifact_files: p = out / fname if p.is_file(): mlflow.log_artifact(str(p), artifact_path="plots") except Exception: pass return figures def plot_training_curves_comparison( train_losses_a: Sequence[float], val_losses_a: Sequence[float], train_losses_b: Sequence[float], val_losses_b: Sequence[float], label_a: str, label_b: str, output_dir: str | Path, best_epoch_a: Optional[int] = None, best_epoch_b: Optional[int] = None, filename: str = "compare_training_val_loss.png", ) -> Figure: """Overlay train/val loss curves for two runs.""" out = _ensure_output_dir(output_dir) fig, ax = plt.subplots(figsize=(11, 5)) ea = np.arange(1, len(train_losses_a) + 1) eb = np.arange(1, len(train_losses_b) + 1) ax.plot(ea, train_losses_a, label=f"Train ({label_a})", color="#1f77b4", linewidth=2, linestyle="-") ax.plot(ea, val_losses_a, label=f"Val ({label_a})", color="#aec7e8", linewidth=2, linestyle="-") ax.plot(eb, train_losses_b, label=f"Train ({label_b})", color="#d62728", linewidth=2, linestyle="--") ax.plot(eb, val_losses_b, label=f"Val ({label_b})", color="#ff9896", linewidth=2, linestyle="--") if best_epoch_a is not None and 1 <= best_epoch_a <= len(val_losses_a): ax.axvline(best_epoch_a, color="#1f77b4", linestyle=":", linewidth=1.2, alpha=0.8) if best_epoch_b is not None and 1 <= best_epoch_b <= len(val_losses_b): ax.axvline(best_epoch_b, color="#d62728", linestyle=":", linewidth=1.2, alpha=0.8) ax.set_xlabel("Epoch", fontsize=11) ax.set_ylabel("Loss", fontsize=11) ax.set_title("Training & Validation Loss — comparison", fontsize=13, fontweight="bold") ax.legend(loc="upper right", fontsize=9) ax.grid(True, alpha=0.35) fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def plot_val_macro_f1_comparison( macro_f1_a: Sequence[float], macro_f1_b: Sequence[float], label_a: str, label_b: str, output_dir: str | Path, filename: str = "compare_val_macro_f1.png", ) -> Figure: out = _ensure_output_dir(output_dir) fig, ax = plt.subplots(figsize=(10, 4.5)) ax.plot(np.arange(1, len(macro_f1_a) + 1), macro_f1_a, label=label_a, color="#1f77b4", linewidth=2) ax.plot(np.arange(1, len(macro_f1_b) + 1), macro_f1_b, label=label_b, color="#d62728", linewidth=2, linestyle="--") ax.set_xlabel("Epoch", fontsize=11) ax.set_ylabel("Validation macro F1", fontsize=11) ax.set_title("Validation macro F1 — comparison", fontsize=13, fontweight="bold") ax.legend(loc="lower right", fontsize=10) ax.grid(True, alpha=0.35) fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def plot_per_class_f1_panels( epoch_metrics_a: Sequence[Mapping[str, float]], epoch_metrics_b: Sequence[Mapping[str, float]], label_names: Sequence[str], title_a: str, title_b: str, output_dir: str | Path, filename: str = "compare_per_class_f1_panels.png", ) -> Figure: """Side-by-side panels: per-class F1 vs epoch for each run.""" out = _ensure_output_dir(output_dir) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True) palette = sns.color_palette("husl", n_colors=max(len(label_names), 1)) n_a, n_b = len(epoch_metrics_a), len(epoch_metrics_b) for i, name in enumerate(label_names): sa = [float(epoch_metrics_a[e].get(name, np.nan)) for e in range(n_a)] sb = [float(epoch_metrics_b[e].get(name, np.nan)) for e in range(n_b)] ax1.plot(np.arange(1, n_a + 1), sa, label=name, color=palette[i % len(palette)], linewidth=1.5) ax2.plot(np.arange(1, n_b + 1), sb, label=name, color=palette[i % len(palette)], linewidth=1.5) ax1.set_title(title_a, fontsize=12, fontweight="bold") ax2.set_title(title_b, fontsize=12, fontweight="bold") ax1.set_xlabel("Epoch") ax2.set_xlabel("Epoch") ax1.set_ylabel("F1") ax1.grid(True, alpha=0.35) ax2.grid(True, alpha=0.35) h1, l1 = ax1.get_legend_handles_labels() fig.legend(h1, l1, loc="center left", bbox_to_anchor=(1.02, 0.5), fontsize=7, title="Class") fig.suptitle("Per-class validation F1 — comparison", fontsize=14, fontweight="bold") fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def plot_auroc_comparison_bars( test_metrics_a: Mapping[str, Any], test_metrics_b: Mapping[str, Any], label_names: Sequence[str], label_a: str, label_b: str, output_dir: str | Path, filename: str = "compare_test_auroc.png", ) -> Figure: """Grouped horizontal bars: AUROC per class for two runs.""" out = _ensure_output_dir(output_dir) names = list(label_names) y = np.arange(len(names)) h = 0.35 va = [float(test_metrics_a.get("per_class", {}).get(n, {}).get("auroc", np.nan)) for n in names] vb = [float(test_metrics_b.get("per_class", {}).get(n, {}).get("auroc", np.nan)) for n in names] fig, ax = plt.subplots(figsize=(10, max(5, 0.42 * len(names)))) ax.barh(y - h / 2, va, h, label=label_a, color="#1f77b4", alpha=0.85) ax.barh(y + h / 2, vb, h, label=label_b, color="#d62728", alpha=0.85) ax.set_yticks(y) ax.set_yticklabels(names, fontsize=9) ax.set_xlabel("AUROC", fontsize=11) ax.set_xlim(0, 1.05) ax.set_title("Test AUROC by class — comparison", fontsize=13, fontweight="bold") ax.legend(loc="lower right") ax.grid(True, axis="x", alpha=0.35) fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def plot_confusion_matrices_two_runs( y_true_a: Any, y_pred_a: Any, y_true_b: Any, y_pred_b: Any, label_names: Sequence[str], row_title_a: str, row_title_b: str, output_dir: str | Path, filename: str = "compare_confusion_matrices.png", ) -> Figure: """Two rows of per-class confusion matrices (binary multilabel).""" out = _ensure_output_dir(output_dir) yt_a = _to_numpy(y_true_a).astype(np.int64) yp_a = _to_numpy(y_pred_a).astype(np.int64) yt_b = _to_numpy(y_true_b).astype(np.int64) yp_b = _to_numpy(y_pred_b).astype(np.int64) n = len(label_names) nrows, ncols = _subplot_grid(n, max_cols=5) fig, axes = plt.subplots(2 * nrows, ncols, figsize=(3.2 * ncols, 3.0 * 2 * nrows)) axes_arr = np.atleast_2d(axes) for i, name in enumerate(label_names): r0, c0 = divmod(i, ncols) for ver, (yt, yp, rtitle) in enumerate( [(yt_a, yp_a, row_title_a), (yt_b, yp_b, row_title_b)] ): ax = axes_arr[ver * nrows + r0, c0] cm = confusion_matrix(yt[:, i], yp[:, i], labels=[0, 1]) sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", cbar=False, ax=ax, xticklabels=["P0", "P1"], yticklabels=["T0", "T1"], ) ax.set_title(f"{name}\n{rtitle}", fontsize=8, fontweight="bold") used: Set[Tuple[int, int]] = set() for i in range(n): r0, c0 = divmod(i, ncols) for ver in (0, 1): used.add((ver * nrows + r0, c0)) for r in range(2 * nrows): for c in range(ncols): if (r, c) not in used: axes_arr[r, c].set_visible(False) fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def plot_label_cooccurrence_four( y_true_a: Any, y_pred_a: Any, y_true_b: Any, y_pred_b: Any, label_names: Sequence[str], titles: Sequence[str], output_dir: str | Path, filename: str = "compare_label_cooccurrence.png", ) -> Figure: """Four heatmaps: GT/pred co-occurrence for each run.""" out = _ensure_output_dir(output_dir) mats = [] for yt, yp in [(y_true_a, y_pred_a), (y_true_b, y_pred_b)]: yt = _to_numpy(yt).astype(np.int64) yp = _to_numpy(yp).astype(np.int64) mats.append((_cooccurrence_matrix(yt), _cooccurrence_matrix(yp))) vmax = 1.0 for gt_mat, pred_mat in mats: vmax = max(vmax, float(gt_mat.max()), float(pred_mat.max())) fig, axes = plt.subplots(2, 2, figsize=(14, 12)) flat_titles = [titles[0], titles[1], titles[2], titles[3]] idx = 0 for row, pair in enumerate(mats): for col, cmat in enumerate(pair): ax = axes[row, col] sns.heatmap( cmat, xticklabels=label_names, yticklabels=label_names, cmap="YlOrRd", ax=ax, vmin=0, vmax=vmax, square=True, ) ax.set_title(flat_titles[idx], fontsize=11, fontweight="bold") idx += 1 fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def plot_probability_distributions_two_runs( y_true_a: Any, y_proba_a: Any, y_true_b: Any, y_proba_b: Any, label_names: Sequence[str], title_a: str, title_b: str, output_dir: str | Path, filename: str = "compare_probability_distributions.png", ) -> Figure: """Two rows of per-class probability histograms (neg vs pos).""" out = _ensure_output_dir(output_dir) n = len(label_names) nrows, ncols = _subplot_grid(n, max_cols=5) fig, axes = plt.subplots(2 * nrows, ncols, figsize=(3.2 * ncols, 2.6 * 2 * nrows)) axes_arr = np.atleast_2d(axes) def _row(yt: Any, yp: Any, ver: int) -> None: yt = _to_numpy(yt).astype(np.float64) yp = _to_numpy(yp).astype(np.float64) for i, name in enumerate(label_names): r0, c0 = divmod(i, ncols) ax = axes_arr[ver * nrows + r0, c0] pos_mask = yt[:, i] > 0.5 neg_mask = ~pos_mask ax.hist(yp[neg_mask, i], bins=20, alpha=0.6, color="#d62728", density=True, label="Neg") ax.hist(yp[pos_mask, i], bins=20, alpha=0.6, color="#1f77b4", density=True, label="Pos") ax.set_title(f"{name}", fontsize=8, fontweight="bold") ax.set_xlim(0, 1) _row(y_true_a, y_proba_a, 0) _row(y_true_b, y_proba_b, 1) for j in range(n, nrows * ncols): r0, c0 = divmod(j, ncols) for ver in range(2): axes_arr[ver * nrows + r0, c0].set_visible(False) fig.suptitle(f"Predicted probability distributions — {title_a} (top) vs {title_b} (bottom)", fontsize=12) fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def plot_threshold_analysis_two_runs( y_true_a: Any, y_proba_a: Any, y_true_b: Any, y_proba_b: Any, label_names: Sequence[str], title_a: str, title_b: str, output_dir: str | Path, filename: str = "compare_threshold_analysis.png", ) -> Figure: """Two rows: F1 vs threshold per class.""" out = _ensure_output_dir(output_dir) thresholds = np.linspace(0.01, 0.99, 99) n = len(label_names) nrows, ncols = _subplot_grid(n, max_cols=5) fig, axes = plt.subplots(2 * nrows, ncols, figsize=(3.2 * ncols, 2.4 * 2 * nrows)) axes_arr = np.atleast_2d(axes) def _fill(yt: Any, yp: Any, ver: int) -> None: yt = _to_numpy(yt).astype(np.int64) yp = _to_numpy(yp).astype(np.float64) for i, name in enumerate(label_names): r0, c0 = divmod(i, ncols) ax = axes_arr[ver * nrows + r0, c0] f1s = [f1_score(yt[:, i], (yp[:, i] >= thr).astype(np.int64), zero_division=0) for thr in thresholds] ax.plot(thresholds, f1s, color="#1f77b4", linewidth=1.2) bi = int(np.argmax(f1s)) ax.scatter([thresholds[bi]], [f1s[bi]], color="red", s=25, zorder=5) ax.set_title(f"{name}", fontsize=7, fontweight="bold") ax.set_xlabel("Thr", fontsize=7) ax.set_ylabel("F1", fontsize=7) _fill(y_true_a, y_proba_a, 0) _fill(y_true_b, y_proba_b, 1) for j in range(n, nrows * ncols): r0, c0 = divmod(j, ncols) for ver in range(2): axes_arr[ver * nrows + r0, c0].set_visible(False) fig.suptitle(f"F1 vs threshold — {title_a} (top) vs {title_b} (bottom)", fontsize=12) fig.tight_layout() fig.savefig(out / filename, dpi=150, bbox_inches="tight") return fig def _load_json(path: Path) -> Any: return json.loads(path.read_text(encoding="utf-8")) def generate_comparison_plots( *, v1_artifact_dir: Union[str, Path], v2_artifact_dir: Union[str, Path], label_names: Sequence[str], output_dir: Union[str, Path], label_v1: str = "baseline_v1", label_v2: str = "regularized_v2", mlflow_log: bool = True, ) -> Dict[str, Figure]: """ Load two run artifact folders (``train_history.json``, ``final_test_metrics.json``, ``test_predictions.npz``) and write comparison figures. """ d1 = Path(v1_artifact_dir).expanduser().resolve() d2 = Path(v2_artifact_dir).expanduser().resolve() out = _ensure_output_dir(output_dir) h1 = _load_json(d1 / "train_history.json") h2 = _load_json(d2 / "train_history.json") m1 = _load_json(d1 / "final_test_metrics.json") m2 = _load_json(d2 / "final_test_metrics.json") z1 = np.load(d1 / "test_predictions.npz") z2 = np.load(d2 / "test_predictions.npz") figures: Dict[str, Figure] = {} be1 = h1.get("best_epoch") be2 = h2.get("best_epoch") be1i = int(be1) if be1 is not None and int(be1) >= 1 else None be2i = int(be2) if be2 is not None and int(be2) >= 1 else None figures["compare_loss"] = plot_training_curves_comparison( h1.get("train_loss", []), h1.get("val_loss", []), h2.get("train_loss", []), h2.get("val_loss", []), label_v1, label_v2, out, best_epoch_a=be1i, best_epoch_b=be2i, ) figures["compare_macro_f1"] = plot_val_macro_f1_comparison( h1.get("val_macro_f1", []), h2.get("val_macro_f1", []), label_v1, label_v2, out, ) figures["compare_per_class_f1"] = plot_per_class_f1_panels( h1.get("val_per_class_f1", []), h2.get("val_per_class_f1", []), label_names, label_v1, label_v2, out, ) figures["compare_auroc"] = plot_auroc_comparison_bars(m1, m2, label_names, label_v1, label_v2, out) figures["compare_confusion"] = plot_confusion_matrices_two_runs( z1["y_true"], z1["y_pred_binary"], z2["y_true"], z2["y_pred_binary"], label_names, label_v1, label_v2, out, ) figures["compare_cooc"] = plot_label_cooccurrence_four( z1["y_true"], z1["y_pred_binary"], z2["y_true"], z2["y_pred_binary"], label_names, (f"{label_v1} GT", f"{label_v1} pred", f"{label_v2} GT", f"{label_v2} pred"), out, ) figures["compare_proba"] = plot_probability_distributions_two_runs( z1["y_true"], z1["y_pred_proba"], z2["y_true"], z2["y_pred_proba"], label_names, label_v1, label_v2, out, ) figures["compare_thr"] = plot_threshold_analysis_two_runs( z1["y_true"], z1["y_pred_proba"], z2["y_true"], z2["y_pred_proba"], label_names, label_v1, label_v2, out, ) if mlflow_log and mlflow is not None: try: if mlflow.active_run() is not None: for fname in ( "compare_training_val_loss.png", "compare_val_macro_f1.png", "compare_per_class_f1_panels.png", "compare_test_auroc.png", "compare_confusion_matrices.png", "compare_label_cooccurrence.png", "compare_probability_distributions.png", "compare_threshold_analysis.png", ): p = out / fname if p.is_file(): mlflow.log_artifact(str(p), artifact_path="comparison_plots") except Exception: pass return figures