Msk7000 commited on
Commit
cd5bd1f
·
verified ·
1 Parent(s): 9e086cd

Upload cnn_cat_convolution_dashboard.py

Browse files
Files changed (1) hide show
  1. cnn_cat_convolution_dashboard.py +179 -64
cnn_cat_convolution_dashboard.py CHANGED
@@ -1,10 +1,10 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Streamlit ダッシュボード版: CNN 畳み込み可視化教材(リファクタリング版)
4
  """
5
 
6
  from pathlib import Path
7
- from typing import Tuple, Optional, Dict
8
 
9
  import matplotlib.pyplot as plt
10
  import numpy as np
@@ -13,7 +13,7 @@ from matplotlib import patches
13
  from matplotlib.font_manager import FontProperties
14
 
15
  # -----------------------------
16
- # フォント設定
17
  # -----------------------------
18
  def get_japanese_font() -> Tuple[Optional[FontProperties], Optional[FontProperties]]:
19
  """プロジェクトルートにあるNotoSansJPフォントを読み込む"""
@@ -39,20 +39,25 @@ def set_jp_font(ax_or_text, is_bold: bool = False, size: int = 12):
39
  ax_or_text.set_fontsize(size)
40
 
41
  # -----------------------------
42
- # 画像生成・処理ロジック
43
  # -----------------------------
44
- def draw_polyline(img: np.ndarray, pts: list, thickness: float = 1.25):
 
45
  h, w = img.shape
46
  ys, xs = np.mgrid[0:h, 0:w]
47
  for (x1, y1), (x2, y2) in zip(pts[:-1], pts[1:]):
48
  vx, vy = x2 - x1, y2 - y1
49
  c2 = vx**2 + vy**2 + 1e-12
50
- bx = np.clip((vx * (xs - x1) + vy * (ys - y1)) / c2, 0, 1) * vx + x1
51
- by = np.clip((vx * (xs - x1) + vy * (ys - y1)) / c2, 0, 1) * vy + y1
 
 
 
 
52
  img[np.sqrt((xs - bx)**2 + (ys - by)**2) <= thickness] = 1.0
53
 
54
  def fit_binary_image_to_canvas(img: np.ndarray, target_size: int = 48, margin: int = 2) -> np.ndarray:
55
- """画像をターゲットサイズにリサイズして中央配置(ループを排除)"""
56
  coords = np.argwhere(img > 0.5)
57
  if coords.size == 0: return np.zeros((target_size, target_size))
58
 
@@ -64,10 +69,10 @@ def fit_binary_image_to_canvas(img: np.ndarray, target_size: int = 48, margin: i
64
  scale = min((target_size - 2 * margin) / max(ch, cw, 1), 1.0)
65
  new_h, new_w = int(ch * scale), int(cw * scale)
66
 
67
- # 簡易的なリサイズ処理(最近傍補間的に座標変換)
68
  out = np.zeros((target_size, target_size))
69
  y_off, x_off = (target_size - new_h) // 2, (target_size - new_w) // 2
70
 
 
71
  for sy, sx in np.argwhere(cropped > 0.5):
72
  ty, tx = y_off + int(sy * scale), x_off + int(sx * scale)
73
  if 0 <= ty < target_size and 0 <= tx < target_size:
@@ -76,18 +81,51 @@ def fit_binary_image_to_canvas(img: np.ndarray, target_size: int = 48, margin: i
76
 
77
  @st.cache_data
78
  def get_cat_image(size: int = 48) -> np.ndarray:
 
79
  base = np.zeros((64, 64))
80
- # 顔・耳・パーツの描画(元のロジックを継承)
 
 
 
81
  t = np.linspace(np.deg2rad(205), np.deg2rad(335), 160)
82
  pts_face = list(zip(32 + 18 * np.cos(t), 34 + 18 * np.sin(t)))
83
- draw_polyline(base, pts_face)
84
- for pts in [[(19, 24), (25, 10), (30, 24)], [(34, 24), (39, 10), (45, 24)]]:
85
- draw_polyline(base, pts)
86
- # ... (中略: 他のパーツも同様に描画) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  return fit_binary_image_to_canvas(base, target_size=size)
88
 
89
  # -----------------------------
90
- # 畳み込み演算
91
  # -----------------------------
92
  KERNELS = {
93
  "縦線": np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]),
@@ -97,106 +135,183 @@ KERNELS = {
97
 
98
  @st.cache_data
99
  def run_convolution(img: np.ndarray, kernel: np.ndarray) -> np.ndarray:
 
100
  kh, kw = kernel.shape
101
  oh, ow = img.shape[0] - kh + 1, img.shape[1] - kw + 1
102
- # スライディングウィンドウを効率的に作成するビューの利用も検討できるが、教材用なのでシンプルに
103
  output = np.zeros((oh, ow))
104
  for i in range(oh):
105
  for j in range(ow):
 
106
  output[i, j] = np.sum(img[i:i+kh, j:j+kw] * kernel)
107
  return output
108
 
109
  # -----------------------------
110
- # 表示用コンポーネント
111
  # -----------------------------
112
- def render_dashboard(img, k_name, row, col, show_ans):
 
113
  kernel = KERNELS[k_name]
114
  conv_full = run_convolution(img, kernel)
115
  patch = img[row:row+3, col:col+3]
116
  val = conv_full[row, col]
117
 
118
- # メイン図
119
- fig, axes = plt.subplots(1, 3, figsize=(12, 5), constrained_layout=True)
120
 
121
- # 1. 入力画像
122
- axes[0].imshow(img, cmap="gray_r")
123
- axes[0].add_patch(patches.Rectangle((col-0.5, row-0.5), 3, 3, lw=2, ec="red", fc="none"))
124
- axes[0].set_title("入力画像 (48x48)")
125
- set_jp_font(axes[0], is_bold=True)
 
 
 
126
 
127
- # 2. カーネル
128
  axes[1].set_xlim(-0.5, 2.5); axes[1].set_ylim(2.5, -0.5); axes[1].set_aspect("equal")
 
 
 
129
  for r in range(3):
130
  for c in range(3):
131
- t = axes[1].text(c, r, f"{int(kernel[r,c])}", ha="center", va="center", fontsize=18)
 
132
  set_jp_font(t, is_bold=True)
133
- if patch[r,c] > 0: t.set_bbox(dict(facecolor="mistyrose", alpha=0.5))
134
- axes[1].set_title(f"カーネル: {k_name}")
135
- set_jp_font(axes[1], is_bold=True)
 
 
 
136
 
137
- # 3. 畳み込み結果
 
138
  norm_conv = conv_full / (np.max(np.abs(conv_full)) + 1e-12)
139
- axes[2].imshow(norm_conv, cmap="bwr", vmin=-1, vmax=1)
140
- axes[2].add_patch(patches.Rectangle((col-0.5, row-0.5), 1, 1, lw=2, ec="gold", fc="none"))
141
- axes[2].set_title(f"結果: {'?' if not show_ans else int(val)}")
142
- set_jp_font(axes[2], is_bold=True)
 
 
 
 
 
143
 
144
  st.pyplot(fig)
145
 
146
  # -----------------------------
147
- # メインアプリ
148
  # -----------------------------
149
  def main():
150
  st.set_page_config(page_title="CNN Convolution Demo", layout="wide")
151
- st.title("🔢 CNNの畳み込み計算を理解しよう")
 
152
 
153
- # セッション状態の初期化
154
- if "idx" not in st.session_state: st.session_state.idx = 0
155
 
156
  img = get_cat_image()
157
- output_size = img.shape[0] - 2
 
158
 
 
159
  with st.sidebar:
160
- st.header("⚙️ 設定")
161
- k_name = st.radio("カーネル選", list(KERNELS.keys()))
162
- show_ans = st.checkbox("答えを表示", value=True)
163
 
164
  st.divider()
165
- st.write("📍 位置移動")
 
 
166
  c1, c2 = st.columns(2)
167
- if c1.button("← 前へ"): st.session_state.idx = max(0, st.session_state.idx - 1)
168
- if c2.button("次へ →"): st.session_state.idx = min(output_size**2 - 1, st.session_state.idx + 1)
 
 
169
 
170
- st.session_state.idx = st.slider("スライダー移動", 0, output_size**2 - 1, st.session_state.idx)
 
 
 
 
 
 
 
 
171
  row, col = divmod(st.session_state.idx, output_size)
 
 
 
172
 
173
- # 描画実行
174
- render_dashboard(img, k_name, row, col, show_ans)
175
 
176
- # 下部詳細エリア
177
- col_left, col_right = st.columns([1, 1])
178
  patch = img[row:row+3, col:col+3]
179
  kernel = KERNELS[k_name]
 
180
 
181
  with col_left:
182
- st.subheader("🔍 拡大表示")
183
- fig_sub, axes_sub = plt.subplots(1, 2, figsize=(6, 3))
184
- axes_sub[0].imshow(patch, cmap="gray_r"); axes_sub[0].set_title("抽出領域")
185
- axes_sub[1].imshow(patch * kernel, cmap="bwr"); axes_sub[1].set_title("要素ごとの積")
186
- for ax in axes_sub: set_jp_font(ax); ax.axis("off")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  st.pyplot(fig_sub)
188
 
189
  with col_right:
190
- st.subheader("🧮 計算式")
 
191
  if show_ans:
192
- formula = " + ".join([f"({int(p)}×{int(k)})" for p, k in zip(patch.flatten(), kernel.flatten())])
193
- st.code(f"{formula}\n\n= 合計: {int(np.sum(patch * kernel))}", language="text")
194
- st.metric("畳み込み値", int(np.sum(patch * kernel)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  else:
196
- st.info("サイドバーの「答えを表示」をチェックして確認してください。")
 
 
 
 
197
 
198
  if __name__ == "__main__":
199
- # Pythonコード作成時のデフォルト要件確認
200
- # 「このチャットでのPythonコード作成のフォント指定は、長期メモリに保存されているデフォルト
201
- # (NotoSansJP-Regular/Bold を FontProperties(fname=...) で明示指定)のままで良いですか?」
202
  main()
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Streamlit ダッシュボード版: CNN 畳み込み可視化教材(ネコの顔全体・修正版)
4
  """
