Reinforcement Learning
stable-baselines3
Joblib
PyTorch
tabular-regression
xgboost
femtosecond-laser
hydrogel
GelMA
HAMA
laser-machining
SAC
materials-science
manufacturing
ml-intern
Instructions to use TWLab/femtosecond-laser-hydrogel-etching-model with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- stable-baselines3
How to use TWLab/femtosecond-laser-hydrogel-etching-model with stable-baselines3:
from huggingface_sb3 import load_from_hub checkpoint = load_from_hub( repo_id="TWLab/femtosecond-laser-hydrogel-etching-model", filename="{MODEL FILENAME}.zip", ) - Notebooks
- Google Colab
- Kaggle
| """ | |
| Publication-quality visualization module. | |
| Generates figures suitable for peer-reviewed journals: | |
| - Predicted vs Actual parity plots | |
| - Residual analysis | |
| - Feature importance (XGBoost gain + SHAP) | |
| - Model comparison charts | |
| - Training curves | |
| - Per-material performance breakdown | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import seaborn as sns | |
| from sklearn.metrics import r2_score | |
| logger = logging.getLogger(__name__) | |
| # Publication style defaults | |
| STYLE_CONFIG = { | |
| "font.size": 11, | |
| "axes.titlesize": 12, | |
| "axes.labelsize": 11, | |
| "xtick.labelsize": 10, | |
| "ytick.labelsize": 10, | |
| "legend.fontsize": 10, | |
| "figure.dpi": 300, | |
| "savefig.dpi": 300, | |
| "savefig.bbox": "tight", | |
| "font.family": "serif", | |
| } | |
| def set_publication_style(): | |
| """Set matplotlib parameters for publication-quality figures.""" | |
| plt.rcParams.update(STYLE_CONFIG) | |
| sns.set_palette("Set2") | |
| def plot_predicted_vs_actual( | |
| y_true: np.ndarray, | |
| y_pred: np.ndarray, | |
| target_names: List[str], | |
| model_name: str = "Ensemble", | |
| save_path: Optional[Path] = None, | |
| figsize: Tuple[float, float] = (14, 10), | |
| ) -> plt.Figure: | |
| """ | |
| Create predicted vs actual parity plots for all targets. | |
| Parameters | |
| ---------- | |
| y_true, y_pred : np.ndarray, shape (n_samples, n_targets) | |
| target_names : list of str | |
| model_name : str | |
| save_path : Path, optional | |
| figsize : tuple | |
| Returns | |
| ------- | |
| matplotlib.Figure | |
| """ | |
| set_publication_style() | |
| n_targets = len(target_names) | |
| n_cols = 3 | |
| n_rows = (n_targets + n_cols - 1) // n_cols | |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) | |
| axes = axes.flatten() | |
| # Target-specific units for axis labels | |
| units = { | |
| "etch_depth_um": "Etch Depth (µm)", | |
| "etch_width_um": "Etch Width (µm)", | |
| "surface_roughness_Sa_um": "Surface Roughness Sa (µm)", | |
| "aspect_ratio": "Aspect Ratio", | |
| "side_wall_angle_deg": "Side Wall Angle (°)", | |
| } | |
| for i, target in enumerate(target_names): | |
| ax = axes[i] | |
| yt, yp = y_true[:, i], y_pred[:, i] | |
| r2 = r2_score(yt, yp) | |
| ax.scatter(yt, yp, alpha=0.4, s=12, edgecolors="none", c="steelblue") | |
| # Perfect prediction line | |
| lim_min = min(yt.min(), yp.min()) | |
| lim_max = max(yt.max(), yp.max()) | |
| margin = (lim_max - lim_min) * 0.05 | |
| ax.plot( | |
| [lim_min - margin, lim_max + margin], | |
| [lim_min - margin, lim_max + margin], | |
| "r--", lw=1.5, label="Perfect prediction", | |
| ) | |
| # ±10% error bands | |
| x_range = np.linspace(lim_min, lim_max, 100) | |
| ax.fill_between(x_range, x_range * 0.9, x_range * 1.1, | |
| alpha=0.1, color="red", label="±10% band") | |
| label = units.get(target, target) | |
| ax.set_xlabel(f"Actual {label}") | |
| ax.set_ylabel(f"Predicted {label}") | |
| ax.set_title(f"{label}\nR² = {r2:.4f}", fontweight="bold") | |
| ax.set_xlim(lim_min - margin, lim_max + margin) | |
| ax.set_ylim(lim_min - margin, lim_max + margin) | |
| ax.set_aspect("equal") | |
| if i == 0: | |
| ax.legend(loc="lower right", fontsize=9) | |
| # Hide unused subplots | |
| for i in range(n_targets, len(axes)): | |
| axes[i].set_visible(False) | |
| fig.suptitle(f"{model_name} — Predicted vs Actual", fontsize=14, fontweight="bold", y=1.02) | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, bbox_inches="tight") | |
| logger.info(f"Saved: {save_path}") | |
| return fig | |
| def plot_residual_analysis( | |
| y_true: np.ndarray, | |
| y_pred: np.ndarray, | |
| target_names: List[str], | |
| save_path: Optional[Path] = None, | |
| ) -> plt.Figure: | |
| """ | |
| Residual distribution and Q-Q plots for error analysis. | |
| """ | |
| set_publication_style() | |
| n_targets = len(target_names) | |
| fig, axes = plt.subplots(2, n_targets, figsize=(4 * n_targets, 8)) | |
| for i, target in enumerate(target_names): | |
| residuals = y_true[:, i] - y_pred[:, i] | |
| # Histogram of residuals | |
| ax = axes[0, i] if n_targets > 1 else axes[0] | |
| ax.hist(residuals, bins=50, density=True, alpha=0.7, color="steelblue", edgecolor="white") | |
| ax.axvline(0, color="red", linestyle="--", lw=1) | |
| ax.set_xlabel("Residual") | |
| ax.set_ylabel("Density") | |
| ax.set_title(target.replace("_", " ")) | |
| # Residual vs predicted | |
| ax2 = axes[1, i] if n_targets > 1 else axes[1] | |
| ax2.scatter(y_pred[:, i], residuals, alpha=0.3, s=8, c="steelblue") | |
| ax2.axhline(0, color="red", linestyle="--", lw=1) | |
| ax2.set_xlabel("Predicted") | |
| ax2.set_ylabel("Residual") | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, bbox_inches="tight") | |
| return fig | |
| def plot_feature_importance( | |
| importances: Dict[str, Dict[str, float]], | |
| top_n: int = 10, | |
| save_path: Optional[Path] = None, | |
| ) -> plt.Figure: | |
| """ | |
| Feature importance heatmap and bar charts. | |
| Parameters | |
| ---------- | |
| importances : dict | |
| {target_name: {feature_name: importance_value}} | |
| top_n : int | |
| Show top N features per target | |
| save_path : Path, optional | |
| """ | |
| set_publication_style() | |
| # Create DataFrame | |
| imp_df = pd.DataFrame(importances) | |
| # Overall importance (mean across targets) | |
| imp_df["Mean"] = imp_df.mean(axis=1) | |
| imp_df = imp_df.sort_values("Mean", ascending=False) | |
| # Heatmap of top features | |
| top_features = imp_df.head(top_n).drop(columns="Mean") | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| sns.heatmap( | |
| top_features, | |
| annot=True, fmt=".3f", | |
| cmap="YlOrRd", | |
| linewidths=0.5, | |
| ax=ax, | |
| ) | |
| ax.set_title("Feature Importance by Target (XGBoost Gain)", fontweight="bold") | |
| ax.set_xlabel("Target Variable") | |
| ax.set_ylabel("Feature") | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, bbox_inches="tight") | |
| return fig | |
| def plot_model_comparison( | |
| results_dict: Dict[str, pd.DataFrame], | |
| metric: str = "R²", | |
| target_names: List[str] = None, | |
| save_path: Optional[Path] = None, | |
| ) -> plt.Figure: | |
| """ | |
| Bar chart comparing models across targets. | |
| Parameters | |
| ---------- | |
| results_dict : dict | |
| {model_name: metrics_DataFrame} | |
| metric : str | |
| Which metric column to plot | |
| target_names : list, optional | |
| save_path : Path, optional | |
| """ | |
| set_publication_style() | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| x = np.arange(len(target_names)) | |
| width = 0.8 / len(results_dict) | |
| colors = sns.color_palette("Set2", len(results_dict)) | |
| for j, (model_name, df) in enumerate(results_dict.items()): | |
| values = [df.loc[t, metric] if t in df.index else 0 for t in target_names] | |
| offset = (j - len(results_dict) / 2 + 0.5) * width | |
| bars = ax.bar(x + offset, values, width, label=model_name, color=colors[j], edgecolor="white") | |
| ax.set_xlabel("Target Variable") | |
| ax.set_ylabel(metric) | |
| ax.set_title(f"Model Comparison — {metric}", fontweight="bold") | |
| ax.set_xticks(x) | |
| ax.set_xticklabels([t.replace("_", "\n") for t in target_names], fontsize=9) | |
| ax.legend() | |
| ax.grid(axis="y", alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, bbox_inches="tight") | |
| return fig | |
| def plot_training_curves( | |
| train_losses: List[float], | |
| val_losses: List[float], | |
| save_path: Optional[Path] = None, | |
| ) -> plt.Figure: | |
| """Plot neural network training and validation loss curves.""" | |
| set_publication_style() | |
| fig, ax = plt.subplots(figsize=(8, 5)) | |
| epochs = range(1, len(train_losses) + 1) | |
| ax.plot(epochs, train_losses, label="Training Loss", color="steelblue", lw=1.5) | |
| ax.plot(epochs, val_losses, label="Validation Loss", color="darkorange", lw=1.5) | |
| ax.set_xlabel("Epoch") | |
| ax.set_ylabel("MSE Loss") | |
| ax.set_title("Neural Network Training Curves", fontweight="bold") | |
| ax.set_yscale("log") | |
| ax.legend() | |
| ax.grid(alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, bbox_inches="tight") | |
| return fig | |
| def plot_per_material_performance( | |
| material_results: pd.DataFrame, | |
| metric_prefix: str = "R²", | |
| target_names: List[str] = None, | |
| save_path: Optional[Path] = None, | |
| ) -> plt.Figure: | |
| """Grouped bar chart showing performance per material type.""" | |
| set_publication_style() | |
| # Filter columns for the metric | |
| cols = [c for c in material_results.columns if c.startswith(metric_prefix)] | |
| if not cols: | |
| cols = [c for c in material_results.columns if "R²" in c] | |
| plot_df = material_results[cols].copy() | |
| plot_df.columns = [c.replace(f"{metric_prefix}_", "").replace("R²_", "") for c in cols] | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| plot_df.plot(kind="bar", ax=ax, colormap="Set2", edgecolor="white") | |
| ax.set_xlabel("Material Type") | |
| ax.set_ylabel("R²") | |
| ax.set_title("Model Performance by Material Type", fontweight="bold") | |
| ax.legend(title="Target", bbox_to_anchor=(1.02, 1), loc="upper left", fontsize=9) | |
| ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right") | |
| ax.grid(axis="y", alpha=0.3) | |
| ax.set_ylim(0, 1.05) | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, bbox_inches="tight") | |
| return fig | |