taxfree-python commited on
Commit
6421da7
·
0 Parent(s):

Prepare Hugging Face Space deployment

Browse files
Files changed (7) hide show
  1. .gitignore +14 -0
  2. .python-version +1 -0
  3. README.md +51 -0
  4. main.py +514 -0
  5. pyproject.toml +11 -0
  6. requirements.txt +2 -0
  7. uv.lock +0 -0
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ # Local tooling
13
+ .serena/
14
+ .DS_Store
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
README.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Bayesian Linear Regression Visualizer
3
+ emoji: 📈
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: "5.50.0"
8
+ python_version: "3.11"
9
+ app_file: main.py
10
+ fullWidth: true
11
+ pinned: false
12
+ ---
13
+
14
+ # Bayes Study
15
+
16
+ ベイズ線形回帰の事前分布・尤度・事後分布を対話的に確認できる Gradio アプリです。
17
+ パラメータ空間 `(w0, w1)` の等高線と、データ空間での回帰直線群を並べて表示します。
18
+
19
+ ## セットアップ
20
+
21
+ ```bash
22
+ uv sync
23
+ ```
24
+
25
+ Hugging Face Spaces では `README.md` の frontmatter と `requirements.txt` を使ってデプロイされます。
26
+
27
+ ## 起動
28
+
29
+ ```bash
30
+ uv run python main.py
31
+ ```
32
+
33
+ ブラウザを自動で開く場合:
34
+
35
+ ```bash
36
+ uv run python main.py --browser
37
+ ```
38
+
39
+ ホストやポートを指定する場合:
40
+
41
+ ```bash
42
+ uv run python main.py --server-name 0.0.0.0 --server-port 7860
43
+ ```
44
+
45
+ ## アプリでできること
46
+
47
+ - 事前平均、事前標準偏差、相関係数からガウス事前分布を設定
48
+ - 真の切片、真の傾き、観測ノイズからデータを生成
49
+ - 使用サンプル数 `N` を変えて事後分布の収束を確認
50
+ - prior / posterior からサンプルした回帰直線群を比較
51
+ - 尤度等高線をパラメータ空間に重ねて表示
main.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+
5
+ import gradio as gr
6
+ import matplotlib
7
+ import numpy as np
8
+ from matplotlib.figure import Figure
9
+ from matplotlib.lines import Line2D
10
+ from numpy.typing import NDArray
11
+
12
+ # Use a headless backend so the app also works in terminal-only environments.
13
+ matplotlib.use("Agg")
14
+ import matplotlib.pyplot as plt
15
+
16
+ FloatArray = NDArray[np.float64]
17
+ APP_THEME = gr.themes.Soft(
18
+ primary_hue="sky",
19
+ secondary_hue="amber",
20
+ neutral_hue="slate",
21
+ )
22
+
23
+
24
+ def make_prior_cov(std_w0: float, std_w1: float, rho: float) -> FloatArray:
25
+ if std_w0 <= 0 or std_w1 <= 0:
26
+ raise ValueError("事前標準偏差は正の値にしてください。")
27
+ if not (-0.999 < rho < 0.999):
28
+ raise ValueError("事前相関係数 rho は -1 より大きく 1 より小さい値にしてください。")
29
+
30
+ cov = np.array(
31
+ [
32
+ [std_w0**2, rho * std_w0 * std_w1],
33
+ [rho * std_w0 * std_w1, std_w1**2],
34
+ ],
35
+ dtype=float,
36
+ )
37
+ sign, _ = np.linalg.slogdet(cov)
38
+ if sign <= 0:
39
+ raise ValueError("事前共分散行列が正定値ではありません。標準偏差と相関係数を見直してください。")
40
+ return cov
41
+
42
+
43
+ def generate_dataset(
44
+ true_w0: float,
45
+ true_w1: float,
46
+ sigma: float,
47
+ n_max: int,
48
+ seed: int,
49
+ ) -> tuple[FloatArray, FloatArray]:
50
+ if n_max < 1:
51
+ raise ValueError("N_max は 1 以上にしてください。")
52
+ if sigma <= 0:
53
+ raise ValueError("観測ノイズ標準偏差 sigma は正の値にしてください。")
54
+
55
+ rng = np.random.default_rng(seed)
56
+ x = rng.uniform(-1.0, 1.0, size=n_max)
57
+ noise = rng.normal(0.0, sigma, size=n_max)
58
+ y = true_w0 + true_w1 * x + noise
59
+ return x.astype(float), y.astype(float)
60
+
61
+
62
+ def compute_posterior(
63
+ prior_mean: FloatArray,
64
+ prior_cov: FloatArray,
65
+ x: FloatArray,
66
+ y: FloatArray,
67
+ sigma: float,
68
+ n_used: int,
69
+ ) -> tuple[FloatArray, FloatArray]:
70
+ n_used = int(np.clip(n_used, 0, len(x)))
71
+ if n_used == 0:
72
+ return prior_mean.copy(), prior_cov.copy()
73
+
74
+ phi = np.column_stack([np.ones(n_used), x[:n_used]])
75
+ y_used = y[:n_used]
76
+ prior_precision = np.linalg.inv(prior_cov)
77
+ posterior_precision = prior_precision + (phi.T @ phi) / (sigma**2)
78
+ posterior_cov = np.linalg.inv(posterior_precision)
79
+ rhs = prior_precision @ prior_mean + (phi.T @ y_used) / (sigma**2)
80
+ posterior_mean = posterior_cov @ rhs
81
+ return posterior_mean, posterior_cov
82
+
83
+
84
+ def sample_weights(mean: FloatArray, cov: FloatArray, n_lines: int, seed: int) -> FloatArray:
85
+ if n_lines < 1:
86
+ raise ValueError("表示する直線本数 n_lines は 1 以上にしてください。")
87
+
88
+ rng = np.random.default_rng(seed)
89
+ return rng.multivariate_normal(mean=mean, cov=cov, size=n_lines).astype(float)
90
+
91
+
92
+ def _gaussian_density_grid(
93
+ mean: FloatArray,
94
+ cov: FloatArray,
95
+ grid_w0: FloatArray,
96
+ grid_w1: FloatArray,
97
+ ) -> FloatArray:
98
+ cov_inv = np.linalg.inv(cov)
99
+ sign, logdet = np.linalg.slogdet(cov)
100
+ if sign <= 0:
101
+ raise ValueError("共分散行列が正定値ではありません。")
102
+
103
+ position = np.stack([grid_w0, grid_w1], axis=-1)
104
+ diff = position - mean
105
+ quad = np.einsum("...i,ij,...j->...", diff, cov_inv, diff)
106
+ log_density = -0.5 * (2 * np.log(2 * np.pi) + logdet + quad)
107
+ return np.exp(log_density)
108
+
109
+
110
+ def _likelihood_surface(
111
+ grid_w0: FloatArray,
112
+ grid_w1: FloatArray,
113
+ x_used: FloatArray,
114
+ y_used: FloatArray,
115
+ sigma: float,
116
+ ) -> FloatArray:
117
+ predictions = grid_w0[..., None] + grid_w1[..., None] * x_used
118
+ residuals = y_used - predictions
119
+ rss = np.sum(residuals**2, axis=-1)
120
+ log_likelihood = -0.5 * rss / (sigma**2)
121
+ return np.exp(log_likelihood - np.max(log_likelihood))
122
+
123
+
124
+ def _contour_levels(surface: FloatArray) -> FloatArray:
125
+ peak = float(np.max(surface))
126
+ if not np.isfinite(peak) or peak <= 0:
127
+ return np.array([1.0], dtype=float)
128
+
129
+ relative_levels = np.exp(-0.5 * np.array([7.0, 4.5, 2.5, 1.0, 0.3], dtype=float))
130
+ levels = np.sort(peak * relative_levels)
131
+ return np.unique(np.clip(levels, peak * 1e-6, peak * 0.999))
132
+
133
+
134
+ def _parameter_limits(
135
+ prior_mean: FloatArray,
136
+ prior_cov: FloatArray,
137
+ posterior_mean: FloatArray,
138
+ posterior_cov: FloatArray,
139
+ true_w: FloatArray,
140
+ ) -> tuple[tuple[float, float], tuple[float, float]]:
141
+ prior_std = 4.0 * np.sqrt(np.diag(prior_cov))
142
+ posterior_std = 4.0 * np.sqrt(np.diag(posterior_cov))
143
+
144
+ lower = np.vstack(
145
+ [
146
+ prior_mean - prior_std,
147
+ posterior_mean - posterior_std,
148
+ true_w,
149
+ ]
150
+ ).min(axis=0)
151
+ upper = np.vstack(
152
+ [
153
+ prior_mean + prior_std,
154
+ posterior_mean + posterior_std,
155
+ true_w,
156
+ ]
157
+ ).max(axis=0)
158
+ span = np.maximum(upper - lower, np.array([1.0, 1.0], dtype=float))
159
+ padding = 0.15 * span
160
+ w0_limits = (float(lower[0] - padding[0]), float(upper[0] + padding[0]))
161
+ w1_limits = (float(lower[1] - padding[1]), float(upper[1] + padding[1]))
162
+ return w0_limits, w1_limits
163
+
164
+
165
+ def plot_parameter_space(
166
+ prior_mean: FloatArray,
167
+ prior_cov: FloatArray,
168
+ posterior_mean: FloatArray,
169
+ posterior_cov: FloatArray,
170
+ true_w: FloatArray,
171
+ x: FloatArray,
172
+ y: FloatArray,
173
+ sigma: float,
174
+ n_used: int,
175
+ show_likelihood: bool,
176
+ ) -> Figure:
177
+ w0_limits, w1_limits = _parameter_limits(prior_mean, prior_cov, posterior_mean, posterior_cov, true_w)
178
+ w0_grid = np.linspace(*w0_limits, 180)
179
+ w1_grid = np.linspace(*w1_limits, 180)
180
+ grid_w0, grid_w1 = np.meshgrid(w0_grid, w1_grid)
181
+
182
+ prior_density = _gaussian_density_grid(prior_mean, prior_cov, grid_w0, grid_w1)
183
+ posterior_density = _gaussian_density_grid(posterior_mean, posterior_cov, grid_w0, grid_w1)
184
+
185
+ fig, ax = plt.subplots(figsize=(6.2, 5.2))
186
+ if show_likelihood and n_used > 0:
187
+ likelihood = _likelihood_surface(grid_w0, grid_w1, x[:n_used], y[:n_used], sigma)
188
+ ax.contour(
189
+ grid_w0,
190
+ grid_w1,
191
+ likelihood,
192
+ levels=_contour_levels(likelihood),
193
+ colors="0.55",
194
+ linestyles="dotted",
195
+ linewidths=1.1,
196
+ )
197
+
198
+ ax.contour(
199
+ grid_w0,
200
+ grid_w1,
201
+ prior_density,
202
+ levels=_contour_levels(prior_density),
203
+ colors="tab:blue",
204
+ linestyles="dashed",
205
+ linewidths=1.5,
206
+ )
207
+ ax.contour(
208
+ grid_w0,
209
+ grid_w1,
210
+ posterior_density,
211
+ levels=_contour_levels(posterior_density),
212
+ colors="tab:red",
213
+ linewidths=1.8,
214
+ )
215
+ ax.scatter(true_w[0], true_w[1], marker="*", s=140, color="black", zorder=5)
216
+ ax.scatter(posterior_mean[0], posterior_mean[1], s=44, color="tab:red", zorder=5)
217
+
218
+ handles = [
219
+ Line2D([0], [0], color="tab:blue", linestyle="dashed", linewidth=1.5, label="prior"),
220
+ Line2D([0], [0], color="tab:red", linewidth=1.8, label="posterior"),
221
+ Line2D([0], [0], marker="o", color="tab:red", linewidth=0, markersize=7, label="posterior mean"),
222
+ Line2D([0], [0], marker="*", color="black", linewidth=0, markersize=10, label="true parameter"),
223
+ ]
224
+ if show_likelihood and n_used > 0:
225
+ handles.insert(
226
+ 0,
227
+ Line2D([0], [0], color="0.55", linestyle="dotted", linewidth=1.2, label="likelihood"),
228
+ )
229
+
230
+ ax.set_title("Parameter Space")
231
+ ax.set_xlabel(r"$w_0$")
232
+ ax.set_ylabel(r"$w_1$")
233
+ ax.set_xlim(*w0_limits)
234
+ ax.set_ylim(*w1_limits)
235
+ ax.grid(alpha=0.22)
236
+ ax.legend(handles=handles, loc="best")
237
+ fig.tight_layout()
238
+ return fig
239
+
240
+
241
+ def plot_data_space(
242
+ x: FloatArray,
243
+ y: FloatArray,
244
+ n_used: int,
245
+ true_w: FloatArray,
246
+ posterior_mean: FloatArray,
247
+ sampled_w: FloatArray,
248
+ sample_label: str,
249
+ ) -> Figure:
250
+ fig, ax = plt.subplots(figsize=(6.2, 5.2))
251
+
252
+ if n_used < len(x):
253
+ ax.scatter(x[n_used:], y[n_used:], color="0.83", s=36, label="unused data", zorder=2)
254
+ if n_used > 0:
255
+ ax.scatter(x[:n_used], y[:n_used], color="tab:blue", s=42, label="used data", zorder=3)
256
+
257
+ x_line = np.linspace(-1.1, 1.1, 240)
258
+ true_line = true_w[0] + true_w[1] * x_line
259
+ posterior_line = posterior_mean[0] + posterior_mean[1] * x_line
260
+
261
+ ax.plot(x_line, true_line, color="black", linewidth=2.2, label="true line")
262
+ ax.plot(x_line, posterior_line, color="tab:red", linewidth=2.0, label="posterior mean")
263
+
264
+ for index, weights in enumerate(sampled_w):
265
+ label = sample_label if index == 0 else None
266
+ ax.plot(
267
+ x_line,
268
+ weights[0] + weights[1] * x_line,
269
+ color="tab:orange",
270
+ alpha=0.18,
271
+ linewidth=1.15,
272
+ label=label,
273
+ zorder=1,
274
+ )
275
+
276
+ ax.set_title("Data Space")
277
+ ax.set_xlabel("x")
278
+ ax.set_ylabel("y")
279
+ ax.set_xlim(-1.1, 1.1)
280
+ ax.grid(alpha=0.22)
281
+ ax.legend(loc="best")
282
+ fig.tight_layout()
283
+ return fig
284
+
285
+
286
+ def _format_array(value: FloatArray) -> str:
287
+ return np.array2string(value, precision=3, suppress_small=True, floatmode="fixed")
288
+
289
+
290
+ def _select_sampling_distribution(
291
+ sample_mode: str,
292
+ n_used: int,
293
+ prior_mean: FloatArray,
294
+ prior_cov: FloatArray,
295
+ posterior_mean: FloatArray,
296
+ posterior_cov: FloatArray,
297
+ ) -> tuple[FloatArray, FloatArray, str]:
298
+ if sample_mode == "posterior samples" and n_used > 0:
299
+ return posterior_mean, posterior_cov, "posterior samples"
300
+ if sample_mode == "posterior samples":
301
+ return prior_mean, prior_cov, "prior samples (N=0 fallback)"
302
+ return prior_mean, prior_cov, "prior samples"
303
+
304
+
305
+ def sync_n_slider(n_max: float, n_used: float) -> gr.components.Slider:
306
+ max_value = max(1, int(n_max))
307
+ current_value = min(max(0, int(n_used)), max_value)
308
+ return gr.update(maximum=max_value, value=current_value)
309
+
310
+
311
+ def update(
312
+ true_w0: float,
313
+ true_w1: float,
314
+ sigma: float,
315
+ prior_mean_w0: float,
316
+ prior_mean_w1: float,
317
+ prior_std_w0: float,
318
+ prior_std_w1: float,
319
+ prior_rho: float,
320
+ n_max: float,
321
+ n_used: float,
322
+ seed: float,
323
+ n_lines: float,
324
+ sample_mode: str,
325
+ show_likelihood: bool,
326
+ ) -> tuple[Figure, Figure, str, str, str]:
327
+ try:
328
+ n_max_int = max(1, int(n_max))
329
+ n_used_int = min(max(0, int(n_used)), n_max_int)
330
+ seed_int = int(seed)
331
+ n_lines_int = max(1, int(n_lines))
332
+
333
+ true_w = np.array([true_w0, true_w1], dtype=float)
334
+ prior_mean = np.array([prior_mean_w0, prior_mean_w1], dtype=float)
335
+ prior_cov = make_prior_cov(prior_std_w0, prior_std_w1, prior_rho)
336
+
337
+ x, y = generate_dataset(true_w0, true_w1, sigma, n_max_int, seed_int)
338
+ posterior_mean, posterior_cov = compute_posterior(
339
+ prior_mean=prior_mean,
340
+ prior_cov=prior_cov,
341
+ x=x,
342
+ y=y,
343
+ sigma=sigma,
344
+ n_used=n_used_int,
345
+ )
346
+ sample_mean, sample_cov, sample_label = _select_sampling_distribution(
347
+ sample_mode=sample_mode,
348
+ n_used=n_used_int,
349
+ prior_mean=prior_mean,
350
+ prior_cov=prior_cov,
351
+ posterior_mean=posterior_mean,
352
+ posterior_cov=posterior_cov,
353
+ )
354
+ sample_seed = seed_int + 10_000 * n_used_int + (1 if sample_label.startswith("posterior") else 0)
355
+ sampled_w = sample_weights(sample_mean, sample_cov, n_lines_int, sample_seed)
356
+
357
+ parameter_fig = plot_parameter_space(
358
+ prior_mean=prior_mean,
359
+ prior_cov=prior_cov,
360
+ posterior_mean=posterior_mean,
361
+ posterior_cov=posterior_cov,
362
+ true_w=true_w,
363
+ x=x,
364
+ y=y,
365
+ sigma=sigma,
366
+ n_used=n_used_int,
367
+ show_likelihood=show_likelihood,
368
+ )
369
+ data_fig = plot_data_space(
370
+ x=x,
371
+ y=y,
372
+ n_used=n_used_int,
373
+ true_w=true_w,
374
+ posterior_mean=posterior_mean,
375
+ sampled_w=sampled_w,
376
+ sample_label=sample_label,
377
+ )
378
+
379
+ summary = "\n".join(
380
+ [
381
+ "### Current State",
382
+ f"- 使用データ数: `{n_used_int} / {n_max_int}`",
383
+ f"- 直線サンプル元: `{sample_label}`",
384
+ f"- 尤度等高線: `{'on' if show_likelihood and n_used_int > 0 else 'off'}`",
385
+ ]
386
+ )
387
+ return (
388
+ parameter_fig,
389
+ data_fig,
390
+ _format_array(posterior_mean),
391
+ _format_array(posterior_cov),
392
+ summary,
393
+ )
394
+ except (ValueError, np.linalg.LinAlgError) as exc:
395
+ raise gr.Error(str(exc)) from exc
396
+
397
+
398
+ def build_app() -> gr.Blocks:
399
+ default_n_max = 60
400
+ default_n_used = 12
401
+
402
+ with gr.Blocks(title="Bayesian Linear Regression Visualizer", theme=APP_THEME) as demo:
403
+ gr.Markdown(
404
+ """
405
+ # Bayesian Linear Regression Visualizer
406
+ 事前分布・尤度・事後分布の関係と、パラメータ分布からサンプルした回帰直線群の変化を 2 つの図で確認できます。
407
+ """
408
+ )
409
+
410
+ with gr.Row():
411
+ with gr.Column(scale=4):
412
+ gr.Markdown("## Controls")
413
+
414
+ with gr.Group():
415
+ gr.Markdown("### 真のモデル")
416
+ true_w0 = gr.Slider(-3.0, 3.0, value=-0.3, step=0.1, label="true_w0")
417
+ true_w1 = gr.Slider(-3.0, 3.0, value=1.2, step=0.1, label="true_w1")
418
+ sigma = gr.Slider(0.05, 1.2, value=0.25, step=0.05, label="sigma")
419
+
420
+ with gr.Group():
421
+ gr.Markdown("### 事前分布")
422
+ prior_mean_w0 = gr.Slider(-3.0, 3.0, value=0.0, step=0.1, label="prior_mean_w0")
423
+ prior_mean_w1 = gr.Slider(-3.0, 3.0, value=0.0, step=0.1, label="prior_mean_w1")
424
+ prior_std_w0 = gr.Slider(0.1, 3.0, value=1.2, step=0.1, label="prior_std_w0")
425
+ prior_std_w1 = gr.Slider(0.1, 3.0, value=1.2, step=0.1, label="prior_std_w1")
426
+ prior_rho = gr.Slider(-0.95, 0.95, value=-0.25, step=0.05, label="prior_rho")
427
+
428
+ with gr.Group():
429
+ gr.Markdown("### データと描画")
430
+ n_max = gr.Slider(10, 200, value=default_n_max, step=1, label="N_max")
431
+ n_used = gr.Slider(0, default_n_max, value=default_n_used, step=1, label="N")
432
+ seed = gr.Slider(0, 9999, value=7, step=1, label="seed")
433
+ n_lines = gr.Slider(1, 50, value=20, step=1, label="n_lines")
434
+ sample_mode = gr.Radio(
435
+ choices=["prior samples", "posterior samples"],
436
+ value="posterior samples",
437
+ label="表示モード",
438
+ )
439
+ show_likelihood = gr.Checkbox(value=True, label="パラメータ空間に尤度等高線を表示")
440
+
441
+ with gr.Column(scale=6):
442
+ with gr.Row():
443
+ parameter_plot = gr.Plot(label="パラメータ空間")
444
+ data_plot = gr.Plot(label="データ空間")
445
+ with gr.Row():
446
+ posterior_mean_box = gr.Textbox(label="事後平均 m_N", lines=2)
447
+ posterior_cov_box = gr.Textbox(label="事後共分散 S_N", lines=4)
448
+ summary_box = gr.Markdown()
449
+
450
+ inputs = [
451
+ true_w0,
452
+ true_w1,
453
+ sigma,
454
+ prior_mean_w0,
455
+ prior_mean_w1,
456
+ prior_std_w0,
457
+ prior_std_w1,
458
+ prior_rho,
459
+ n_max,
460
+ n_used,
461
+ seed,
462
+ n_lines,
463
+ sample_mode,
464
+ show_likelihood,
465
+ ]
466
+ outputs = [parameter_plot, data_plot, posterior_mean_box, posterior_cov_box, summary_box]
467
+
468
+ n_max_event = n_max.change(sync_n_slider, inputs=[n_max, n_used], outputs=n_used)
469
+ n_max_event.then(update, inputs=inputs, outputs=outputs)
470
+
471
+ for component in [
472
+ true_w0,
473
+ true_w1,
474
+ sigma,
475
+ prior_mean_w0,
476
+ prior_mean_w1,
477
+ prior_std_w0,
478
+ prior_std_w1,
479
+ prior_rho,
480
+ n_used,
481
+ seed,
482
+ n_lines,
483
+ sample_mode,
484
+ show_likelihood,
485
+ ]:
486
+ component.change(update, inputs=inputs, outputs=outputs)
487
+
488
+ demo.load(update, inputs=inputs, outputs=outputs)
489
+
490
+ return demo
491
+
492
+
493
+ def main() -> None:
494
+ parser = argparse.ArgumentParser(description="Launch the Bayesian linear regression visualizer.")
495
+ parser.add_argument("--server-name", default=None, help="Host for the Gradio server.")
496
+ parser.add_argument("--server-port", type=int, default=None, help="Port for the Gradio server.")
497
+ parser.add_argument("--share", action="store_true", help="Create a public Gradio share link.")
498
+ parser.add_argument("--browser", action="store_true", help="Automatically open the app in a browser.")
499
+ args = parser.parse_args()
500
+
501
+ app = build_app()
502
+ launch_kwargs: dict[str, object] = {
503
+ "share": args.share,
504
+ "inbrowser": args.browser,
505
+ }
506
+ if args.server_name is not None:
507
+ launch_kwargs["server_name"] = args.server_name
508
+ if args.server_port is not None:
509
+ launch_kwargs["server_port"] = args.server_port
510
+ app.queue().launch(**launch_kwargs)
511
+
512
+
513
+ if __name__ == "__main__":
514
+ main()
pyproject.toml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "bayes-study"
3
+ version = "0.1.0"
4
+ description = "Interactive Bayesian linear regression visualizer built with Gradio."
5
+ readme = "README.md"
6
+ requires-python = ">=3.11,<3.13"
7
+ dependencies = [
8
+ "gradio>=5.25.0,<6",
9
+ "matplotlib>=3.9.0,<4",
10
+ "numpy>=2.1.0,<3",
11
+ ]
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ numpy==2.4.4
2
+ matplotlib==3.10.8
uv.lock ADDED
The diff for this file is too large to render. See raw diff