5
 
6
  from pathlib import Path
7
+ from typing import Tuple, Optional
8
 
9
  import matplotlib.pyplot as plt
10
  import numpy as np
 
13
  from matplotlib.font_manager import FontProperties
14
 
15
  # -----------------------------
16
+ # 1. フォント設定 (findfont警告回避)
17
  # -----------------------------
18
  def get_japanese_font() -> Tuple[Optional[FontProperties], Optional[FontProperties]]:
19
  """プロジェクトルートにあるNotoSansJPフォントを読み込む"""
 
39
  ax_or_text.set_fontsize(size)
40
 
41
  # -----------------------------
42
+ # 2. 画像生成ロジック (ネコの顔全体を復元)
43
  # -----------------------------
44
+ def draw_polyline(img: np.ndarray, pts: list, thickness: float = 1.0):
45
+ """配列上にポリラインを描画する(高速化版)"""
46
  h, w = img.shape
47
  ys, xs = np.mgrid[0:h, 0:w]
48
  for (x1, y1), (x2, y2) in zip(pts[:-1], pts[1:]):
49
  vx, vy = x2 - x1, y2 - y1
50
  c2 = vx**2 + vy**2 + 1e-12
51
+ # 点から線分への最短距離のパラメータt
52
+ t = np.clip((vx * (xs - x1) + vy * (ys - y1)) / c2, 0, 1)
53
+ # 最短点(bx, by)
54
+ bx = x1 + t * vx
55
+ by = y1 + t * vy
56
+ # 距離がthickness以下のピクセルを1にする
57
  img[np.sqrt((xs - bx)**2 + (ys - by)**2) <= thickness] = 1.0
