Spaces:
Running
Running
| """ | |
| Visualization utilities for multilabel protein localization training and evaluation. | |
| All plot functions save PNG files to ``output_dir`` (150 DPI) and return matplotlib Figure objects. | |
| """ | |
| from __future__ import annotations | |
| from pathlib import Path | |
| import json | |
| from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple, Union | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| try: | |
| import mlflow | |
| except ImportError: # pragma: no cover | |
| mlflow = None # type: ignore | |
| import seaborn as sns | |
| from matplotlib.figure import Figure | |
| from sklearn.metrics import confusion_matrix, f1_score | |
| # Clean default style | |
| sns.set_theme(style="whitegrid", context="notebook") | |
| plt.rcParams.update({"figure.facecolor": "white", "axes.facecolor": "white"}) | |
| def _to_numpy(x: Any) -> np.ndarray: | |
| if isinstance(x, np.ndarray): | |
| return x | |
| if hasattr(x, "detach") and hasattr(x, "cpu"): | |
| return x.detach().cpu().numpy() | |
| return np.asarray(x) | |
| def _ensure_output_dir(output_dir: str | Path) -> Path: | |
| out = Path(output_dir).expanduser().resolve() | |
| out.mkdir(parents=True, exist_ok=True) | |
| return out | |
| def _subplot_grid(n: int, max_cols: int = 5) -> Tuple[int, int]: | |
| if n <= 0: | |
| raise ValueError("n must be positive") | |
| ncols = min(max_cols, n) | |
| nrows = int(np.ceil(n / ncols)) | |
| return nrows, ncols | |
| def plot_training_curves( | |
| train_losses: Sequence[float], | |
| val_losses: Sequence[float], | |
| output_dir: str | Path, | |
| best_epoch: Optional[int] = None, | |
| filename: str = "training_curves.png", | |
| ) -> Figure: | |
| """Plot train vs validation loss per epoch; vertical line at best epoch (1-based, min val loss if omitted).""" | |
| out = _ensure_output_dir(output_dir) | |
| train_losses = list(train_losses) | |
| val_losses = list(val_losses) | |
| if len(train_losses) != len(val_losses): | |
| raise ValueError("train_losses and val_losses must have the same length") | |
| if best_epoch is None and len(val_losses) > 0: | |
| best_epoch = int(np.argmin(np.asarray(val_losses, dtype=np.float64))) + 1 | |
| epochs = np.arange(1, len(train_losses) + 1) | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| ax.plot(epochs, train_losses, label="Train loss", color="#1f77b4", linewidth=2) | |
| ax.plot(epochs, val_losses, label="Validation loss", color="#ff7f0e", linewidth=2) | |
| if best_epoch is not None and 1 <= best_epoch <= len(train_losses): | |
| ax.axvline(best_epoch, color="gray", linestyle="--", linewidth=1.5, label=f"Best epoch ({best_epoch})") | |
| ax.set_xlabel("Epoch", fontsize=11) | |
| ax.set_ylabel("Loss", fontsize=11) | |
| ax.set_title("Training & Validation Loss", fontsize=13, fontweight="bold") | |
| ax.legend(loc="upper right", fontsize=10) | |
| ax.grid(True, alpha=0.35) | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def plot_per_class_f1_progression( | |
| epoch_metrics_list: Sequence[Mapping[str, float]], | |
| label_names: Sequence[str], | |
| output_dir: str | Path, | |
| filename: str = "per_class_f1_progression.png", | |
| ) -> Figure: | |
| """ | |
| ``epoch_metrics_list``: one dict per epoch, each mapping ``label_name -> f1`` for validation. | |
| """ | |
| out = _ensure_output_dir(output_dir) | |
| label_names = list(label_names) | |
| n_epochs = len(epoch_metrics_list) | |
| if n_epochs == 0: | |
| raise ValueError("epoch_metrics_list is empty") | |
| epochs = np.arange(1, n_epochs + 1) | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| palette = sns.color_palette("husl", n_colors=max(len(label_names), 1)) | |
| for i, name in enumerate(label_names): | |
| series = [float(epoch_metrics_list[e].get(name, np.nan)) for e in range(n_epochs)] | |
| ax.plot(epochs, series, label=name, color=palette[i % len(palette)], linewidth=1.8) | |
| ax.set_xlabel("Epoch", fontsize=11) | |
| ax.set_ylabel("F1", fontsize=11) | |
| ax.set_title("Per-Class F1 Score Across Training", fontsize=13, fontweight="bold") | |
| ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", fontsize=8, frameon=True) | |
| ax.grid(True, alpha=0.35) | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def _auroc_color(val: float) -> str: | |
| if np.isnan(val): | |
| return "#cccccc" | |
| if val < 0.7: | |
| return "#d62728" | |
| if val <= 0.85: | |
| return "#ffbf00" | |
| return "#2ca02c" | |
| def plot_auroc_bars( | |
| test_metrics: Mapping[str, Any], | |
| label_names: Sequence[str], | |
| output_dir: str | Path, | |
| filename: str = "test_auroc_bars.png", | |
| ) -> Figure: | |
| """Horizontal bar chart of per-class AUROC (from ``test_metrics['per_class'][name]['auroc']``).""" | |
| out = _ensure_output_dir(output_dir) | |
| per_class = test_metrics.get("per_class", {}) | |
| rows: List[Tuple[str, float]] = [] | |
| for name in label_names: | |
| d = per_class.get(name, {}) | |
| a = float(d.get("auroc", np.nan)) | |
| rows.append((name, a)) | |
| rows.sort( | |
| key=lambda x: (0 if not np.isnan(x[1]) else 1, -(x[1] if not np.isnan(x[1]) else 0.0), x[0]), | |
| ) | |
| labels = [r[0] for r in rows] | |
| values = [r[1] for r in rows] | |
| fig, ax = plt.subplots(figsize=(10, max(4, 0.45 * len(labels)))) | |
| colors = [_auroc_color(v) for v in values] | |
| y_pos = np.arange(len(labels)) | |
| bars = ax.barh(y_pos, values, color=colors, edgecolor="white", linewidth=0.5) | |
| ax.set_yticks(y_pos) | |
| ax.set_yticklabels(labels, fontsize=10) | |
| ax.set_xlabel("AUROC", fontsize=11) | |
| ax.set_xlim(0, 1.05) | |
| ax.set_title("Test Set AUROC by Subcellular Location", fontsize=13, fontweight="bold") | |
| for bar, v in zip(bars, values): | |
| if np.isnan(v): | |
| ax.text(0.02, bar.get_y() + bar.get_height() / 2, "nan", va="center", fontsize=9) | |
| else: | |
| ax.text(min(v + 0.02, 1.0), bar.get_y() + bar.get_height() / 2, f"{v:.3f}", va="center", fontsize=9) | |
| ax.grid(True, axis="x", alpha=0.35) | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def plot_confusion_matrices( | |
| y_true: Any, | |
| y_pred: Any, | |
| label_names: Sequence[str], | |
| output_dir: str | Path, | |
| filename: str = "confusion_matrices.png", | |
| ) -> Figure: | |
| """One 2×2 confusion matrix per class (binary multilabel).""" | |
| out = _ensure_output_dir(output_dir) | |
| yt = _to_numpy(y_true).astype(np.int64) | |
| yp = _to_numpy(y_pred).astype(np.int64) | |
| if yt.shape != yp.shape or yt.ndim != 2: | |
| raise ValueError("y_true and y_pred must be 2D arrays of the same shape") | |
| n = len(label_names) | |
| nrows, ncols = _subplot_grid(n, max_cols=5) | |
| fig, axes = plt.subplots(nrows, ncols, figsize=(3.2 * ncols, 3.0 * nrows)) | |
| axes_flat = np.atleast_1d(axes).ravel() | |
| for i, name in enumerate(label_names): | |
| ax = axes_flat[i] | |
| cm = confusion_matrix(yt[:, i], yp[:, i], labels=[0, 1]) | |
| sns.heatmap( | |
| cm, | |
| annot=True, | |
| fmt="d", | |
| cmap="Blues", | |
| cbar=False, | |
| ax=ax, | |
| xticklabels=["Pred 0", "Pred 1"], | |
| yticklabels=["True 0", "True 1"], | |
| ) | |
| ax.set_title(name, fontsize=10, fontweight="bold") | |
| for j in range(n, len(axes_flat)): | |
| axes_flat[j].set_visible(False) | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def _cooccurrence_matrix(y_bin: np.ndarray) -> np.ndarray: | |
| """Symmetric co-occurrence counts: (y.T @ y) for binary matrix y (n_samples, n_labels).""" | |
| return (y_bin.T @ y_bin).astype(np.float64) | |
| def plot_label_cooccurrence( | |
| y_true: Any, | |
| y_pred: Any, | |
| label_names: Sequence[str], | |
| output_dir: str | Path, | |
| filename: str = "label_cooccurrence.png", | |
| ) -> Figure: | |
| """Side-by-side heatmaps: ground-truth vs predicted label co-occurrence.""" | |
| out = _ensure_output_dir(output_dir) | |
| yt = _to_numpy(y_true).astype(np.int64) | |
| yp = _to_numpy(y_pred).astype(np.int64) | |
| if yt.shape != yp.shape: | |
| raise ValueError("y_true and y_pred must have the same shape") | |
| c_true = _cooccurrence_matrix(yt) | |
| c_pred = _cooccurrence_matrix(yp) | |
| vmax = max(c_true.max(), c_pred.max(), 1.0) | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) | |
| sns.heatmap( | |
| c_true, | |
| xticklabels=label_names, | |
| yticklabels=label_names, | |
| annot=False, | |
| fmt=".0f", | |
| cmap="YlOrRd", | |
| ax=ax1, | |
| vmin=0, | |
| vmax=vmax, | |
| square=True, | |
| ) | |
| ax1.set_title("Ground truth co-occurrence", fontsize=12, fontweight="bold") | |
| sns.heatmap( | |
| c_pred, | |
| xticklabels=label_names, | |
| yticklabels=label_names, | |
| annot=False, | |
| fmt=".0f", | |
| cmap="YlOrRd", | |
| ax=ax2, | |
| vmin=0, | |
| vmax=vmax, | |
| square=True, | |
| ) | |
| ax2.set_title("Predicted co-occurrence", fontsize=12, fontweight="bold") | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def plot_probability_distributions( | |
| y_true: Any, | |
| y_pred_proba: Any, | |
| label_names: Sequence[str], | |
| output_dir: str | Path, | |
| filename: str = "probability_distributions.png", | |
| ) -> Figure: | |
| """Per-class histograms: predicted probability for negatives (red) vs positives (blue).""" | |
| out = _ensure_output_dir(output_dir) | |
| yt = _to_numpy(y_true).astype(np.float64) | |
| yp = _to_numpy(y_pred_proba).astype(np.float64) | |
| n = len(label_names) | |
| nrows, ncols = _subplot_grid(n, max_cols=5) | |
| fig, axes = plt.subplots(nrows, ncols, figsize=(3.2 * ncols, 2.8 * nrows)) | |
| axes_flat = np.atleast_1d(axes).ravel() | |
| for i, name in enumerate(label_names): | |
| ax = axes_flat[i] | |
| pos_mask = yt[:, i] > 0.5 | |
| neg_mask = ~pos_mask | |
| pos_p = yp[pos_mask, i] | |
| neg_p = yp[neg_mask, i] | |
| ax.hist(neg_p, bins=20, alpha=0.65, color="#d62728", label="Negative", density=True) | |
| ax.hist(pos_p, bins=20, alpha=0.65, color="#1f77b4", label="Positive", density=True) | |
| ax.set_title(name, fontsize=9, fontweight="bold") | |
| ax.set_xlim(0, 1) | |
| ax.tick_params(labelsize=8) | |
| if i == 0: | |
| ax.legend(fontsize=7, loc="upper right") | |
| for j in range(n, len(axes_flat)): | |
| axes_flat[j].set_visible(False) | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def plot_threshold_analysis( | |
| y_true: Any, | |
| y_pred_proba: Any, | |
| label_names: Sequence[str], | |
| output_dir: str | Path, | |
| filename: str = "threshold_analysis.png", | |
| thresholds: np.ndarray | None = None, | |
| ) -> Figure: | |
| """Per-class F1 vs threshold; marks optimal threshold (max F1) with a red dot.""" | |
| out = _ensure_output_dir(output_dir) | |
| yt = _to_numpy(y_true).astype(np.int64) | |
| yp = _to_numpy(y_pred_proba).astype(np.float64) | |
| if thresholds is None: | |
| thresholds = np.linspace(0.01, 0.99, 99) | |
| n = len(label_names) | |
| nrows, ncols = _subplot_grid(n, max_cols=5) | |
| fig, axes = plt.subplots(nrows, ncols, figsize=(3.2 * ncols, 2.6 * nrows)) | |
| axes_flat = np.atleast_1d(axes).ravel() | |
| for i, name in enumerate(label_names): | |
| ax = axes_flat[i] | |
| f1s = [] | |
| for thr in thresholds: | |
| yb = (yp[:, i] >= thr).astype(np.int64) | |
| f1s.append(f1_score(yt[:, i], yb, zero_division=0)) | |
| f1s = np.asarray(f1s) | |
| best_idx = int(np.argmax(f1s)) | |
| best_thr = float(thresholds[best_idx]) | |
| best_f1 = float(f1s[best_idx]) | |
| ax.plot(thresholds, f1s, color="#1f77b4", linewidth=1.5) | |
| ax.scatter([best_thr], [best_f1], color="red", s=40, zorder=5) | |
| ax.set_title(name, fontsize=9, fontweight="bold") | |
| ax.set_xlabel("Threshold", fontsize=8) | |
| ax.set_ylabel("F1", fontsize=8) | |
| ax.tick_params(labelsize=7) | |
| ax.grid(True, alpha=0.35) | |
| for j in range(n, len(axes_flat)): | |
| axes_flat[j].set_visible(False) | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def generate_all_plots( | |
| *, | |
| output_dir: str | Path, | |
| train_losses: Sequence[float], | |
| val_losses: Sequence[float], | |
| epoch_metrics_list: Sequence[Mapping[str, float]], | |
| label_names: Sequence[str], | |
| test_metrics: Mapping[str, Any], | |
| y_true: Any, | |
| y_pred: Any, | |
| y_pred_proba: Any, | |
| best_epoch: Optional[int] = None, | |
| mlflow_log: bool = True, | |
| ) -> Dict[str, Figure]: | |
| """ | |
| Generate all analysis plots and optionally log each PNG to the active MLflow run. | |
| ``epoch_metrics_list``: one dict per epoch, each mapping ``label_name -> validation F1``. | |
| """ | |
| out = _ensure_output_dir(output_dir) | |
| figures: Dict[str, Figure] = {} | |
| figures["training_curves"] = plot_training_curves( | |
| train_losses, val_losses, out, best_epoch=best_epoch | |
| ) | |
| figures["per_class_f1"] = plot_per_class_f1_progression(epoch_metrics_list, label_names, out) | |
| figures["auroc_bars"] = plot_auroc_bars(test_metrics, label_names, out) | |
| figures["confusion"] = plot_confusion_matrices(y_true, y_pred, label_names, out) | |
| figures["cooccurrence"] = plot_label_cooccurrence(y_true, y_pred, label_names, out) | |
| figures["proba_hist"] = plot_probability_distributions(y_true, y_pred_proba, label_names, out) | |
| figures["threshold"] = plot_threshold_analysis(y_true, y_pred_proba, label_names, out) | |
| if mlflow_log and mlflow is not None: | |
| try: | |
| if mlflow.active_run() is not None: | |
| artifact_files = [ | |
| "training_curves.png", | |
| "per_class_f1_progression.png", | |
| "test_auroc_bars.png", | |
| "confusion_matrices.png", | |
| "label_cooccurrence.png", | |
| "probability_distributions.png", | |
| "threshold_analysis.png", | |
| ] | |
| for fname in artifact_files: | |
| p = out / fname | |
| if p.is_file(): | |
| mlflow.log_artifact(str(p), artifact_path="plots") | |
| except Exception: | |
| pass | |
| return figures | |
| def plot_training_curves_comparison( | |
| train_losses_a: Sequence[float], | |
| val_losses_a: Sequence[float], | |
| train_losses_b: Sequence[float], | |
| val_losses_b: Sequence[float], | |
| label_a: str, | |
| label_b: str, | |
| output_dir: str | Path, | |
| best_epoch_a: Optional[int] = None, | |
| best_epoch_b: Optional[int] = None, | |
| filename: str = "compare_training_val_loss.png", | |
| ) -> Figure: | |
| """Overlay train/val loss curves for two runs.""" | |
| out = _ensure_output_dir(output_dir) | |
| fig, ax = plt.subplots(figsize=(11, 5)) | |
| ea = np.arange(1, len(train_losses_a) + 1) | |
| eb = np.arange(1, len(train_losses_b) + 1) | |
| ax.plot(ea, train_losses_a, label=f"Train ({label_a})", color="#1f77b4", linewidth=2, linestyle="-") | |
| ax.plot(ea, val_losses_a, label=f"Val ({label_a})", color="#aec7e8", linewidth=2, linestyle="-") | |
| ax.plot(eb, train_losses_b, label=f"Train ({label_b})", color="#d62728", linewidth=2, linestyle="--") | |
| ax.plot(eb, val_losses_b, label=f"Val ({label_b})", color="#ff9896", linewidth=2, linestyle="--") | |
| if best_epoch_a is not None and 1 <= best_epoch_a <= len(val_losses_a): | |
| ax.axvline(best_epoch_a, color="#1f77b4", linestyle=":", linewidth=1.2, alpha=0.8) | |
| if best_epoch_b is not None and 1 <= best_epoch_b <= len(val_losses_b): | |
| ax.axvline(best_epoch_b, color="#d62728", linestyle=":", linewidth=1.2, alpha=0.8) | |
| ax.set_xlabel("Epoch", fontsize=11) | |
| ax.set_ylabel("Loss", fontsize=11) | |
| ax.set_title("Training & Validation Loss — comparison", fontsize=13, fontweight="bold") | |
| ax.legend(loc="upper right", fontsize=9) | |
| ax.grid(True, alpha=0.35) | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def plot_val_macro_f1_comparison( | |
| macro_f1_a: Sequence[float], | |
| macro_f1_b: Sequence[float], | |
| label_a: str, | |
| label_b: str, | |
| output_dir: str | Path, | |
| filename: str = "compare_val_macro_f1.png", | |
| ) -> Figure: | |
| out = _ensure_output_dir(output_dir) | |
| fig, ax = plt.subplots(figsize=(10, 4.5)) | |
| ax.plot(np.arange(1, len(macro_f1_a) + 1), macro_f1_a, label=label_a, color="#1f77b4", linewidth=2) | |
| ax.plot(np.arange(1, len(macro_f1_b) + 1), macro_f1_b, label=label_b, color="#d62728", linewidth=2, linestyle="--") | |
| ax.set_xlabel("Epoch", fontsize=11) | |
| ax.set_ylabel("Validation macro F1", fontsize=11) | |
| ax.set_title("Validation macro F1 — comparison", fontsize=13, fontweight="bold") | |
| ax.legend(loc="lower right", fontsize=10) | |
| ax.grid(True, alpha=0.35) | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def plot_per_class_f1_panels( | |
| epoch_metrics_a: Sequence[Mapping[str, float]], | |
| epoch_metrics_b: Sequence[Mapping[str, float]], | |
| label_names: Sequence[str], | |
| title_a: str, | |
| title_b: str, | |
| output_dir: str | Path, | |
| filename: str = "compare_per_class_f1_panels.png", | |
| ) -> Figure: | |
| """Side-by-side panels: per-class F1 vs epoch for each run.""" | |
| out = _ensure_output_dir(output_dir) | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True) | |
| palette = sns.color_palette("husl", n_colors=max(len(label_names), 1)) | |
| n_a, n_b = len(epoch_metrics_a), len(epoch_metrics_b) | |
| for i, name in enumerate(label_names): | |
| sa = [float(epoch_metrics_a[e].get(name, np.nan)) for e in range(n_a)] | |
| sb = [float(epoch_metrics_b[e].get(name, np.nan)) for e in range(n_b)] | |
| ax1.plot(np.arange(1, n_a + 1), sa, label=name, color=palette[i % len(palette)], linewidth=1.5) | |
| ax2.plot(np.arange(1, n_b + 1), sb, label=name, color=palette[i % len(palette)], linewidth=1.5) | |
| ax1.set_title(title_a, fontsize=12, fontweight="bold") | |
| ax2.set_title(title_b, fontsize=12, fontweight="bold") | |
| ax1.set_xlabel("Epoch") | |
| ax2.set_xlabel("Epoch") | |
| ax1.set_ylabel("F1") | |
| ax1.grid(True, alpha=0.35) | |
| ax2.grid(True, alpha=0.35) | |
| h1, l1 = ax1.get_legend_handles_labels() | |
| fig.legend(h1, l1, loc="center left", bbox_to_anchor=(1.02, 0.5), fontsize=7, title="Class") | |
| fig.suptitle("Per-class validation F1 — comparison", fontsize=14, fontweight="bold") | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def plot_auroc_comparison_bars( | |
| test_metrics_a: Mapping[str, Any], | |
| test_metrics_b: Mapping[str, Any], | |
| label_names: Sequence[str], | |
| label_a: str, | |
| label_b: str, | |
| output_dir: str | Path, | |
| filename: str = "compare_test_auroc.png", | |
| ) -> Figure: | |
| """Grouped horizontal bars: AUROC per class for two runs.""" | |
| out = _ensure_output_dir(output_dir) | |
| names = list(label_names) | |
| y = np.arange(len(names)) | |
| h = 0.35 | |
| va = [float(test_metrics_a.get("per_class", {}).get(n, {}).get("auroc", np.nan)) for n in names] | |
| vb = [float(test_metrics_b.get("per_class", {}).get(n, {}).get("auroc", np.nan)) for n in names] | |
| fig, ax = plt.subplots(figsize=(10, max(5, 0.42 * len(names)))) | |
| ax.barh(y - h / 2, va, h, label=label_a, color="#1f77b4", alpha=0.85) | |
| ax.barh(y + h / 2, vb, h, label=label_b, color="#d62728", alpha=0.85) | |
| ax.set_yticks(y) | |
| ax.set_yticklabels(names, fontsize=9) | |
| ax.set_xlabel("AUROC", fontsize=11) | |
| ax.set_xlim(0, 1.05) | |
| ax.set_title("Test AUROC by class — comparison", fontsize=13, fontweight="bold") | |
| ax.legend(loc="lower right") | |
| ax.grid(True, axis="x", alpha=0.35) | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def plot_confusion_matrices_two_runs( | |
| y_true_a: Any, | |
| y_pred_a: Any, | |
| y_true_b: Any, | |
| y_pred_b: Any, | |
| label_names: Sequence[str], | |
| row_title_a: str, | |
| row_title_b: str, | |
| output_dir: str | Path, | |
| filename: str = "compare_confusion_matrices.png", | |
| ) -> Figure: | |
| """Two rows of per-class confusion matrices (binary multilabel).""" | |
| out = _ensure_output_dir(output_dir) | |
| yt_a = _to_numpy(y_true_a).astype(np.int64) | |
| yp_a = _to_numpy(y_pred_a).astype(np.int64) | |
| yt_b = _to_numpy(y_true_b).astype(np.int64) | |
| yp_b = _to_numpy(y_pred_b).astype(np.int64) | |
| n = len(label_names) | |
| nrows, ncols = _subplot_grid(n, max_cols=5) | |
| fig, axes = plt.subplots(2 * nrows, ncols, figsize=(3.2 * ncols, 3.0 * 2 * nrows)) | |
| axes_arr = np.atleast_2d(axes) | |
| for i, name in enumerate(label_names): | |
| r0, c0 = divmod(i, ncols) | |
| for ver, (yt, yp, rtitle) in enumerate( | |
| [(yt_a, yp_a, row_title_a), (yt_b, yp_b, row_title_b)] | |
| ): | |
| ax = axes_arr[ver * nrows + r0, c0] | |
| cm = confusion_matrix(yt[:, i], yp[:, i], labels=[0, 1]) | |
| sns.heatmap( | |
| cm, | |
| annot=True, | |
| fmt="d", | |
| cmap="Blues", | |
| cbar=False, | |
| ax=ax, | |
| xticklabels=["P0", "P1"], | |
| yticklabels=["T0", "T1"], | |
| ) | |
| ax.set_title(f"{name}\n{rtitle}", fontsize=8, fontweight="bold") | |
| used: Set[Tuple[int, int]] = set() | |
| for i in range(n): | |
| r0, c0 = divmod(i, ncols) | |
| for ver in (0, 1): | |
| used.add((ver * nrows + r0, c0)) | |
| for r in range(2 * nrows): | |
| for c in range(ncols): | |
| if (r, c) not in used: | |
| axes_arr[r, c].set_visible(False) | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def plot_label_cooccurrence_four( | |
| y_true_a: Any, | |
| y_pred_a: Any, | |
| y_true_b: Any, | |
| y_pred_b: Any, | |
| label_names: Sequence[str], | |
| titles: Sequence[str], | |
| output_dir: str | Path, | |
| filename: str = "compare_label_cooccurrence.png", | |
| ) -> Figure: | |
| """Four heatmaps: GT/pred co-occurrence for each run.""" | |
| out = _ensure_output_dir(output_dir) | |
| mats = [] | |
| for yt, yp in [(y_true_a, y_pred_a), (y_true_b, y_pred_b)]: | |
| yt = _to_numpy(yt).astype(np.int64) | |
| yp = _to_numpy(yp).astype(np.int64) | |
| mats.append((_cooccurrence_matrix(yt), _cooccurrence_matrix(yp))) | |
| vmax = 1.0 | |
| for gt_mat, pred_mat in mats: | |
| vmax = max(vmax, float(gt_mat.max()), float(pred_mat.max())) | |
| fig, axes = plt.subplots(2, 2, figsize=(14, 12)) | |
| flat_titles = [titles[0], titles[1], titles[2], titles[3]] | |
| idx = 0 | |
| for row, pair in enumerate(mats): | |
| for col, cmat in enumerate(pair): | |
| ax = axes[row, col] | |
| sns.heatmap( | |
| cmat, | |
| xticklabels=label_names, | |
| yticklabels=label_names, | |
| cmap="YlOrRd", | |
| ax=ax, | |
| vmin=0, | |
| vmax=vmax, | |
| square=True, | |
| ) | |
| ax.set_title(flat_titles[idx], fontsize=11, fontweight="bold") | |
| idx += 1 | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def plot_probability_distributions_two_runs( | |
| y_true_a: Any, | |
| y_proba_a: Any, | |
| y_true_b: Any, | |
| y_proba_b: Any, | |
| label_names: Sequence[str], | |
| title_a: str, | |
| title_b: str, | |
| output_dir: str | Path, | |
| filename: str = "compare_probability_distributions.png", | |
| ) -> Figure: | |
| """Two rows of per-class probability histograms (neg vs pos).""" | |
| out = _ensure_output_dir(output_dir) | |
| n = len(label_names) | |
| nrows, ncols = _subplot_grid(n, max_cols=5) | |
| fig, axes = plt.subplots(2 * nrows, ncols, figsize=(3.2 * ncols, 2.6 * 2 * nrows)) | |
| axes_arr = np.atleast_2d(axes) | |
| def _row(yt: Any, yp: Any, ver: int) -> None: | |
| yt = _to_numpy(yt).astype(np.float64) | |
| yp = _to_numpy(yp).astype(np.float64) | |
| for i, name in enumerate(label_names): | |
| r0, c0 = divmod(i, ncols) | |
| ax = axes_arr[ver * nrows + r0, c0] | |
| pos_mask = yt[:, i] > 0.5 | |
| neg_mask = ~pos_mask | |
| ax.hist(yp[neg_mask, i], bins=20, alpha=0.6, color="#d62728", density=True, label="Neg") | |
| ax.hist(yp[pos_mask, i], bins=20, alpha=0.6, color="#1f77b4", density=True, label="Pos") | |
| ax.set_title(f"{name}", fontsize=8, fontweight="bold") | |
| ax.set_xlim(0, 1) | |
| _row(y_true_a, y_proba_a, 0) | |
| _row(y_true_b, y_proba_b, 1) | |
| for j in range(n, nrows * ncols): | |
| r0, c0 = divmod(j, ncols) | |
| for ver in range(2): | |
| axes_arr[ver * nrows + r0, c0].set_visible(False) | |
| fig.suptitle(f"Predicted probability distributions — {title_a} (top) vs {title_b} (bottom)", fontsize=12) | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def plot_threshold_analysis_two_runs( | |
| y_true_a: Any, | |
| y_proba_a: Any, | |
| y_true_b: Any, | |
| y_proba_b: Any, | |
| label_names: Sequence[str], | |
| title_a: str, | |
| title_b: str, | |
| output_dir: str | Path, | |
| filename: str = "compare_threshold_analysis.png", | |
| ) -> Figure: | |
| """Two rows: F1 vs threshold per class.""" | |
| out = _ensure_output_dir(output_dir) | |
| thresholds = np.linspace(0.01, 0.99, 99) | |
| n = len(label_names) | |
| nrows, ncols = _subplot_grid(n, max_cols=5) | |
| fig, axes = plt.subplots(2 * nrows, ncols, figsize=(3.2 * ncols, 2.4 * 2 * nrows)) | |
| axes_arr = np.atleast_2d(axes) | |
| def _fill(yt: Any, yp: Any, ver: int) -> None: | |
| yt = _to_numpy(yt).astype(np.int64) | |
| yp = _to_numpy(yp).astype(np.float64) | |
| for i, name in enumerate(label_names): | |
| r0, c0 = divmod(i, ncols) | |
| ax = axes_arr[ver * nrows + r0, c0] | |
| f1s = [f1_score(yt[:, i], (yp[:, i] >= thr).astype(np.int64), zero_division=0) for thr in thresholds] | |
| ax.plot(thresholds, f1s, color="#1f77b4", linewidth=1.2) | |
| bi = int(np.argmax(f1s)) | |
| ax.scatter([thresholds[bi]], [f1s[bi]], color="red", s=25, zorder=5) | |
| ax.set_title(f"{name}", fontsize=7, fontweight="bold") | |
| ax.set_xlabel("Thr", fontsize=7) | |
| ax.set_ylabel("F1", fontsize=7) | |
| _fill(y_true_a, y_proba_a, 0) | |
| _fill(y_true_b, y_proba_b, 1) | |
| for j in range(n, nrows * ncols): | |
| r0, c0 = divmod(j, ncols) | |
| for ver in range(2): | |
| axes_arr[ver * nrows + r0, c0].set_visible(False) | |
| fig.suptitle(f"F1 vs threshold — {title_a} (top) vs {title_b} (bottom)", fontsize=12) | |
| fig.tight_layout() | |
| fig.savefig(out / filename, dpi=150, bbox_inches="tight") | |
| return fig | |
| def _load_json(path: Path) -> Any: | |
| return json.loads(path.read_text(encoding="utf-8")) | |
| def generate_comparison_plots( | |
| *, | |
| v1_artifact_dir: Union[str, Path], | |
| v2_artifact_dir: Union[str, Path], | |
| label_names: Sequence[str], | |
| output_dir: Union[str, Path], | |
| label_v1: str = "baseline_v1", | |
| label_v2: str = "regularized_v2", | |
| mlflow_log: bool = True, | |
| ) -> Dict[str, Figure]: | |
| """ | |
| Load two run artifact folders (``train_history.json``, ``final_test_metrics.json``, | |
| ``test_predictions.npz``) and write comparison figures. | |
| """ | |
| d1 = Path(v1_artifact_dir).expanduser().resolve() | |
| d2 = Path(v2_artifact_dir).expanduser().resolve() | |
| out = _ensure_output_dir(output_dir) | |
| h1 = _load_json(d1 / "train_history.json") | |
| h2 = _load_json(d2 / "train_history.json") | |
| m1 = _load_json(d1 / "final_test_metrics.json") | |
| m2 = _load_json(d2 / "final_test_metrics.json") | |
| z1 = np.load(d1 / "test_predictions.npz") | |
| z2 = np.load(d2 / "test_predictions.npz") | |
| figures: Dict[str, Figure] = {} | |
| be1 = h1.get("best_epoch") | |
| be2 = h2.get("best_epoch") | |
| be1i = int(be1) if be1 is not None and int(be1) >= 1 else None | |
| be2i = int(be2) if be2 is not None and int(be2) >= 1 else None | |
| figures["compare_loss"] = plot_training_curves_comparison( | |
| h1.get("train_loss", []), | |
| h1.get("val_loss", []), | |
| h2.get("train_loss", []), | |
| h2.get("val_loss", []), | |
| label_v1, | |
| label_v2, | |
| out, | |
| best_epoch_a=be1i, | |
| best_epoch_b=be2i, | |
| ) | |
| figures["compare_macro_f1"] = plot_val_macro_f1_comparison( | |
| h1.get("val_macro_f1", []), | |
| h2.get("val_macro_f1", []), | |
| label_v1, | |
| label_v2, | |
| out, | |
| ) | |
| figures["compare_per_class_f1"] = plot_per_class_f1_panels( | |
| h1.get("val_per_class_f1", []), | |
| h2.get("val_per_class_f1", []), | |
| label_names, | |
| label_v1, | |
| label_v2, | |
| out, | |
| ) | |
| figures["compare_auroc"] = plot_auroc_comparison_bars(m1, m2, label_names, label_v1, label_v2, out) | |
| figures["compare_confusion"] = plot_confusion_matrices_two_runs( | |
| z1["y_true"], | |
| z1["y_pred_binary"], | |
| z2["y_true"], | |
| z2["y_pred_binary"], | |
| label_names, | |
| label_v1, | |
| label_v2, | |
| out, | |
| ) | |
| figures["compare_cooc"] = plot_label_cooccurrence_four( | |
| z1["y_true"], | |
| z1["y_pred_binary"], | |
| z2["y_true"], | |
| z2["y_pred_binary"], | |
| label_names, | |
| (f"{label_v1} GT", f"{label_v1} pred", f"{label_v2} GT", f"{label_v2} pred"), | |
| out, | |
| ) | |
| figures["compare_proba"] = plot_probability_distributions_two_runs( | |
| z1["y_true"], | |
| z1["y_pred_proba"], | |
| z2["y_true"], | |
| z2["y_pred_proba"], | |
| label_names, | |
| label_v1, | |
| label_v2, | |
| out, | |
| ) | |
| figures["compare_thr"] = plot_threshold_analysis_two_runs( | |
| z1["y_true"], | |
| z1["y_pred_proba"], | |
| z2["y_true"], | |
| z2["y_pred_proba"], | |
| label_names, | |
| label_v1, | |
| label_v2, | |
| out, | |
| ) | |
| if mlflow_log and mlflow is not None: | |
| try: | |
| if mlflow.active_run() is not None: | |
| for fname in ( | |
| "compare_training_val_loss.png", | |
| "compare_val_macro_f1.png", | |
| "compare_per_class_f1_panels.png", | |
| "compare_test_auroc.png", | |
| "compare_confusion_matrices.png", | |
| "compare_label_cooccurrence.png", | |
| "compare_probability_distributions.png", | |
| "compare_threshold_analysis.png", | |
| ): | |
| p = out / fname | |
| if p.is_file(): | |
| mlflow.log_artifact(str(p), artifact_path="comparison_plots") | |
| except Exception: | |
| pass | |
| return figures | |