taxfree-python
Add control descriptions
cbc118e
from __future__ import annotations
import argparse
import gradio as gr
import matplotlib
import numpy as np
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
from numpy.typing import NDArray
# Use a headless backend so the app also works in terminal-only environments.
matplotlib.use("Agg")
import matplotlib.pyplot as plt
FloatArray = NDArray[np.float64]
APP_THEME = gr.themes.Soft(
primary_hue="sky",
secondary_hue="amber",
neutral_hue="slate",
)
def make_prior_cov(std_w0: float, std_w1: float, rho: float) -> FloatArray:
if std_w0 <= 0 or std_w1 <= 0:
raise ValueError("事前標準偏差は正の値にしてください。")
if not (-0.999 < rho < 0.999):
raise ValueError("事前相関係数 rho は -1 より大きく 1 より小さい値にしてください。")
cov = np.array(
[
[std_w0**2, rho * std_w0 * std_w1],
[rho * std_w0 * std_w1, std_w1**2],
],
dtype=float,
)
sign, _ = np.linalg.slogdet(cov)
if sign <= 0:
raise ValueError("事前共分散行列が正定値ではありません。標準偏差と相関係数を見直してください。")
return cov
def generate_dataset(
true_w0: float,
true_w1: float,
sigma: float,
n_max: int,
seed: int,
) -> tuple[FloatArray, FloatArray]:
if n_max < 1:
raise ValueError("N_max は 1 以上にしてください。")
if sigma <= 0:
raise ValueError("観測ノイズ標準偏差 sigma は正の値にしてください。")
rng = np.random.default_rng(seed)
x = rng.uniform(-1.0, 1.0, size=n_max)
noise = rng.normal(0.0, sigma, size=n_max)
y = true_w0 + true_w1 * x + noise
return x.astype(float), y.astype(float)
def compute_posterior(
prior_mean: FloatArray,
prior_cov: FloatArray,
x: FloatArray,
y: FloatArray,
sigma: float,
n_used: int,
) -> tuple[FloatArray, FloatArray]:
n_used = int(np.clip(n_used, 0, len(x)))
if n_used == 0:
return prior_mean.copy(), prior_cov.copy()
phi = np.column_stack([np.ones(n_used), x[:n_used]])
y_used = y[:n_used]
prior_precision = np.linalg.inv(prior_cov)
posterior_precision = prior_precision + (phi.T @ phi) / (sigma**2)
posterior_cov = np.linalg.inv(posterior_precision)
rhs = prior_precision @ prior_mean + (phi.T @ y_used) / (sigma**2)
posterior_mean = posterior_cov @ rhs
return posterior_mean, posterior_cov
def sample_weights(mean: FloatArray, cov: FloatArray, n_lines: int, seed: int) -> FloatArray:
if n_lines < 1:
raise ValueError("表示する直線本数 n_lines は 1 以上にしてください。")
rng = np.random.default_rng(seed)
return rng.multivariate_normal(mean=mean, cov=cov, size=n_lines).astype(float)
def _gaussian_density_grid(
mean: FloatArray,
cov: FloatArray,
grid_w0: FloatArray,
grid_w1: FloatArray,
) -> FloatArray:
cov_inv = np.linalg.inv(cov)
sign, logdet = np.linalg.slogdet(cov)
if sign <= 0:
raise ValueError("共分散行列が正定値ではありません。")
position = np.stack([grid_w0, grid_w1], axis=-1)
diff = position - mean
quad = np.einsum("...i,ij,...j->...", diff, cov_inv, diff)
log_density = -0.5 * (2 * np.log(2 * np.pi) + logdet + quad)
return np.exp(log_density)
def _likelihood_surface(
grid_w0: FloatArray,
grid_w1: FloatArray,
x_used: FloatArray,
y_used: FloatArray,
sigma: float,
) -> FloatArray:
predictions = grid_w0[..., None] + grid_w1[..., None] * x_used
residuals = y_used - predictions
rss = np.sum(residuals**2, axis=-1)
log_likelihood = -0.5 * rss / (sigma**2)
return np.exp(log_likelihood - np.max(log_likelihood))
def _contour_levels(surface: FloatArray) -> FloatArray:
peak = float(np.max(surface))
if not np.isfinite(peak) or peak <= 0:
return np.array([1.0], dtype=float)
relative_levels = np.exp(-0.5 * np.array([7.0, 4.5, 2.5, 1.0, 0.3], dtype=float))
levels = np.sort(peak * relative_levels)
return np.unique(np.clip(levels, peak * 1e-6, peak * 0.999))
def _parameter_limits(
prior_mean: FloatArray,
prior_cov: FloatArray,
posterior_mean: FloatArray,
posterior_cov: FloatArray,
true_w: FloatArray,
) -> tuple[tuple[float, float], tuple[float, float]]:
prior_std = 4.0 * np.sqrt(np.diag(prior_cov))
posterior_std = 4.0 * np.sqrt(np.diag(posterior_cov))
lower = np.vstack(
[
prior_mean - prior_std,
posterior_mean - posterior_std,
true_w,
]
).min(axis=0)
upper = np.vstack(
[
prior_mean + prior_std,
posterior_mean + posterior_std,
true_w,
]
).max(axis=0)
span = np.maximum(upper - lower, np.array([1.0, 1.0], dtype=float))
padding = 0.15 * span
w0_limits = (float(lower[0] - padding[0]), float(upper[0] + padding[0]))
w1_limits = (float(lower[1] - padding[1]), float(upper[1] + padding[1]))
return w0_limits, w1_limits
def plot_parameter_space(
prior_mean: FloatArray,
prior_cov: FloatArray,
posterior_mean: FloatArray,
posterior_cov: FloatArray,
true_w: FloatArray,
x: FloatArray,
y: FloatArray,
sigma: float,
n_used: int,
show_likelihood: bool,
) -> Figure:
w0_limits, w1_limits = _parameter_limits(prior_mean, prior_cov, posterior_mean, posterior_cov, true_w)
w0_grid = np.linspace(*w0_limits, 180)
w1_grid = np.linspace(*w1_limits, 180)
grid_w0, grid_w1 = np.meshgrid(w0_grid, w1_grid)
prior_density = _gaussian_density_grid(prior_mean, prior_cov, grid_w0, grid_w1)
posterior_density = _gaussian_density_grid(posterior_mean, posterior_cov, grid_w0, grid_w1)
fig, ax = plt.subplots(figsize=(6.2, 5.2))
if show_likelihood and n_used > 0:
likelihood = _likelihood_surface(grid_w0, grid_w1, x[:n_used], y[:n_used], sigma)
ax.contour(
grid_w0,
grid_w1,
likelihood,
levels=_contour_levels(likelihood),
colors="0.55",
linestyles="dotted",
linewidths=1.1,
)
ax.contour(
grid_w0,
grid_w1,
prior_density,
levels=_contour_levels(prior_density),
colors="tab:blue",
linestyles="dashed",
linewidths=1.5,
)
ax.contour(
grid_w0,
grid_w1,
posterior_density,
levels=_contour_levels(posterior_density),
colors="tab:red",
linewidths=1.8,
)
ax.scatter(true_w[0], true_w[1], marker="*", s=140, color="black", zorder=5)
ax.scatter(posterior_mean[0], posterior_mean[1], s=44, color="tab:red", zorder=5)
handles = [
Line2D([0], [0], color="tab:blue", linestyle="dashed", linewidth=1.5, label="prior"),
Line2D([0], [0], color="tab:red", linewidth=1.8, label="posterior"),
Line2D([0], [0], marker="o", color="tab:red", linewidth=0, markersize=7, label="posterior mean"),
Line2D([0], [0], marker="*", color="black", linewidth=0, markersize=10, label="true parameter"),
]
if show_likelihood and n_used > 0:
handles.insert(
0,
Line2D([0], [0], color="0.55", linestyle="dotted", linewidth=1.2, label="likelihood"),
)
ax.set_title("Parameter Space")
ax.set_xlabel(r"$w_0$")
ax.set_ylabel(r"$w_1$")
ax.set_xlim(*w0_limits)
ax.set_ylim(*w1_limits)
ax.grid(alpha=0.22)
ax.legend(handles=handles, loc="best")
fig.tight_layout()
return fig
def plot_data_space(
x: FloatArray,
y: FloatArray,
n_used: int,
true_w: FloatArray,
posterior_mean: FloatArray,
sampled_w: FloatArray,
sample_label: str,
) -> Figure:
fig, ax = plt.subplots(figsize=(6.2, 5.2))
if n_used < len(x):
ax.scatter(x[n_used:], y[n_used:], color="0.83", s=36, label="unused data", zorder=2)
if n_used > 0:
ax.scatter(x[:n_used], y[:n_used], color="tab:blue", s=42, label="used data", zorder=3)
x_line = np.linspace(-1.1, 1.1, 240)
true_line = true_w[0] + true_w[1] * x_line
posterior_line = posterior_mean[0] + posterior_mean[1] * x_line
ax.plot(x_line, true_line, color="black", linewidth=2.2, label="true line")
ax.plot(x_line, posterior_line, color="tab:red", linewidth=2.0, label="posterior mean")
for index, weights in enumerate(sampled_w):
label = sample_label if index == 0 else None
ax.plot(
x_line,
weights[0] + weights[1] * x_line,
color="tab:orange",
alpha=0.18,
linewidth=1.15,
label=label,
zorder=1,
)
ax.set_title("Data Space")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_xlim(-1.1, 1.1)
ax.grid(alpha=0.22)
ax.legend(loc="best")
fig.tight_layout()
return fig
def _format_array(value: FloatArray) -> str:
return np.array2string(value, precision=3, suppress_small=True, floatmode="fixed")
def _select_sampling_distribution(
sample_mode: str,
n_used: int,
prior_mean: FloatArray,
prior_cov: FloatArray,
posterior_mean: FloatArray,
posterior_cov: FloatArray,
) -> tuple[FloatArray, FloatArray, str]:
if sample_mode == "posterior samples" and n_used > 0:
return posterior_mean, posterior_cov, "posterior samples"
if sample_mode == "posterior samples":
return prior_mean, prior_cov, "prior samples (N=0 fallback)"
return prior_mean, prior_cov, "prior samples"
def sync_n_slider(n_max: float, n_used: float) -> gr.components.Slider:
max_value = max(1, int(n_max))
current_value = min(max(0, int(n_used)), max_value)
return gr.update(maximum=max_value, value=current_value)
def update(
true_w0: float,
true_w1: float,
sigma: float,
prior_mean_w0: float,
prior_mean_w1: float,
prior_std_w0: float,
prior_std_w1: float,
prior_rho: float,
n_max: float,
n_used: float,
seed: float,
n_lines: float,
sample_mode: str,
show_likelihood: bool,
) -> tuple[Figure, Figure, str, str, str]:
try:
n_max_int = max(1, int(n_max))
n_used_int = min(max(0, int(n_used)), n_max_int)
seed_int = int(seed)
n_lines_int = max(1, int(n_lines))
true_w = np.array([true_w0, true_w1], dtype=float)
prior_mean = np.array([prior_mean_w0, prior_mean_w1], dtype=float)
prior_cov = make_prior_cov(prior_std_w0, prior_std_w1, prior_rho)
x, y = generate_dataset(true_w0, true_w1, sigma, n_max_int, seed_int)
posterior_mean, posterior_cov = compute_posterior(
prior_mean=prior_mean,
prior_cov=prior_cov,
x=x,
y=y,
sigma=sigma,
n_used=n_used_int,
)
sample_mean, sample_cov, sample_label = _select_sampling_distribution(
sample_mode=sample_mode,
n_used=n_used_int,
prior_mean=prior_mean,
prior_cov=prior_cov,
posterior_mean=posterior_mean,
posterior_cov=posterior_cov,
)
sample_seed = seed_int + 10_000 * n_used_int + (1 if sample_label.startswith("posterior") else 0)
sampled_w = sample_weights(sample_mean, sample_cov, n_lines_int, sample_seed)
parameter_fig = plot_parameter_space(
prior_mean=prior_mean,
prior_cov=prior_cov,
posterior_mean=posterior_mean,
posterior_cov=posterior_cov,
true_w=true_w,
x=x,
y=y,
sigma=sigma,
n_used=n_used_int,
show_likelihood=show_likelihood,
)
data_fig = plot_data_space(
x=x,
y=y,
n_used=n_used_int,
true_w=true_w,
posterior_mean=posterior_mean,
sampled_w=sampled_w,
sample_label=sample_label,
)
summary = "\n".join(
[
"### Current State",
f"- 使用データ数: `{n_used_int} / {n_max_int}`",
f"- 直線サンプル元: `{sample_label}`",
f"- 尤度等高線: `{'on' if show_likelihood and n_used_int > 0 else 'off'}`",
]
)
return (
parameter_fig,
data_fig,
_format_array(posterior_mean),
_format_array(posterior_cov),
summary,
)
except (ValueError, np.linalg.LinAlgError) as exc:
raise gr.Error(str(exc)) from exc
def build_app() -> gr.Blocks:
default_n_max = 60
default_n_used = 12
with gr.Blocks(title="Bayesian Linear Regression Visualizer", theme=APP_THEME) as demo:
gr.Markdown(
"""
# Bayesian Linear Regression Visualizer
事前分布・尤度・事後分布の関係と、パラメータ分布からサンプルした回帰直線群の変化を 2 つの図で確認できます。
"""
)
with gr.Row():
with gr.Column(scale=4):
gr.Markdown("## Controls")
with gr.Group():
gr.Markdown("### 真のモデル")
true_w0 = gr.Slider(
-3.0,
3.0,
value=-0.3,
step=0.1,
label="true_w0",
info="真の切片。黒い真の回帰直線の上下位置を決めます。",
)
true_w1 = gr.Slider(
-3.0,
3.0,
value=1.2,
step=0.1,
label="true_w1",
info="真の傾き。黒い真の回帰直線の傾きを決めます。",
)
sigma = gr.Slider(
0.05,
1.2,
value=0.25,
step=0.05,
label="sigma",
info="観測ノイズの標準偏差。大きいほどデータ点が真の直線から散らばります。",
)
with gr.Group():
gr.Markdown("### 事前分布")
prior_mean_w0 = gr.Slider(
-3.0,
3.0,
value=0.0,
step=0.1,
label="prior_mean_w0",
info="事前分布での切片の平均です。",
)
prior_mean_w1 = gr.Slider(
-3.0,
3.0,
value=0.0,
step=0.1,
label="prior_mean_w1",
info="事前分布での傾きの平均です。",
)
prior_std_w0 = gr.Slider(
0.1,
3.0,
value=1.2,
step=0.1,
label="prior_std_w0",
info="事前分布での切片方向の広がりです。大きいほど切片に自信がありません。",
)
prior_std_w1 = gr.Slider(
0.1,
3.0,
value=1.2,
step=0.1,
label="prior_std_w1",
info="事前分布での傾き方向の広がりです。大きいほど傾きに自信がありません。",
)
prior_rho = gr.Slider(
-0.95,
0.95,
value=-0.25,
step=0.05,
label="prior_rho",
info="事前分布での切片と傾きの相関です。0 なら軸に沿い、正負で等高線の傾きが変わります。",
)
with gr.Group():
gr.Markdown("### データと描画")
n_max = gr.Slider(
10,
200,
value=default_n_max,
step=1,
label="N_max",
info="先に生成しておく総データ数です。",
)
n_used = gr.Slider(
0,
default_n_max,
value=default_n_used,
step=1,
label="N",
info="事後分布の計算に使うデータ数です。先頭から N 個だけ使います。",
)
seed = gr.Slider(
0,
9999,
value=7,
step=1,
label="seed",
info="データ生成の乱数シードです。同じ値なら同じデータになります。",
)
n_lines = gr.Slider(
1,
50,
value=20,
step=1,
label="n_lines",
info="分布からサンプルして描く回帰直線の本数です。",
)
sample_mode = gr.Radio(
choices=["prior samples", "posterior samples"],
value="posterior samples",
label="表示モード",
info="回帰直線を事前分布から引くか、事後分布から引くかを選びます。",
)
show_likelihood = gr.Checkbox(
value=True,
label="パラメータ空間に尤度等高線を表示",
info="灰色の点線で尤度の等高線を重ねます。",
)
with gr.Column(scale=6):
with gr.Row():
parameter_plot = gr.Plot(label="パラメータ空間")
data_plot = gr.Plot(label="データ空間")
with gr.Row():
posterior_mean_box = gr.Textbox(label="事後平均 m_N", lines=2)
posterior_cov_box = gr.Textbox(label="事後共分散 S_N", lines=4)
summary_box = gr.Markdown()
inputs = [
true_w0,
true_w1,
sigma,
prior_mean_w0,
prior_mean_w1,
prior_std_w0,
prior_std_w1,
prior_rho,
n_max,
n_used,
seed,
n_lines,
sample_mode,
show_likelihood,
]
outputs = [parameter_plot, data_plot, posterior_mean_box, posterior_cov_box, summary_box]
n_max_event = n_max.change(sync_n_slider, inputs=[n_max, n_used], outputs=n_used)
n_max_event.then(update, inputs=inputs, outputs=outputs)
for component in [
true_w0,
true_w1,
sigma,
prior_mean_w0,
prior_mean_w1,
prior_std_w0,
prior_std_w1,
prior_rho,
n_used,
seed,
n_lines,
sample_mode,
show_likelihood,
]:
component.change(update, inputs=inputs, outputs=outputs)
demo.load(update, inputs=inputs, outputs=outputs)
return demo
def main() -> None:
parser = argparse.ArgumentParser(description="Launch the Bayesian linear regression visualizer.")
parser.add_argument("--server-name", default=None, help="Host for the Gradio server.")
parser.add_argument("--server-port", type=int, default=None, help="Port for the Gradio server.")
parser.add_argument("--share", action="store_true", help="Create a public Gradio share link.")
parser.add_argument("--browser", action="store_true", help="Automatically open the app in a browser.")
args = parser.parse_args()
app = build_app()
launch_kwargs: dict[str, object] = {
"share": args.share,
"inbrowser": args.browser,
}
if args.server_name is not None:
launch_kwargs["server_name"] = args.server_name
if args.server_port is not None:
launch_kwargs["server_port"] = args.server_port
app.queue().launch(**launch_kwargs)
if __name__ == "__main__":
main()