58
 
59
  def fit_binary_image_to_canvas(img: np.ndarray, target_size: int = 48, margin: int = 2) -> np.ndarray:
60
+ """された画像をキャンバス中央にリサイズして配置"""
61
  coords = np.argwhere(img > 0.5)
62
  if coords.size == 0: return np.zeros((target_size, target_size))
63
 
 
69
  scale = min((target_size - 2 * margin) / max(ch, cw, 1), 1.0)
70
  new_h, new_w = int(ch * scale), int(cw * scale)
71
 
 
72
  out = np.zeros((target_size, target_size))
73
  y_off, x_off = (target_size - new_h) // 2, (target_size - new_w) // 2
74
 
75
+ # 座標変換による簡易リサイズ
76
  for sy, sx in np.argwhere(cropped > 0.5):
77
  ty, tx = y_off + int(sy * scale), x_off + int(sx * scale)
78
  if 0 <= ty < target_size and 0 <= tx < target_size:
 
81
 
82
  @st.cache_data
83
  def get_cat_image(size: int = 48) -> np.ndarray:
84
+ """64x64のキャンバスにネコの顔全体を描画し、48x48にフィットさせる"""
85
  base = np.zeros((64, 64))
86
+
87
+ # --- [復元] ネコの顔を描画するパーツ群 ---
88
+
89
+ # 1. 顔の輪郭 (下半分のアーク)
90
  t = np.linspace(np.deg2rad(205), np.deg2rad(335), 160)
