from __future__ import annotations import argparse import gradio as gr import matplotlib import numpy as np from matplotlib.figure import Figure from matplotlib.lines import Line2D from numpy.typing import NDArray # Use a headless backend so the app also works in terminal-only environments. matplotlib.use("Agg") import matplotlib.pyplot as plt FloatArray = NDArray[np.float64] APP_THEME = gr.themes.Soft( primary_hue="sky", secondary_hue="amber", neutral_hue="slate", ) def make_prior_cov(std_w0: float, std_w1: float, rho: float) -> FloatArray: if std_w0 <= 0 or std_w1 <= 0: raise ValueError("事前標準偏差は正の値にしてください。") if not (-0.999 < rho < 0.999): raise ValueError("事前相関係数 rho は -1 より大きく 1 より小さい値にしてください。") cov = np.array( [ [std_w0**2, rho * std_w0 * std_w1], [rho * std_w0 * std_w1, std_w1**2], ], dtype=float, ) sign, _ = np.linalg.slogdet(cov) if sign <= 0: raise ValueError("事前共分散行列が正定値ではありません。標準偏差と相関係数を見直してください。") return cov def generate_dataset( true_w0: float, true_w1: float, sigma: float, n_max: int, seed: int, ) -> tuple[FloatArray, FloatArray]: if n_max < 1: raise ValueError("N_max は 1 以上にしてください。") if sigma <= 0: raise ValueError("観測ノイズ標準偏差 sigma は正の値にしてください。") rng = np.random.default_rng(seed) x = rng.uniform(-1.0, 1.0, size=n_max) noise = rng.normal(0.0, sigma, size=n_max) y = true_w0 + true_w1 * x + noise return x.astype(float), y.astype(float) def compute_posterior( prior_mean: FloatArray, prior_cov: FloatArray, x: FloatArray, y: FloatArray, sigma: float, n_used: int, ) -> tuple[FloatArray, FloatArray]: n_used = int(np.clip(n_used, 0, len(x))) if n_used == 0: return prior_mean.copy(), prior_cov.copy() phi = np.column_stack([np.ones(n_used), x[:n_used]]) y_used = y[:n_used] prior_precision = np.linalg.inv(prior_cov) posterior_precision = prior_precision + (phi.T @ phi) / (sigma**2) posterior_cov = np.linalg.inv(posterior_precision) rhs = prior_precision @ prior_mean + (phi.T @ y_used) / (sigma**2) posterior_mean = posterior_cov @ rhs return posterior_mean, posterior_cov def sample_weights(mean: FloatArray, cov: FloatArray, n_lines: int, seed: int) -> FloatArray: if n_lines < 1: raise ValueError("表示する直線本数 n_lines は 1 以上にしてください。") rng = np.random.default_rng(seed) return rng.multivariate_normal(mean=mean, cov=cov, size=n_lines).astype(float) def _gaussian_density_grid( mean: FloatArray, cov: FloatArray, grid_w0: FloatArray, grid_w1: FloatArray, ) -> FloatArray: cov_inv = np.linalg.inv(cov) sign, logdet = np.linalg.slogdet(cov) if sign <= 0: raise ValueError("共分散行列が正定値ではありません。") position = np.stack([grid_w0, grid_w1], axis=-1) diff = position - mean quad = np.einsum("...i,ij,...j->...", diff, cov_inv, diff) log_density = -0.5 * (2 * np.log(2 * np.pi) + logdet + quad) return np.exp(log_density) def _likelihood_surface( grid_w0: FloatArray, grid_w1: FloatArray, x_used: FloatArray, y_used: FloatArray, sigma: float, ) -> FloatArray: predictions = grid_w0[..., None] + grid_w1[..., None] * x_used residuals = y_used - predictions rss = np.sum(residuals**2, axis=-1) log_likelihood = -0.5 * rss / (sigma**2) return np.exp(log_likelihood - np.max(log_likelihood)) def _contour_levels(surface: FloatArray) -> FloatArray: peak = float(np.max(surface)) if not np.isfinite(peak) or peak <= 0: return np.array([1.0], dtype=float) relative_levels = np.exp(-0.5 * np.array([7.0, 4.5, 2.5, 1.0, 0.3], dtype=float)) levels = np.sort(peak * relative_levels) return np.unique(np.clip(levels, peak * 1e-6, peak * 0.999)) def _parameter_limits( prior_mean: FloatArray, prior_cov: FloatArray, posterior_mean: FloatArray, posterior_cov: FloatArray, true_w: FloatArray, ) -> tuple[tuple[float, float], tuple[float, float]]: prior_std = 4.0 * np.sqrt(np.diag(prior_cov)) posterior_std = 4.0 * np.sqrt(np.diag(posterior_cov)) lower = np.vstack( [ prior_mean - prior_std, posterior_mean - posterior_std, true_w, ] ).min(axis=0) upper = np.vstack( [ prior_mean + prior_std, posterior_mean + posterior_std, true_w, ] ).max(axis=0) span = np.maximum(upper - lower, np.array([1.0, 1.0], dtype=float)) padding = 0.15 * span w0_limits = (float(lower[0] - padding[0]), float(upper[0] + padding[0])) w1_limits = (float(lower[1] - padding[1]), float(upper[1] + padding[1])) return w0_limits, w1_limits def plot_parameter_space( prior_mean: FloatArray, prior_cov: FloatArray, posterior_mean: FloatArray, posterior_cov: FloatArray, true_w: FloatArray, x: FloatArray, y: FloatArray, sigma: float, n_used: int, show_likelihood: bool, ) -> Figure: w0_limits, w1_limits = _parameter_limits(prior_mean, prior_cov, posterior_mean, posterior_cov, true_w) w0_grid = np.linspace(*w0_limits, 180) w1_grid = np.linspace(*w1_limits, 180) grid_w0, grid_w1 = np.meshgrid(w0_grid, w1_grid) prior_density = _gaussian_density_grid(prior_mean, prior_cov, grid_w0, grid_w1) posterior_density = _gaussian_density_grid(posterior_mean, posterior_cov, grid_w0, grid_w1) fig, ax = plt.subplots(figsize=(6.2, 5.2)) if show_likelihood and n_used > 0: likelihood = _likelihood_surface(grid_w0, grid_w1, x[:n_used], y[:n_used], sigma) ax.contour( grid_w0, grid_w1, likelihood, levels=_contour_levels(likelihood), colors="0.55", linestyles="dotted", linewidths=1.1, ) ax.contour( grid_w0, grid_w1, prior_density, levels=_contour_levels(prior_density), colors="tab:blue", linestyles="dashed", linewidths=1.5, ) ax.contour( grid_w0, grid_w1, posterior_density, levels=_contour_levels(posterior_density), colors="tab:red", linewidths=1.8, ) ax.scatter(true_w[0], true_w[1], marker="*", s=140, color="black", zorder=5) ax.scatter(posterior_mean[0], posterior_mean[1], s=44, color="tab:red", zorder=5) handles = [ Line2D([0], [0], color="tab:blue", linestyle="dashed", linewidth=1.5, label="prior"), Line2D([0], [0], color="tab:red", linewidth=1.8, label="posterior"), Line2D([0], [0], marker="o", color="tab:red", linewidth=0, markersize=7, label="posterior mean"), Line2D([0], [0], marker="*", color="black", linewidth=0, markersize=10, label="true parameter"), ] if show_likelihood and n_used > 0: handles.insert( 0, Line2D([0], [0], color="0.55", linestyle="dotted", linewidth=1.2, label="likelihood"), ) ax.set_title("Parameter Space") ax.set_xlabel(r"$w_0$") ax.set_ylabel(r"$w_1$") ax.set_xlim(*w0_limits) ax.set_ylim(*w1_limits) ax.grid(alpha=0.22) ax.legend(handles=handles, loc="best") fig.tight_layout() return fig def plot_data_space( x: FloatArray, y: FloatArray, n_used: int, true_w: FloatArray, posterior_mean: FloatArray, sampled_w: FloatArray, sample_label: str, ) -> Figure: fig, ax = plt.subplots(figsize=(6.2, 5.2)) if n_used < len(x): ax.scatter(x[n_used:], y[n_used:], color="0.83", s=36, label="unused data", zorder=2) if n_used > 0: ax.scatter(x[:n_used], y[:n_used], color="tab:blue", s=42, label="used data", zorder=3) x_line = np.linspace(-1.1, 1.1, 240) true_line = true_w[0] + true_w[1] * x_line posterior_line = posterior_mean[0] + posterior_mean[1] * x_line ax.plot(x_line, true_line, color="black", linewidth=2.2, label="true line") ax.plot(x_line, posterior_line, color="tab:red", linewidth=2.0, label="posterior mean") for index, weights in enumerate(sampled_w): label = sample_label if index == 0 else None ax.plot( x_line, weights[0] + weights[1] * x_line, color="tab:orange", alpha=0.18, linewidth=1.15, label=label, zorder=1, ) ax.set_title("Data Space") ax.set_xlabel("x") ax.set_ylabel("y") ax.set_xlim(-1.1, 1.1) ax.grid(alpha=0.22) ax.legend(loc="best") fig.tight_layout() return fig def _format_array(value: FloatArray) -> str: return np.array2string(value, precision=3, suppress_small=True, floatmode="fixed") def _select_sampling_distribution( sample_mode: str, n_used: int, prior_mean: FloatArray, prior_cov: FloatArray, posterior_mean: FloatArray, posterior_cov: FloatArray, ) -> tuple[FloatArray, FloatArray, str]: if sample_mode == "posterior samples" and n_used > 0: return posterior_mean, posterior_cov, "posterior samples" if sample_mode == "posterior samples": return prior_mean, prior_cov, "prior samples (N=0 fallback)" return prior_mean, prior_cov, "prior samples" def sync_n_slider(n_max: float, n_used: float) -> gr.components.Slider: max_value = max(1, int(n_max)) current_value = min(max(0, int(n_used)), max_value) return gr.update(maximum=max_value, value=current_value) def update( true_w0: float, true_w1: float, sigma: float, prior_mean_w0: float, prior_mean_w1: float, prior_std_w0: float, prior_std_w1: float, prior_rho: float, n_max: float, n_used: float, seed: float, n_lines: float, sample_mode: str, show_likelihood: bool, ) -> tuple[Figure, Figure, str, str, str]: try: n_max_int = max(1, int(n_max)) n_used_int = min(max(0, int(n_used)), n_max_int) seed_int = int(seed) n_lines_int = max(1, int(n_lines)) true_w = np.array([true_w0, true_w1], dtype=float) prior_mean = np.array([prior_mean_w0, prior_mean_w1], dtype=float) prior_cov = make_prior_cov(prior_std_w0, prior_std_w1, prior_rho) x, y = generate_dataset(true_w0, true_w1, sigma, n_max_int, seed_int) posterior_mean, posterior_cov = compute_posterior( prior_mean=prior_mean, prior_cov=prior_cov, x=x, y=y, sigma=sigma, n_used=n_used_int, ) sample_mean, sample_cov, sample_label = _select_sampling_distribution( sample_mode=sample_mode, n_used=n_used_int, prior_mean=prior_mean, prior_cov=prior_cov, posterior_mean=posterior_mean, posterior_cov=posterior_cov, ) sample_seed = seed_int + 10_000 * n_used_int + (1 if sample_label.startswith("posterior") else 0) sampled_w = sample_weights(sample_mean, sample_cov, n_lines_int, sample_seed) parameter_fig = plot_parameter_space( prior_mean=prior_mean, prior_cov=prior_cov, posterior_mean=posterior_mean, posterior_cov=posterior_cov, true_w=true_w, x=x, y=y, sigma=sigma, n_used=n_used_int, show_likelihood=show_likelihood, ) data_fig = plot_data_space( x=x, y=y, n_used=n_used_int, true_w=true_w, posterior_mean=posterior_mean, sampled_w=sampled_w, sample_label=sample_label, ) summary = "\n".join( [ "### Current State", f"- 使用データ数: `{n_used_int} / {n_max_int}`", f"- 直線サンプル元: `{sample_label}`", f"- 尤度等高線: `{'on' if show_likelihood and n_used_int > 0 else 'off'}`", ] ) return ( parameter_fig, data_fig, _format_array(posterior_mean), _format_array(posterior_cov), summary, ) except (ValueError, np.linalg.LinAlgError) as exc: raise gr.Error(str(exc)) from exc def build_app() -> gr.Blocks: default_n_max = 60 default_n_used = 12 with gr.Blocks(title="Bayesian Linear Regression Visualizer", theme=APP_THEME) as demo: gr.Markdown( """ # Bayesian Linear Regression Visualizer 事前分布・尤度・事後分布の関係と、パラメータ分布からサンプルした回帰直線群の変化を 2 つの図で確認できます。 """ ) with gr.Row(): with gr.Column(scale=4): gr.Markdown("## Controls") with gr.Group(): gr.Markdown("### 真のモデル") true_w0 = gr.Slider( -3.0, 3.0, value=-0.3, step=0.1, label="true_w0", info="真の切片。黒い真の回帰直線の上下位置を決めます。", ) true_w1 = gr.Slider( -3.0, 3.0, value=1.2, step=0.1, label="true_w1", info="真の傾き。黒い真の回帰直線の傾きを決めます。", ) sigma = gr.Slider( 0.05, 1.2, value=0.25, step=0.05, label="sigma", info="観測ノイズの標準偏差。大きいほどデータ点が真の直線から散らばります。", ) with gr.Group(): gr.Markdown("### 事前分布") prior_mean_w0 = gr.Slider( -3.0, 3.0, value=0.0, step=0.1, label="prior_mean_w0", info="事前分布での切片の平均です。", ) prior_mean_w1 = gr.Slider( -3.0, 3.0, value=0.0, step=0.1, label="prior_mean_w1", info="事前分布での傾きの平均です。", ) prior_std_w0 = gr.Slider( 0.1, 3.0, value=1.2, step=0.1, label="prior_std_w0", info="事前分布での切片方向の広がりです。大きいほど切片に自信がありません。", ) prior_std_w1 = gr.Slider( 0.1, 3.0, value=1.2, step=0.1, label="prior_std_w1", info="事前分布での傾き方向の広がりです。大きいほど傾きに自信がありません。", ) prior_rho = gr.Slider( -0.95, 0.95, value=-0.25, step=0.05, label="prior_rho", info="事前分布での切片と傾きの相関です。0 なら軸に沿い、正負で等高線の傾きが変わります。", ) with gr.Group(): gr.Markdown("### データと描画") n_max = gr.Slider( 10, 200, value=default_n_max, step=1, label="N_max", info="先に生成しておく総データ数です。", ) n_used = gr.Slider( 0, default_n_max, value=default_n_used, step=1, label="N", info="事後分布の計算に使うデータ数です。先頭から N 個だけ使います。", ) seed = gr.Slider( 0, 9999, value=7, step=1, label="seed", info="データ生成の乱数シードです。同じ値なら同じデータになります。", ) n_lines = gr.Slider( 1, 50, value=20, step=1, label="n_lines", info="分布からサンプルして描く回帰直線の本数です。", ) sample_mode = gr.Radio( choices=["prior samples", "posterior samples"], value="posterior samples", label="表示モード", info="回帰直線を事前分布から引くか、事後分布から引くかを選びます。", ) show_likelihood = gr.Checkbox( value=True, label="パラメータ空間に尤度等高線を表示", info="灰色の点線で尤度の等高線を重ねます。", ) with gr.Column(scale=6): with gr.Row(): parameter_plot = gr.Plot(label="パラメータ空間") data_plot = gr.Plot(label="データ空間") with gr.Row(): posterior_mean_box = gr.Textbox(label="事後平均 m_N", lines=2) posterior_cov_box = gr.Textbox(label="事後共分散 S_N", lines=4) summary_box = gr.Markdown() inputs = [ true_w0, true_w1, sigma, prior_mean_w0, prior_mean_w1, prior_std_w0, prior_std_w1, prior_rho, n_max, n_used, seed, n_lines, sample_mode, show_likelihood, ] outputs = [parameter_plot, data_plot, posterior_mean_box, posterior_cov_box, summary_box] n_max_event = n_max.change(sync_n_slider, inputs=[n_max, n_used], outputs=n_used) n_max_event.then(update, inputs=inputs, outputs=outputs) for component in [ true_w0, true_w1, sigma, prior_mean_w0, prior_mean_w1, prior_std_w0, prior_std_w1, prior_rho, n_used, seed, n_lines, sample_mode, show_likelihood, ]: component.change(update, inputs=inputs, outputs=outputs) demo.load(update, inputs=inputs, outputs=outputs) return demo def main() -> None: parser = argparse.ArgumentParser(description="Launch the Bayesian linear regression visualizer.") parser.add_argument("--server-name", default=None, help="Host for the Gradio server.") parser.add_argument("--server-port", type=int, default=None, help="Port for the Gradio server.") parser.add_argument("--share", action="store_true", help="Create a public Gradio share link.") parser.add_argument("--browser", action="store_true", help="Automatically open the app in a browser.") args = parser.parse_args() app = build_app() launch_kwargs: dict[str, object] = { "share": args.share, "inbrowser": args.browser, } if args.server_name is not None: launch_kwargs["server_name"] = args.server_name if args.server_port is not None: launch_kwargs["server_port"] = args.server_port app.queue().launch(**launch_kwargs) if __name__ == "__main__": main()