| from __future__ import annotations |
|
|
| import argparse |
|
|
| import gradio as gr |
| import matplotlib |
| import numpy as np |
| from matplotlib.figure import Figure |
| from matplotlib.lines import Line2D |
| from numpy.typing import NDArray |
|
|
| |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| FloatArray = NDArray[np.float64] |
| APP_THEME = gr.themes.Soft( |
| primary_hue="sky", |
| secondary_hue="amber", |
| neutral_hue="slate", |
| ) |
|
|
|
|
| def make_prior_cov(std_w0: float, std_w1: float, rho: float) -> FloatArray: |
| if std_w0 <= 0 or std_w1 <= 0: |
| raise ValueError("事前標準偏差は正の値にしてください。") |
| if not (-0.999 < rho < 0.999): |
| raise ValueError("事前相関係数 rho は -1 より大きく 1 より小さい値にしてください。") |
|
|
| cov = np.array( |
| [ |
| [std_w0**2, rho * std_w0 * std_w1], |
| [rho * std_w0 * std_w1, std_w1**2], |
| ], |
| dtype=float, |
| ) |
| sign, _ = np.linalg.slogdet(cov) |
| if sign <= 0: |
| raise ValueError("事前共分散行列が正定値ではありません。標準偏差と相関係数を見直してください。") |
| return cov |
|
|
|
|
| def generate_dataset( |
| true_w0: float, |
| true_w1: float, |
| sigma: float, |
| n_max: int, |
| seed: int, |
| ) -> tuple[FloatArray, FloatArray]: |
| if n_max < 1: |
| raise ValueError("N_max は 1 以上にしてください。") |
| if sigma <= 0: |
| raise ValueError("観測ノイズ標準偏差 sigma は正の値にしてください。") |
|
|
| rng = np.random.default_rng(seed) |
| x = rng.uniform(-1.0, 1.0, size=n_max) |
| noise = rng.normal(0.0, sigma, size=n_max) |
| y = true_w0 + true_w1 * x + noise |
| return x.astype(float), y.astype(float) |
|
|
|
|
| def compute_posterior( |
| prior_mean: FloatArray, |
| prior_cov: FloatArray, |
| x: FloatArray, |
| y: FloatArray, |
| sigma: float, |
| n_used: int, |
| ) -> tuple[FloatArray, FloatArray]: |
| n_used = int(np.clip(n_used, 0, len(x))) |
| if n_used == 0: |
| return prior_mean.copy(), prior_cov.copy() |
|
|
| phi = np.column_stack([np.ones(n_used), x[:n_used]]) |
| y_used = y[:n_used] |
| prior_precision = np.linalg.inv(prior_cov) |
| posterior_precision = prior_precision + (phi.T @ phi) / (sigma**2) |
| posterior_cov = np.linalg.inv(posterior_precision) |
| rhs = prior_precision @ prior_mean + (phi.T @ y_used) / (sigma**2) |
| posterior_mean = posterior_cov @ rhs |
| return posterior_mean, posterior_cov |
|
|
|
|
| def sample_weights(mean: FloatArray, cov: FloatArray, n_lines: int, seed: int) -> FloatArray: |
| if n_lines < 1: |
| raise ValueError("表示する直線本数 n_lines は 1 以上にしてください。") |
|
|
| rng = np.random.default_rng(seed) |
| return rng.multivariate_normal(mean=mean, cov=cov, size=n_lines).astype(float) |
|
|
|
|
| def _gaussian_density_grid( |
| mean: FloatArray, |
| cov: FloatArray, |
| grid_w0: FloatArray, |
| grid_w1: FloatArray, |
| ) -> FloatArray: |
| cov_inv = np.linalg.inv(cov) |
| sign, logdet = np.linalg.slogdet(cov) |
| if sign <= 0: |
| raise ValueError("共分散行列が正定値ではありません。") |
|
|
| position = np.stack([grid_w0, grid_w1], axis=-1) |
| diff = position - mean |
| quad = np.einsum("...i,ij,...j->...", diff, cov_inv, diff) |
| log_density = -0.5 * (2 * np.log(2 * np.pi) + logdet + quad) |
| return np.exp(log_density) |
|
|
|
|
| def _likelihood_surface( |
| grid_w0: FloatArray, |
| grid_w1: FloatArray, |
| x_used: FloatArray, |
| y_used: FloatArray, |
| sigma: float, |
| ) -> FloatArray: |
| predictions = grid_w0[..., None] + grid_w1[..., None] * x_used |
| residuals = y_used - predictions |
| rss = np.sum(residuals**2, axis=-1) |
| log_likelihood = -0.5 * rss / (sigma**2) |
| return np.exp(log_likelihood - np.max(log_likelihood)) |
|
|
|
|
| def _contour_levels(surface: FloatArray) -> FloatArray: |
| peak = float(np.max(surface)) |
| if not np.isfinite(peak) or peak <= 0: |
| return np.array([1.0], dtype=float) |
|
|
| relative_levels = np.exp(-0.5 * np.array([7.0, 4.5, 2.5, 1.0, 0.3], dtype=float)) |
| levels = np.sort(peak * relative_levels) |
| return np.unique(np.clip(levels, peak * 1e-6, peak * 0.999)) |
|
|
|
|
| def _parameter_limits( |
| prior_mean: FloatArray, |
| prior_cov: FloatArray, |
| posterior_mean: FloatArray, |
| posterior_cov: FloatArray, |
| true_w: FloatArray, |
| ) -> tuple[tuple[float, float], tuple[float, float]]: |
| prior_std = 4.0 * np.sqrt(np.diag(prior_cov)) |
| posterior_std = 4.0 * np.sqrt(np.diag(posterior_cov)) |
|
|
| lower = np.vstack( |
| [ |
| prior_mean - prior_std, |
| posterior_mean - posterior_std, |
| true_w, |
| ] |
| ).min(axis=0) |
| upper = np.vstack( |
| [ |
| prior_mean + prior_std, |
| posterior_mean + posterior_std, |
| true_w, |
| ] |
| ).max(axis=0) |
| span = np.maximum(upper - lower, np.array([1.0, 1.0], dtype=float)) |
| padding = 0.15 * span |
| w0_limits = (float(lower[0] - padding[0]), float(upper[0] + padding[0])) |
| w1_limits = (float(lower[1] - padding[1]), float(upper[1] + padding[1])) |
| return w0_limits, w1_limits |
|
|
|
|
| def plot_parameter_space( |
| prior_mean: FloatArray, |
| prior_cov: FloatArray, |
| posterior_mean: FloatArray, |
| posterior_cov: FloatArray, |
| true_w: FloatArray, |
| x: FloatArray, |
| y: FloatArray, |
| sigma: float, |
| n_used: int, |
| show_likelihood: bool, |
| ) -> Figure: |
| w0_limits, w1_limits = _parameter_limits(prior_mean, prior_cov, posterior_mean, posterior_cov, true_w) |
| w0_grid = np.linspace(*w0_limits, 180) |
| w1_grid = np.linspace(*w1_limits, 180) |
| grid_w0, grid_w1 = np.meshgrid(w0_grid, w1_grid) |
|
|
| prior_density = _gaussian_density_grid(prior_mean, prior_cov, grid_w0, grid_w1) |
| posterior_density = _gaussian_density_grid(posterior_mean, posterior_cov, grid_w0, grid_w1) |
|
|
| fig, ax = plt.subplots(figsize=(6.2, 5.2)) |
| if show_likelihood and n_used > 0: |
| likelihood = _likelihood_surface(grid_w0, grid_w1, x[:n_used], y[:n_used], sigma) |
| ax.contour( |
| grid_w0, |
| grid_w1, |
| likelihood, |
| levels=_contour_levels(likelihood), |
| colors="0.55", |
| linestyles="dotted", |
| linewidths=1.1, |
| ) |
|
|
| ax.contour( |
| grid_w0, |
| grid_w1, |
| prior_density, |
| levels=_contour_levels(prior_density), |
| colors="tab:blue", |
| linestyles="dashed", |
| linewidths=1.5, |
| ) |
| ax.contour( |
| grid_w0, |
| grid_w1, |
| posterior_density, |
| levels=_contour_levels(posterior_density), |
| colors="tab:red", |
| linewidths=1.8, |
| ) |
| ax.scatter(true_w[0], true_w[1], marker="*", s=140, color="black", zorder=5) |
| ax.scatter(posterior_mean[0], posterior_mean[1], s=44, color="tab:red", zorder=5) |
|
|
| handles = [ |
| Line2D([0], [0], color="tab:blue", linestyle="dashed", linewidth=1.5, label="prior"), |
| Line2D([0], [0], color="tab:red", linewidth=1.8, label="posterior"), |
| Line2D([0], [0], marker="o", color="tab:red", linewidth=0, markersize=7, label="posterior mean"), |
| Line2D([0], [0], marker="*", color="black", linewidth=0, markersize=10, label="true parameter"), |
| ] |
| if show_likelihood and n_used > 0: |
| handles.insert( |
| 0, |
| Line2D([0], [0], color="0.55", linestyle="dotted", linewidth=1.2, label="likelihood"), |
| ) |
|
|
| ax.set_title("Parameter Space") |
| ax.set_xlabel(r"$w_0$") |
| ax.set_ylabel(r"$w_1$") |
| ax.set_xlim(*w0_limits) |
| ax.set_ylim(*w1_limits) |
| ax.grid(alpha=0.22) |
| ax.legend(handles=handles, loc="best") |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def plot_data_space( |
| x: FloatArray, |
| y: FloatArray, |
| n_used: int, |
| true_w: FloatArray, |
| posterior_mean: FloatArray, |
| sampled_w: FloatArray, |
| sample_label: str, |
| ) -> Figure: |
| fig, ax = plt.subplots(figsize=(6.2, 5.2)) |
|
|
| if n_used < len(x): |
| ax.scatter(x[n_used:], y[n_used:], color="0.83", s=36, label="unused data", zorder=2) |
| if n_used > 0: |
| ax.scatter(x[:n_used], y[:n_used], color="tab:blue", s=42, label="used data", zorder=3) |
|
|
| x_line = np.linspace(-1.1, 1.1, 240) |
| true_line = true_w[0] + true_w[1] * x_line |
| posterior_line = posterior_mean[0] + posterior_mean[1] * x_line |
|
|
| ax.plot(x_line, true_line, color="black", linewidth=2.2, label="true line") |
| ax.plot(x_line, posterior_line, color="tab:red", linewidth=2.0, label="posterior mean") |
|
|
| for index, weights in enumerate(sampled_w): |
| label = sample_label if index == 0 else None |
| ax.plot( |
| x_line, |
| weights[0] + weights[1] * x_line, |
| color="tab:orange", |
| alpha=0.18, |
| linewidth=1.15, |
| label=label, |
| zorder=1, |
| ) |
|
|
| ax.set_title("Data Space") |
| ax.set_xlabel("x") |
| ax.set_ylabel("y") |
| ax.set_xlim(-1.1, 1.1) |
| ax.grid(alpha=0.22) |
| ax.legend(loc="best") |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def _format_array(value: FloatArray) -> str: |
| return np.array2string(value, precision=3, suppress_small=True, floatmode="fixed") |
|
|
|
|
| def _select_sampling_distribution( |
| sample_mode: str, |
| n_used: int, |
| prior_mean: FloatArray, |
| prior_cov: FloatArray, |
| posterior_mean: FloatArray, |
| posterior_cov: FloatArray, |
| ) -> tuple[FloatArray, FloatArray, str]: |
| if sample_mode == "posterior samples" and n_used > 0: |
| return posterior_mean, posterior_cov, "posterior samples" |
| if sample_mode == "posterior samples": |
| return prior_mean, prior_cov, "prior samples (N=0 fallback)" |
| return prior_mean, prior_cov, "prior samples" |
|
|
|
|
| def sync_n_slider(n_max: float, n_used: float) -> gr.components.Slider: |
| max_value = max(1, int(n_max)) |
| current_value = min(max(0, int(n_used)), max_value) |
| return gr.update(maximum=max_value, value=current_value) |
|
|
|
|
| def update( |
| true_w0: float, |
| true_w1: float, |
| sigma: float, |
| prior_mean_w0: float, |
| prior_mean_w1: float, |
| prior_std_w0: float, |
| prior_std_w1: float, |
| prior_rho: float, |
| n_max: float, |
| n_used: float, |
| seed: float, |
| n_lines: float, |
| sample_mode: str, |
| show_likelihood: bool, |
| ) -> tuple[Figure, Figure, str, str, str]: |
| try: |
| n_max_int = max(1, int(n_max)) |
| n_used_int = min(max(0, int(n_used)), n_max_int) |
| seed_int = int(seed) |
| n_lines_int = max(1, int(n_lines)) |
|
|
| true_w = np.array([true_w0, true_w1], dtype=float) |
| prior_mean = np.array([prior_mean_w0, prior_mean_w1], dtype=float) |
| prior_cov = make_prior_cov(prior_std_w0, prior_std_w1, prior_rho) |
|
|
| x, y = generate_dataset(true_w0, true_w1, sigma, n_max_int, seed_int) |
| posterior_mean, posterior_cov = compute_posterior( |
| prior_mean=prior_mean, |
| prior_cov=prior_cov, |
| x=x, |
| y=y, |
| sigma=sigma, |
| n_used=n_used_int, |
| ) |
| sample_mean, sample_cov, sample_label = _select_sampling_distribution( |
| sample_mode=sample_mode, |
| n_used=n_used_int, |
| prior_mean=prior_mean, |
| prior_cov=prior_cov, |
| posterior_mean=posterior_mean, |
| posterior_cov=posterior_cov, |
| ) |
| sample_seed = seed_int + 10_000 * n_used_int + (1 if sample_label.startswith("posterior") else 0) |
| sampled_w = sample_weights(sample_mean, sample_cov, n_lines_int, sample_seed) |
|
|
| parameter_fig = plot_parameter_space( |
| prior_mean=prior_mean, |
| prior_cov=prior_cov, |
| posterior_mean=posterior_mean, |
| posterior_cov=posterior_cov, |
| true_w=true_w, |
| x=x, |
| y=y, |
| sigma=sigma, |
| n_used=n_used_int, |
| show_likelihood=show_likelihood, |
| ) |
| data_fig = plot_data_space( |
| x=x, |
| y=y, |
| n_used=n_used_int, |
| true_w=true_w, |
| posterior_mean=posterior_mean, |
| sampled_w=sampled_w, |
| sample_label=sample_label, |
| ) |
|
|
| summary = "\n".join( |
| [ |
| "### Current State", |
| f"- 使用データ数: `{n_used_int} / {n_max_int}`", |
| f"- 直線サンプル元: `{sample_label}`", |
| f"- 尤度等高線: `{'on' if show_likelihood and n_used_int > 0 else 'off'}`", |
| ] |
| ) |
| return ( |
| parameter_fig, |
| data_fig, |
| _format_array(posterior_mean), |
| _format_array(posterior_cov), |
| summary, |
| ) |
| except (ValueError, np.linalg.LinAlgError) as exc: |
| raise gr.Error(str(exc)) from exc |
|
|
|
|
| def build_app() -> gr.Blocks: |
| default_n_max = 60 |
| default_n_used = 12 |
|
|
| with gr.Blocks(title="Bayesian Linear Regression Visualizer", theme=APP_THEME) as demo: |
| gr.Markdown( |
| """ |
| # Bayesian Linear Regression Visualizer |
| 事前分布・尤度・事後分布の関係と、パラメータ分布からサンプルした回帰直線群の変化を 2 つの図で確認できます。 |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=4): |
| gr.Markdown("## Controls") |
|
|
| with gr.Group(): |
| gr.Markdown("### 真のモデル") |
| true_w0 = gr.Slider( |
| -3.0, |
| 3.0, |
| value=-0.3, |
| step=0.1, |
| label="true_w0", |
| info="真の切片。黒い真の回帰直線の上下位置を決めます。", |
| ) |
| true_w1 = gr.Slider( |
| -3.0, |
| 3.0, |
| value=1.2, |
| step=0.1, |
| label="true_w1", |
| info="真の傾き。黒い真の回帰直線の傾きを決めます。", |
| ) |
| sigma = gr.Slider( |
| 0.05, |
| 1.2, |
| value=0.25, |
| step=0.05, |
| label="sigma", |
| info="観測ノイズの標準偏差。大きいほどデータ点が真の直線から散らばります。", |
| ) |
|
|
| with gr.Group(): |
| gr.Markdown("### 事前分布") |
| prior_mean_w0 = gr.Slider( |
| -3.0, |
| 3.0, |
| value=0.0, |
| step=0.1, |
| label="prior_mean_w0", |
| info="事前分布での切片の平均です。", |
| ) |
| prior_mean_w1 = gr.Slider( |
| -3.0, |
| 3.0, |
| value=0.0, |
| step=0.1, |
| label="prior_mean_w1", |
| info="事前分布での傾きの平均です。", |
| ) |
| prior_std_w0 = gr.Slider( |
| 0.1, |
| 3.0, |
| value=1.2, |
| step=0.1, |
| label="prior_std_w0", |
| info="事前分布での切片方向の広がりです。大きいほど切片に自信がありません。", |
| ) |
| prior_std_w1 = gr.Slider( |
| 0.1, |
| 3.0, |
| value=1.2, |
| step=0.1, |
| label="prior_std_w1", |
| info="事前分布での傾き方向の広がりです。大きいほど傾きに自信がありません。", |
| ) |
| prior_rho = gr.Slider( |
| -0.95, |
| 0.95, |
| value=-0.25, |
| step=0.05, |
| label="prior_rho", |
| info="事前分布での切片と傾きの相関です。0 なら軸に沿い、正負で等高線の傾きが変わります。", |
| ) |
|
|
| with gr.Group(): |
| gr.Markdown("### データと描画") |
| n_max = gr.Slider( |
| 10, |
| 200, |
| value=default_n_max, |
| step=1, |
| label="N_max", |
| info="先に生成しておく総データ数です。", |
| ) |
| n_used = gr.Slider( |
| 0, |
| default_n_max, |
| value=default_n_used, |
| step=1, |
| label="N", |
| info="事後分布の計算に使うデータ数です。先頭から N 個だけ使います。", |
| ) |
| seed = gr.Slider( |
| 0, |
| 9999, |
| value=7, |
| step=1, |
| label="seed", |
| info="データ生成の乱数シードです。同じ値なら同じデータになります。", |
| ) |
| n_lines = gr.Slider( |
| 1, |
| 50, |
| value=20, |
| step=1, |
| label="n_lines", |
| info="分布からサンプルして描く回帰直線の本数です。", |
| ) |
| sample_mode = gr.Radio( |
| choices=["prior samples", "posterior samples"], |
| value="posterior samples", |
| label="表示モード", |
| info="回帰直線を事前分布から引くか、事後分布から引くかを選びます。", |
| ) |
| show_likelihood = gr.Checkbox( |
| value=True, |
| label="パラメータ空間に尤度等高線を表示", |
| info="灰色の点線で尤度の等高線を重ねます。", |
| ) |
|
|
| with gr.Column(scale=6): |
| with gr.Row(): |
| parameter_plot = gr.Plot(label="パラメータ空間") |
| data_plot = gr.Plot(label="データ空間") |
| with gr.Row(): |
| posterior_mean_box = gr.Textbox(label="事後平均 m_N", lines=2) |
| posterior_cov_box = gr.Textbox(label="事後共分散 S_N", lines=4) |
| summary_box = gr.Markdown() |
|
|
| inputs = [ |
| true_w0, |
| true_w1, |
| sigma, |
| prior_mean_w0, |
| prior_mean_w1, |
| prior_std_w0, |
| prior_std_w1, |
| prior_rho, |
| n_max, |
| n_used, |
| seed, |
| n_lines, |
| sample_mode, |
| show_likelihood, |
| ] |
| outputs = [parameter_plot, data_plot, posterior_mean_box, posterior_cov_box, summary_box] |
|
|
| n_max_event = n_max.change(sync_n_slider, inputs=[n_max, n_used], outputs=n_used) |
| n_max_event.then(update, inputs=inputs, outputs=outputs) |
|
|
| for component in [ |
| true_w0, |
| true_w1, |
| sigma, |
| prior_mean_w0, |
| prior_mean_w1, |
| prior_std_w0, |
| prior_std_w1, |
| prior_rho, |
| n_used, |
| seed, |
| n_lines, |
| sample_mode, |
| show_likelihood, |
| ]: |
| component.change(update, inputs=inputs, outputs=outputs) |
|
|
| demo.load(update, inputs=inputs, outputs=outputs) |
|
|
| return demo |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Launch the Bayesian linear regression visualizer.") |
| parser.add_argument("--server-name", default=None, help="Host for the Gradio server.") |
| parser.add_argument("--server-port", type=int, default=None, help="Port for the Gradio server.") |
| parser.add_argument("--share", action="store_true", help="Create a public Gradio share link.") |
| parser.add_argument("--browser", action="store_true", help="Automatically open the app in a browser.") |
| args = parser.parse_args() |
|
|
| app = build_app() |
| launch_kwargs: dict[str, object] = { |
| "share": args.share, |
| "inbrowser": args.browser, |
| } |
| if args.server_name is not None: |
| launch_kwargs["server_name"] = args.server_name |
| if args.server_port is not None: |
| launch_kwargs["server_port"] = args.server_port |
| app.queue().launch(**launch_kwargs) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|