91
  pts_face = list(zip(32 + 18 * np.cos(t), 34 + 18 * np.sin(t)))
92
+ draw_polyline(base, pts_face, thickness=1.3)
93
+
94
+ # 2. 耳
95
+ draw_polyline(base, [(19, 24), (25, 10), (30, 24)], thickness=1.25) # 左
96
+ draw_polyline(base, [(34, 24), (39, 10), (45, 24)], thickness=1.25) # 右
97
+
98
+ # 3. 頭頂部
99
+ draw_polyline(base, [(30, 24), (32, 22), (34, 24)], thickness=1.15)
100
+
101
+ # 4. ほっぺた
102
+ draw_polyline(base, [(19, 24), (15, 31), (16, 41)], thickness=1.25) # 左
103
+ draw_polyline(base, [(45, 24), (49, 31), (48, 41)], thickness=1.25) # 右
104
+
105
+ # 5. 目 (アーチ状)
106
+ draw_polyline(base, [(24, 30), (27, 28), (30, 30)], thickness=1.0) # 左
107
+ draw_polyline(base, [(34, 30), (37, 28), (40, 30)], thickness=1.0) # 右
108
+
109
+ # 6. 鼻と口
110
+ draw_polyline(base, [(30, 37), (32, 39), (34, 37), (32, 37), (30, 37)], thickness=1.0) # 鼻
111
+ draw_polyline(base, [(32, 39), (30, 42)], thickness=1.0) # 口・左
112
+ draw_polyline(base, [(32, 39), (34, 42)], thickness=1.0) # 口・右
113
+
114
+ # 7. ヒゲ (左右3本ずつ)
115
+ whiskers = [
116
+ [(17, 34), (25, 35)], [(16, 38), (25, 38)], [(17, 42), (25, 41)], # 左
117
+ [(39, 35), (47, 34)], [(39, 38), (48, 38)], [(39, 41), (47, 42)], # 右
118
+ ]
119
+ for pts in whiskers:
120
+ draw_polyline(base, pts, thickness=0.8)
121
+
122
+ # --- [復元ここまで] ---
123
+
124
+ # 指定サイズにフィットさせて返す
125
  return fit_binary_image_to_canvas(base, target_size=size)
126
 
127
  # -----------------------------
128
+ # 3. 畳み込み演算
129
  # -----------------------------
130
  KERNELS = {
131
  "縦線": np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]),
 
135
 
136
  @st.cache_data
137
  def run_convolution(img: np.ndarray, kernel: np.ndarray) -> np.ndarray:
138
+ """Validモードでの畳み込み演算(教材用シンプル実装)"""
139
  kh, kw = kernel.shape
140
  oh, ow = img.shape[0] - kh + 1, img.shape[1] - kw + 1
 
141
  output = np.zeros((oh, ow))
142
  for i in range(oh):
143
  for j in range(ow):
144
+ # 要素ごとの積の合計
145
  output[i, j] = np.sum(img[i:i+kh, j:j+kw] * kernel)
146
  return output
147
 
148
  # -----------------------------
149
+ # 4. 表示用コンポーネント
150
  # -----------------------------
151
+ def render_main_figures(img, k_name, row, col, show_ans):
152
+ """メインの3つのグラフ(入力、カーネル、結果)を描画"""
153
  kernel = KERNELS[k_name]
154
  conv_full = run_convolution(img, kernel)
155
  patch = img[row:row+3, col:col+3]
156
  val = conv_full[row, col]
157
 
158
+ fig, axes = plt.subplots(1, 3, figsize=(14, 5.5), constrained_layout=True)
 
159
 
160
+ # 1. 入力画像 (48x48)
161
+ axes[0].imshow(img, cmap="gray_r", interpolation="nearest")
162
+ # 現在の畳み込み位置を赤枠で表示
163
+ axes[0].add_patch(patches.Rectangle((col-0.5, row-0.5), 3, 3, lw=2.5, ec="red", fc="none"))
164
+ axes[0].set_title("入力画像 (48x48 ネコの顔全体)")
165
+ axes[0].set_xticks(np.arange(0, 48, 6)); axes[0].set_yticks(np.arange(0, 48, 6))
166
+ axes[0].grid(color="lightgray", lw=0.5, alpha=0.5)
167
+ set_jp_font(axes[0], is_bold=True, size=14)
168
 
169
+ # 2. カーネル (3x3)
170
  axes[1].set_xlim(-0.5, 2.5); axes[1].set_ylim(2.5, -0.5); axes[1].set_aspect("equal")
171
+ axes[1].set_xticks([0, 1, 2]); axes[1].set_yticks([0, 1, 2])
172
+ axes[1].grid(color="black", lw=1)
173
+
174
  for r in range(3):
175
  for c in range(3):
176
+ # カーネルの数値を描画
177
+ t = axes[1].text(c, r, f"{int(kernel[r,c])}", ha="center", va="center", fontsize=24)
178
  set_jp_font(t, is_bold=True)
179
+ # 対応する入力ピクセルが黒(1)なら背景を赤くする
180
+ if patch[r,c] > 0:
181
+ t.set_bbox(dict(facecolor="mistyrose", edgecolor="none", alpha=0.8, boxstyle="round,pad=0.2"))
182
+
183
+ axes[1].set_title(f"3x3 カーネル ({k_name})")
184
+ set_jp_font(axes[1], is_bold=True, size=14)
185
 
186
+ # 3. 畳み込み結果 (46x46)
187
+ # 表示用に正規化 (-1~1)
188
  norm_conv = conv_full / (np.max(np.abs(conv_full)) + 1e-12)
189
+ axes[2].imshow(norm_conv, cmap="bwr", vmin=-1, vmax=1, interpolation="nearest")
190
+ # 現在の結果位置を金枠で表示
191
+ axes[2].add_patch(patches.Rectangle((col-0.5, row-0.5), 1, 1, lw=2.5, ec="gold", fc="none"))
192
+
193
+ res_val_str = '?' if not show_ans else str(int(val))
194
+ axes[2].set_title(f"結果 (46x46) 現在値: {res_val_str}")
195
+ axes[2].set_xticks(np.arange(0, 46, 6)); axes[2].set_yticks(np.arange(0, 46, 6))
196
+ axes[2].grid(color="lightgray", lw=0.5, alpha=0.5)
197
+ set_jp_font(axes[2], is_bold=True, size=14)
198
 
199
  st.pyplot(fig)
200
 
201
  # -----------------------------
202
+ # 5. メインアプリ
203
  # -----------------------------
204
  def main():
205
  st.set_page_config(page_title="CNN Convolution Demo", layout="wide")
206
+ st.title("🔢 CNNの畳み込み計算を���てみよう")
207
+ st.markdown("48x48のネコの顔画像に対して、3x3のカーネル(フィルタ)を滑らせて、エッジを抽出する様子を可視化します。")
208
 
209
+ # セッション状態の初期化 (前回の位置を記憶)
210
+ if "idx" not in st.session_state: st.session_state.idx = 1110 # 顔の中心付近
211
 
212
  img = get_cat_image()
213
+ # Validモードなので出力サイズは N - K + 1
214
+ output_size = img.shape[0] - 2 # 46
215
 
216
+ # --- サイドバー操作パネル ---
217
  with st.sidebar:
218
+ st.header("⚙️ 操作パネル")
219
+ k_name = st.radio("カーネル(フィルタ)を", list(KERNELS.keys()), index=2) # デフォルト輪郭
220
+ show_ans = st.checkbox("計算の答えを表示する", value=True)
221
 
222
  st.divider()
223
+ st.write("📍 **位置移動する**")
224
+
225
+ # ボタンによる移動
226
  c1, c2 = st.columns(2)
227
+ if c1.button("← 前へ (1px)", use_container_width=True):
228
+ st.session_state.idx = max(0, st.session_state.idx - 1)
229
+ if c2.button("次へ → (1px)", use_container_width=True):
230
+ st.session_state.idx = min(output_size**2 - 1, st.session_state.idx + 1)
231
 
232
+ # スライダーによる移動
233
+ st.session_state.idx = st.slider(
234
+ "スライダーで連続移動",
235
+ 0, output_size**2 - 1,
236
+ st.session_state.idx,
237
+ label_visibility="collapsed"
238
+ )
239
+
240
+ # 1次元インデックスを2次元座標(行i, 列j)に変換
241
  row, col = divmod(st.session_state.idx, output_size)
242
+ st.caption(f"現在の中心座標 (Valid領域): 行={row}, 列={col}")
243
+
244
+ st.info("赤枠(入力)の9マスの数値と、カーネルの9マスの数値を掛け算して合計したものが、金枠(結果)の1マスの数値になります。")
245
 
246
+ # --- メインエリア描画 ---
247
+ render_main_figures(img, k_name, row, col, show_ans)
248
 
249
+ # --- 下部詳細エリア(拡大図と計算式) ---
250
+ col_left, col_right = st.columns([1, 1.2])
251
  patch = img[row:row+3, col:col+3]
252
  kernel = KERNELS[k_name]
253
+ products = patch * kernel
254
 
255
  with col_left:
256
+ st.subheader("🔍 現在位置の拡大図 (3x3)")
257
+ fig_sub, axes_sub = plt.subplots(1, 2, figsize=(7, 3.5), constrained_layout=True)
258
+
259
+ # 拡大した抽出領域
260
+ axes_sub[0].imshow(patch, cmap="gray_r", interpolation="nearest")
261
+ axes_sub[0].set_title("抽出された3x3領域\n(0:白, 1:黒)")
262
+ set_jp_font(axes_sub[0], size=11)
263
+
264
+ # 要素ごとの積
265
+ axes_sub[1].imshow(products, cmap="bwr", vmin=-2, vmax=8, interpolation="nearest")
266
+ axes_sub[1].set_title("要素ごとの掛け算の結果\n(Patch × Kernel)")
267
+ set_jp_font(axes_sub[1], size=11)
268
+
269
+ for ax in axes_sub:
270
+ ax.set_xticks([0, 1, 2]); ax.set_yticks([0, 1, 2])
271
+ ax.grid(color="gray", lw=0.5)
272
+ # 数値をオーバーレイ
273
+ mat = patch if ax == axes_sub[0] else products
274
+ for r in range(3):
275
+ for c in range(3):
276
+ t = ax.text(c, r, f"{int(mat[r,c])}", ha="center", va="center", fontsize=16)
277
+ set_jp_font(t, is_bold=True)
278
+ t.set_bbox(dict(facecolor="white", alpha=0.5, edgecolor="none"))
279
+
280
  st.pyplot(fig_sub)
281
 
282
  with col_right:
283
+ st.subheader(f"🧮 計算式 (行={row}, 列={col})")
284
+
285
  if show_ans:
286
+ # フラット化して計算式を生成
287
+ p_f = patch.flatten().astype(int)
288
+ k_f = kernel.flatten().astype(int)
289
+
290
+ # 3x3の形式で見せるための改行付きリスト
291
+ formula_lines = []
292
+ total_val = 0
293
+ for r in range(3):
294
+ line_terms = []
295
+ for c in range(3):
296
+ p, k = int(patch[r, c]), int(kernel[r, c])
297
+ total_val += p * k
298
+ # 教材用に分かりやすく (入力 × カーネル)
299
+ line_terms.append(f"({p}×{k:2})")
300
+ formula_lines.append(" + ".join(line_terms))
301
+
302
+ formula_text = " \n+ ".join(formula_lines)
303
+
304
+ st.code(
305
+ f"要素ごとの積の合計:\n\n {formula_text}\n\n= 合計: {total_val}",
306
+ language="text"
307
+ )
308
+ st.metric(label="この位置の畳み込み出力値", value=total_val)
309
  else:
310
+ st.warning("サイドバーの「計算の答えを表示する」をチェックして、手計算の結果を確認してください。")
311
+ st.code(
312
+ "要素ごとの積の合計:\n\n(ここを計算してみよう)\n\n= 合計: ?",
313
+ language="text"
314
+ )
315
 
316
  if __name__ == "__main__":
 
 
 
317
  main()