db-d2 commited on
Commit
a681fba
·
1 Parent(s): b34e98b

Feat: Gradient descent and chain rule helper

Browse files
Files changed (1) hide show
  1. app.py +1895 -0
app.py ADDED
@@ -0,0 +1,1895 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+
4
+ # Fallout Terminal Theme CSS
5
+ # Color palette:
6
+ # - Pip-Boy Amber: #f0b030 (warm, readable headers)
7
+ # - Terminal Green: #4ade80 (softer green, easy on eyes)
8
+ # - Vault-Tec Blue: #5b9bd5 (trusty Vault-Tec corporate blue)
9
+ # - Background: #0c0c0c (near-black terminal)
10
+ # - Panel BG: #141414 (slightly lifted for depth)
11
+
12
+ FALLOUT_CSS = """
13
+ @import url('https://fonts.googleapis.com/css2?family=VT323&display=swap');
14
+ @import url('https://fonts.googleapis.com/css2?family=Share+Tech+Mono&display=swap');
15
+
16
+ :root {
17
+ --pip-amber: #f0b030;
18
+ --pip-amber-dim: #c49028;
19
+ --terminal-green: #4ade80;
20
+ --terminal-green-dim: #22c55e;
21
+ --vault-blue: #5b9bd5;
22
+ --vault-blue-dim: #4080b8;
23
+ --bg-dark: #0c0c0c;
24
+ --bg-panel: #141414;
25
+ --bg-input: #1a1a1a;
26
+ --text-muted: #888888;
27
+ }
28
+
29
+ * {
30
+ font-family: 'Share Tech Mono', 'VT323', monospace !important;
31
+ font-size: 20px !important;
32
+ line-height: 1.6 !important;
33
+ }
34
+
35
+ h1 { font-size: 36px !important; }
36
+ h2 { font-size: 30px !important; }
37
+ h3 { font-size: 24px !important; }
38
+ h4, h5 { font-size: 22px !important; }
39
+ code, pre { font-size: 18px !important; }
40
+
41
+ body, .gradio-container {
42
+ background-color: var(--bg-dark) !important;
43
+ }
44
+
45
+ .gradio-container {
46
+ max-width: 1200px !important;
47
+ }
48
+
49
+ /* Main text - soft green, NO glow */
50
+ .markdown-text, .prose, p, span, label, .label-wrap {
51
+ color: var(--terminal-green) !important;
52
+ }
53
+
54
+ /* Headers - warm amber for hierarchy */
55
+ h1 {
56
+ color: var(--pip-amber) !important;
57
+ border-bottom: 2px solid var(--pip-amber-dim) !important;
58
+ padding-bottom: 8px !important;
59
+ }
60
+
61
+ h2 {
62
+ color: var(--pip-amber) !important;
63
+ border-bottom: 1px solid var(--pip-amber-dim) !important;
64
+ padding-bottom: 4px !important;
65
+ }
66
+
67
+ h3, h4, h5 {
68
+ color: var(--vault-blue) !important;
69
+ border-bottom: none !important;
70
+ }
71
+
72
+ /* Tab styling - Vault-Tec blue for navigation */
73
+ .tabs {
74
+ background-color: var(--bg-dark) !important;
75
+ border: 1px solid var(--vault-blue-dim) !important;
76
+ border-radius: 4px !important;
77
+ }
78
+
79
+ .tab-nav {
80
+ background-color: var(--bg-panel) !important;
81
+ border-bottom: 2px solid var(--vault-blue-dim) !important;
82
+ }
83
+
84
+ .tab-nav button {
85
+ background-color: var(--bg-panel) !important;
86
+ color: var(--vault-blue) !important;
87
+ border: none !important;
88
+ border-right: 1px solid var(--bg-dark) !important;
89
+ padding: 10px 16px !important;
90
+ transition: all 0.2s ease !important;
91
+ }
92
+
93
+ .tab-nav button:hover {
94
+ background-color: #1e3a5f !important;
95
+ color: #8ec5fc !important;
96
+ }
97
+
98
+ .tab-nav button.selected {
99
+ background-color: #1a3550 !important;
100
+ color: #8ec5fc !important;
101
+ border-bottom: 2px solid var(--pip-amber) !important;
102
+ }
103
+
104
+ /* Input/Output boxes - subtle with green text */
105
+ .textbox, textarea, input {
106
+ background-color: var(--bg-input) !important;
107
+ color: var(--terminal-green) !important;
108
+ border: 1px solid #333 !important;
109
+ border-radius: 3px !important;
110
+ }
111
+
112
+ .textbox:focus, textarea:focus, input:focus {
113
+ border-color: var(--terminal-green-dim) !important;
114
+ outline: none !important;
115
+ }
116
+
117
+ /* Buttons - amber accent for actions */
118
+ .primary, .secondary, button {
119
+ background-color: #2a2010 !important;
120
+ color: var(--pip-amber) !important;
121
+ border: 1px solid var(--pip-amber-dim) !important;
122
+ border-radius: 3px !important;
123
+ transition: all 0.2s ease !important;
124
+ }
125
+
126
+ button:hover {
127
+ background-color: #3d2e15 !important;
128
+ border-color: var(--pip-amber) !important;
129
+ }
130
+
131
+ /* Sliders - amber accent */
132
+ input[type="range"] {
133
+ accent-color: var(--pip-amber) !important;
134
+ }
135
+
136
+ /* Number inputs */
137
+ .number-input input {
138
+ background-color: var(--bg-input) !important;
139
+ color: var(--terminal-green) !important;
140
+ border: 1px solid #333 !important;
141
+ }
142
+
143
+ /* Code blocks - slightly blue-tinted for distinction */
144
+ code, pre {
145
+ background-color: #0d1520 !important;
146
+ color: var(--terminal-green) !important;
147
+ border: 1px solid #2a4060 !important;
148
+ border-left: 3px solid var(--vault-blue) !important;
149
+ border-radius: 3px !important;
150
+ padding: 2px 6px !important;
151
+ }
152
+
153
+ pre {
154
+ padding: 12px !important;
155
+ }
156
+
157
+ /* Tables */
158
+ table {
159
+ border-collapse: collapse !important;
160
+ }
161
+
162
+ th {
163
+ background-color: #1a2a3a !important;
164
+ color: var(--pip-amber) !important;
165
+ border: 1px solid #2a4060 !important;
166
+ padding: 8px !important;
167
+ }
168
+
169
+ td {
170
+ background-color: var(--bg-panel) !important;
171
+ color: var(--terminal-green) !important;
172
+ border: 1px solid #2a4060 !important;
173
+ padding: 8px !important;
174
+ }
175
+
176
+ /* Strong/bold text - amber for emphasis */
177
+ strong, b {
178
+ color: var(--pip-amber) !important;
179
+ font-weight: bold !important;
180
+ }
181
+
182
+ /* Links */
183
+ a {
184
+ color: var(--vault-blue) !important;
185
+ }
186
+
187
+ a:hover {
188
+ color: #8ec5fc !important;
189
+ }
190
+
191
+ /* Radio buttons and checkboxes */
192
+ .radio-group label, .checkbox-group label {
193
+ color: var(--terminal-green) !important;
194
+ }
195
+
196
+ /* Scrollbar - subtle */
197
+ ::-webkit-scrollbar {
198
+ width: 8px;
199
+ height: 8px;
200
+ background-color: var(--bg-dark);
201
+ }
202
+
203
+ ::-webkit-scrollbar-thumb {
204
+ background-color: #333;
205
+ border-radius: 4px;
206
+ }
207
+
208
+ ::-webkit-scrollbar-thumb:hover {
209
+ background-color: #444;
210
+ }
211
+
212
+ /* Subtle scanlines - very light, not distracting */
213
+ .gradio-container::before {
214
+ content: "";
215
+ position: fixed;
216
+ top: 0;
217
+ left: 0;
218
+ width: 100%;
219
+ height: 100%;
220
+ background: repeating-linear-gradient(
221
+ 0deg,
222
+ rgba(0, 0, 0, 0.03),
223
+ rgba(0, 0, 0, 0.03) 1px,
224
+ transparent 1px,
225
+ transparent 2px
226
+ );
227
+ pointer-events: none;
228
+ z-index: 1000;
229
+ }
230
+
231
+ /* Horizontal rules - amber accent */
232
+ hr {
233
+ border: none !important;
234
+ border-top: 1px solid var(--pip-amber-dim) !important;
235
+ margin: 16px 0 !important;
236
+ }
237
+
238
+ /* Blockquotes - for terminal prompts */
239
+ blockquote {
240
+ border-left: 3px solid var(--pip-amber) !important;
241
+ background-color: var(--bg-panel) !important;
242
+ padding: 8px 16px !important;
243
+ margin: 8px 0 !important;
244
+ color: var(--pip-amber) !important;
245
+ }
246
+
247
+ /* Muted/secondary text */
248
+ .secondary-text, .hint {
249
+ color: var(--text-muted) !important;
250
+ }
251
+ """
252
+
253
+ # ============================================================================
254
+ # SVG DIAGRAM GENERATORS
255
+ # ============================================================================
256
+
257
+ def generate_forward_svg(x1, x2, w1, w2, b, z, y):
258
+ """Generate an SVG diagram showing the forward pass with actual values."""
259
+
260
+ # Colors matching our theme
261
+ bg = "#0c0c0c"
262
+ node_fill = "#1a2a3a"
263
+ node_stroke = "#5b9bd5"
264
+ input_fill = "#1a3a2a"
265
+ input_stroke = "#4ade80"
266
+ output_fill = "#2a2a1a"
267
+ output_stroke = "#f0b030"
268
+ text_color = "#4ade80"
269
+ label_color = "#5b9bd5"
270
+ arrow_color = "#5b9bd5"
271
+ value_color = "#f0b030"
272
+
273
+ svg = f'''
274
+ <svg viewBox="0 0 800 320" style="width:100%; max-width:800px; height:auto; background:{bg}; border-radius:8px; border:1px solid #333;">
275
+ <defs>
276
+ <marker id="arrowhead" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
277
+ <polygon points="0 0, 10 3.5, 0 7" fill="{arrow_color}" />
278
+ </marker>
279
+ <marker id="arrowhead-amber" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
280
+ <polygon points="0 0, 10 3.5, 0 7" fill="{output_stroke}" />
281
+ </marker>
282
+ </defs>
283
+
284
+ <!-- Title -->
285
+ <text x="400" y="30" text-anchor="middle" fill="{output_stroke}" font-size="20" font-family="monospace">FORWARD PASS: Data Flow</text>
286
+
287
+ <!-- Input nodes -->
288
+ <rect x="40" y="60" width="80" height="60" rx="8" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/>
289
+ <text x="80" y="85" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">x₁</text>
290
+ <text x="80" y="108" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace" font-weight="bold">{x1:.2f}</text>
291
+
292
+ <rect x="40" y="200" width="80" height="60" rx="8" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/>
293
+ <text x="80" y="225" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">x₂</text>
294
+ <text x="80" y="248" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace" font-weight="bold">{x2:.2f}</text>
295
+
296
+ <!-- Weight labels on arrows -->
297
+ <line x1="120" y1="90" x2="220" y2="140" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arrowhead)"/>
298
+ <text x="155" y="100" fill="{text_color}" font-size="12" font-family="monospace">w₁={w1:.2f}</text>
299
+
300
+ <line x1="120" y1="230" x2="220" y2="180" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arrowhead)"/>
301
+ <text x="155" y="230" fill="{text_color}" font-size="12" font-family="monospace">w₂={w2:.2f}</text>
302
+
303
+ <!-- Summation node -->
304
+ <rect x="230" y="130" width="100" height="70" rx="8" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/>
305
+ <text x="280" y="152" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">Σ + b</text>
306
+ <text x="280" y="175" text-anchor="middle" fill="{text_color}" font-size="11" font-family="monospace">b={b:.2f}</text>
307
+ <text x="280" y="193" text-anchor="middle" fill="{value_color}" font-size="14" font-family="monospace" font-weight="bold">z={z:.3f}</text>
308
+
309
+ <!-- Arrow to sigmoid -->
310
+ <line x1="330" y1="165" x2="400" y2="165" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arrowhead)"/>
311
+
312
+ <!-- Sigmoid node -->
313
+ <rect x="410" y="130" width="100" height="70" rx="8" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/>
314
+ <text x="460" y="155" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">σ(z)</text>
315
+ <text x="460" y="175" text-anchor="middle" fill="{text_color}" font-size="10" font-family="monospace">1/(1+e⁻ᶻ)</text>
316
+ <text x="460" y="193" text-anchor="middle" fill="{value_color}" font-size="14" font-family="monospace" font-weight="bold">ŷ={y:.4f}</text>
317
+
318
+ <!-- Arrow to output -->
319
+ <line x1="510" y1="165" x2="580" y2="165" stroke="{output_stroke}" stroke-width="2" marker-end="url(#arrowhead-amber)"/>
320
+
321
+ <!-- Output node -->
322
+ <rect x="590" y="130" width="100" height="70" rx="8" fill="{output_fill}" stroke="{output_stroke}" stroke-width="2"/>
323
+ <text x="640" y="155" text-anchor="middle" fill="{output_stroke}" font-size="14" font-family="monospace">OUTPUT</text>
324
+ <text x="640" y="180" text-anchor="middle" fill="{value_color}" font-size="18" font-family="monospace" font-weight="bold">{y:.4f}</text>
325
+
326
+ <!-- Legend -->
327
+ <rect x="40" y="280" width="15" height="15" fill="{input_fill}" stroke="{input_stroke}" stroke-width="1"/>
328
+ <text x="60" y="292" fill="{text_color}" font-size="12" font-family="monospace">Inputs</text>
329
+
330
+ <rect x="140" y="280" width="15" height="15" fill="{node_fill}" stroke="{node_stroke}" stroke-width="1"/>
331
+ <text x="160" y="292" fill="{text_color}" font-size="12" font-family="monospace">Operations</text>
332
+
333
+ <rect x="280" y="280" width="15" height="15" fill="{output_fill}" stroke="{output_stroke}" stroke-width="1"/>
334
+ <text x="300" y="292" fill="{text_color}" font-size="12" font-family="monospace">Output</text>
335
+
336
+ <text x="420" y="292" fill="{value_color}" font-size="12" font-family="monospace">■ Computed Values</text>
337
+ </svg>
338
+ '''
339
+ return svg
340
+
341
+
342
+ def generate_backward_svg(x1, x2, w1, w2, b, y_true, z, y_pred, dL_dy, dy_dz, dL_dz, dL_dw1, dL_dw2, dL_db, loss):
343
+ """Generate an SVG diagram showing backward pass with gradients."""
344
+
345
+ bg = "#0c0c0c"
346
+ node_fill = "#1a2a3a"
347
+ node_stroke = "#5b9bd5"
348
+ input_fill = "#1a3a2a"
349
+ input_stroke = "#4ade80"
350
+ loss_fill = "#3a1a1a"
351
+ loss_stroke = "#ff6b6b"
352
+ text_color = "#4ade80"
353
+ label_color = "#5b9bd5"
354
+ forward_arrow = "#5b9bd5"
355
+ backward_arrow = "#ff6b6b"
356
+ value_color = "#f0b030"
357
+ gradient_color = "#ff6b6b"
358
+
359
+ svg = f'''
360
+ <svg viewBox="0 0 900 420" style="width:100%; max-width:900px; height:auto; background:{bg}; border-radius:8px; border:1px solid #333;">
361
+ <defs>
362
+ <marker id="fwd" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
363
+ <polygon points="0 0, 10 3.5, 0 7" fill="{forward_arrow}" />
364
+ </marker>
365
+ <marker id="bwd" markerWidth="10" markerHeight="7" refX="0" refY="3.5" orient="auto">
366
+ <polygon points="10 0, 0 3.5, 10 7" fill="{backward_arrow}" />
367
+ </marker>
368
+ </defs>
369
+
370
+ <!-- Title -->
371
+ <text x="300" y="25" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace">GRADIENT FLOW DIAGRAM</text>
372
+
373
+ <!-- FORWARD PATH - Row 1 -->
374
+ <!-- Input x1 -->
375
+ <rect x="20" y="50" width="60" height="50" rx="5" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/>
376
+ <text x="50" y="70" text-anchor="middle" fill="{label_color}" font-size="11" font-family="monospace">x1</text>
377
+ <text x="50" y="88" text-anchor="middle" fill="{value_color}" font-size="12" font-family="monospace">{x1:.2f}</text>
378
+
379
+ <!-- Weight w1 -->
380
+ <line x1="80" y1="75" x2="105" y2="75" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/>
381
+ <rect x="115" y="55" width="55" height="40" rx="4" fill="#1a1a2a" stroke="{forward_arrow}" stroke-width="1"/>
382
+ <text x="142" y="80" text-anchor="middle" fill="{value_color}" font-size="10" font-family="monospace">w1={w1:.1f}</text>
383
+
384
+ <!-- Arrow to Sum -->
385
+ <line x1="170" y1="75" x2="195" y2="100" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/>
386
+
387
+ <!-- Input x2 -->
388
+ <rect x="20" y="120" width="60" height="50" rx="5" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/>
389
+ <text x="50" y="140" text-anchor="middle" fill="{label_color}" font-size="11" font-family="monospace">x2</text>
390
+ <text x="50" y="158" text-anchor="middle" fill="{value_color}" font-size="12" font-family="monospace">{x2:.2f}</text>
391
+
392
+ <!-- Weight w2 -->
393
+ <line x1="80" y1="145" x2="105" y2="145" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/>
394
+ <rect x="115" y="125" width="55" height="40" rx="4" fill="#1a1a2a" stroke="{forward_arrow}" stroke-width="1"/>
395
+ <text x="142" y="150" text-anchor="middle" fill="{value_color}" font-size="10" font-family="monospace">w2={w2:.1f}</text>
396
+
397
+ <!-- Arrow to Sum -->
398
+ <line x1="170" y1="145" x2="195" y2="120" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/>
399
+
400
+ <!-- Sum node -->
401
+ <rect x="205" y="90" width="70" height="50" rx="5" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/>
402
+ <text x="240" y="110" text-anchor="middle" fill="{label_color}" font-size="11" font-family="monospace">Sum+b</text>
403
+ <text x="240" y="128" text-anchor="middle" fill="{value_color}" font-size="10" font-family="monospace">z={z:.2f}</text>
404
+
405
+ <!-- Arrow to Sigmoid -->
406
+ <line x1="275" y1="115" x2="310" y2="115" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/>
407
+
408
+ <!-- Sigmoid -->
409
+ <rect x="320" y="90" width="70" height="50" rx="5" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/>
410
+ <text x="355" y="110" text-anchor="middle" fill="{label_color}" font-size="11" font-family="monospace">sigmoid</text>
411
+ <text x="355" y="128" text-anchor="middle" fill="{value_color}" font-size="10" font-family="monospace">y={y_pred:.3f}</text>
412
+
413
+ <!-- Arrow to Loss -->
414
+ <line x1="390" y1="115" x2="425" y2="115" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/>
415
+
416
+ <!-- Loss -->
417
+ <rect x="435" y="85" width="80" height="60" rx="5" fill="{loss_fill}" stroke="{loss_stroke}" stroke-width="2"/>
418
+ <text x="475" y="105" text-anchor="middle" fill="{loss_stroke}" font-size="11" font-family="monospace">BCE</text>
419
+ <text x="475" y="122" text-anchor="middle" fill="{value_color}" font-size="11" font-family="monospace">L={loss:.4f}</text>
420
+ <text x="475" y="138" text-anchor="middle" fill="{text_color}" font-size="9" font-family="monospace">y_true={y_true}</text>
421
+
422
+ <!-- BACKWARD SECTION -->
423
+ <text x="300" y="185" text-anchor="middle" fill="{backward_arrow}" font-size="12" font-family="monospace">BACKWARD PASS (gradients)</text>
424
+
425
+ <!-- Gradient chain boxes -->
426
+ <rect x="435" y="200" width="80" height="35" rx="4" fill="{loss_fill}" stroke="{loss_stroke}" stroke-width="1"/>
427
+ <text x="475" y="222" text-anchor="middle" fill="{gradient_color}" font-size="9" font-family="monospace">dL/dy={dL_dy:.2f}</text>
428
+
429
+ <line x1="435" y1="218" x2="405" y2="218" stroke="{backward_arrow}" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd)"/>
430
+
431
+ <rect x="320" y="200" width="80" height="35" rx="4" fill="{node_fill}" stroke="{node_stroke}" stroke-width="1"/>
432
+ <text x="360" y="222" text-anchor="middle" fill="{gradient_color}" font-size="9" font-family="monospace">dy/dz={dy_dz:.3f}</text>
433
+
434
+ <line x1="320" y1="218" x2="290" y2="218" stroke="{backward_arrow}" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd)"/>
435
+
436
+ <rect x="205" y="200" width="80" height="35" rx="4" fill="{node_fill}" stroke="{node_stroke}" stroke-width="1"/>
437
+ <text x="245" y="222" text-anchor="middle" fill="{gradient_color}" font-size="9" font-family="monospace">dL/dz={dL_dz:.3f}</text>
438
+
439
+ <line x1="205" y1="218" x2="175" y2="218" stroke="{backward_arrow}" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd)"/>
440
+
441
+ <!-- Final Gradients Box -->
442
+ <rect x="20" y="260" width="240" height="80" rx="6" fill="#141414" stroke="{gradient_color}" stroke-width="1"/>
443
+ <text x="140" y="282" text-anchor="middle" fill="{value_color}" font-size="12" font-family="monospace">COMPUTED GRADIENTS</text>
444
+ <text x="140" y="305" text-anchor="middle" fill="{gradient_color}" font-size="11" font-family="monospace">dL/dw1 = {dL_dw1:.4f}</text>
445
+ <text x="140" y="325" text-anchor="middle" fill="{gradient_color}" font-size="11" font-family="monospace">dL/dw2 = {dL_dw2:.4f} dL/db = {dL_db:.4f}</text>
446
+
447
+ <!-- Chain Rule Box -->
448
+ <rect x="530" y="50" width="360" height="190" rx="6" fill="#141414" stroke="#555" stroke-width="1"/>
449
+ <text x="710" y="75" text-anchor="middle" fill="{value_color}" font-size="13" font-family="monospace">CHAIN RULE COMPUTATION</text>
450
+
451
+ <text x="545" y="100" fill="{text_color}" font-size="11" font-family="monospace">dL/dw1 = dL/dy * dy/dz * dz/dw1</text>
452
+ <text x="545" y="125" fill="#888" font-size="11" font-family="monospace"> = ({dL_dy:.2f}) * ({dy_dz:.4f}) * ({x1:.2f})</text>
453
+ <text x="545" y="150" fill="{value_color}" font-size="12" font-family="monospace"> = {dL_dw1:.4f}</text>
454
+
455
+ <line x1="545" y1="165" x2="875" y2="165" stroke="#333" stroke-width="1"/>
456
+
457
+ <text x="545" y="185" fill="{text_color}" font-size="10" font-family="monospace">Key: dz/dw1 = x1, dz/dw2 = x2, dz/db = 1</text>
458
+ <text x="545" y="205" fill="#888" font-size="10" font-family="monospace">The input values become gradients!</text>
459
+ <text x="545" y="225" fill="{text_color}" font-size="10" font-family="monospace">dL/dw2 = dL/dz * x2 = {dL_dz:.3f} * {x2:.2f} = {dL_dw2:.4f}</text>
460
+
461
+ <!-- Legend -->
462
+ <rect x="530" y="260" width="360" height="80" rx="6" fill="#141414" stroke="#333" stroke-width="1"/>
463
+ <text x="710" y="282" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">LEGEND</text>
464
+ <line x1="550" y1="302" x2="590" y2="302" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/>
465
+ <text x="600" y="306" fill="{text_color}" font-size="10" font-family="monospace">Forward (data)</text>
466
+ <line x1="550" y1="322" x2="590" y2="322" stroke="{backward_arrow}" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd)"/>
467
+ <text x="600" y="326" fill="{text_color}" font-size="10" font-family="monospace">Backward (grads)</text>
468
+ <rect x="750" y="295" width="12" height="12" fill="{gradient_color}"/>
469
+ <text x="770" y="306" fill="{text_color}" font-size="10" font-family="monospace">Gradient values</text>
470
+ </svg>
471
+ '''
472
+ return svg
473
+
474
+
475
+ # ============================================================================
476
+ # TAB 1: FORWARD PASS
477
+ # ============================================================================
478
+
479
+ def forward_pass_demo(x1, x2, w1, w2, b):
480
+ """Step-by-step forward pass calculation."""
481
+
482
+ # Step 1: Weighted sum
483
+ z = w1 * x1 + w2 * x2 + b
484
+
485
+ # Step 2: Sigmoid activation
486
+ sigmoid_z = 1 / (1 + np.exp(-z))
487
+
488
+ # Generate SVG diagram
489
+ svg_diagram = generate_forward_svg(x1, x2, w1, w2, b, z, sigmoid_z)
490
+
491
+ explanation = f"""
492
+ ## FORWARD PASS CALCULATION
493
+ ===============================================
494
+
495
+ ### STEP 1: The Weighted Sum (z)
496
+
497
+ The neuron computes a **weighted sum** of inputs plus a bias:
498
+
499
+ ```
500
+ z = w1*x1 + w2*x2 + b
501
+ z = ({w1:.2f})*({x1:.2f}) + ({w2:.2f})*({x2:.2f}) + ({b:.2f})
502
+ z = {w1*x1:.4f} + {w2*x2:.4f} + {b:.2f}
503
+ z = {z:.4f}
504
+ ```
505
+
506
+ **What's happening:** Each input is scaled by its weight,
507
+ then we add them up. The bias shifts the whole thing.
508
+
509
+ -----------------------------------------------
510
+
511
+ ### STEP 2: The Sigmoid Activation Function
512
+
513
+ We squash z through the **sigmoid function** to get output in (0,1):
514
+
515
+ ```
516
+ sigmoid(z) = 1 / (1 + e^(-z))
517
+ = 1 / (1 + e^(-{z:.4f}))
518
+ = 1 / (1 + {np.exp(-z):.4f})
519
+ = 1 / {1 + np.exp(-z):.4f}
520
+ = {sigmoid_z:.4f}
521
+ ```
522
+
523
+ **Why sigmoid?** It smoothly maps any real number to (0,1).
524
+ - z >> 0 --> sigmoid(z) ≈ 1
525
+ - z << 0 --> sigmoid(z) ≈ 0
526
+ - z = 0 --> sigmoid(z) = 0.5
527
+
528
+ -----------------------------------------------
529
+
530
+ ### SUMMARY
531
+
532
+ ```
533
+ Inputs: x1={x1:.2f}, x2={x2:.2f}
534
+ Weights: w1={w1:.2f}, w2={w2:.2f}
535
+ Bias: b={b:.2f}
536
+
537
+ z = {z:.4f}
538
+ y = sigmoid(z) = {sigmoid_z:.4f}
539
+ ```
540
+
541
+ **Interpretation:** Output of {sigmoid_z:.4f} means
542
+ {sigmoid_z*100:.1f}% probability of class 1.
543
+ """
544
+ return svg_diagram, explanation
545
+
546
+
547
+ FORWARD_INTRO_SVG = '''
548
+ <svg viewBox="0 0 900 420" style="width:100%; max-width:900px; height:auto; background:#0c0c0c; border-radius:8px; border:1px solid #333; margin-bottom:20px;">
549
+ <defs>
550
+ <marker id="fwd-arr" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
551
+ <polygon points="0 0, 10 3.5, 0 7" fill="#5b9bd5" />
552
+ </marker>
553
+ </defs>
554
+
555
+ <!-- Title -->
556
+ <text x="450" y="35" text-anchor="middle" fill="#f0b030" font-size="20" font-family="monospace" font-weight="bold">FORWARD PASS: DATA IN → PREDICTION OUT</text>
557
+
558
+ <!-- Main neuron diagram -->
559
+ <rect x="20" y="55" width="480" height="200" rx="8" fill="#141414" stroke="#5b9bd5" stroke-width="1"/>
560
+ <text x="40" y="80" fill="#5b9bd5" font-size="14" font-family="monospace" font-weight="bold">THE SINGLE NEURON</text>
561
+
562
+ <!-- Input x1 -->
563
+ <circle cx="70" cy="120" r="22" fill="#1a3a2a" stroke="#4ade80" stroke-width="2"/>
564
+ <text x="70" y="125" text-anchor="middle" fill="#4ade80" font-size="13" font-family="monospace">x₁</text>
565
+
566
+ <!-- Input x2 -->
567
+ <circle cx="70" cy="195" r="22" fill="#1a3a2a" stroke="#4ade80" stroke-width="2"/>
568
+ <text x="70" y="200" text-anchor="middle" fill="#4ade80" font-size="13" font-family="monospace">x₂</text>
569
+
570
+ <!-- Weights on arrows -->
571
+ <line x1="92" y1="120" x2="155" y2="145" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-arr)"/>
572
+ <text x="108" y="115" fill="#f0b030" font-size="11" font-family="monospace">×w₁</text>
573
+
574
+ <line x1="92" y1="195" x2="155" y2="170" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-arr)"/>
575
+ <text x="108" y="210" fill="#f0b030" font-size="11" font-family="monospace">×w₂</text>
576
+
577
+ <!-- Summation node -->
578
+ <circle cx="185" cy="157" r="28" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="2"/>
579
+ <text x="185" y="152" text-anchor="middle" fill="#5b9bd5" font-size="16" font-family="monospace">Σ</text>
580
+ <text x="185" y="170" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">+b</text>
581
+
582
+ <!-- Arrow to z -->
583
+ <line x1="213" y1="157" x2="255" y2="157" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-arr)"/>
584
+
585
+ <!-- z value box -->
586
+ <rect x="265" y="140" width="40" height="35" rx="5" fill="#2a2a1a" stroke="#f0b030" stroke-width="2"/>
587
+ <text x="285" y="162" text-anchor="middle" fill="#f0b030" font-size="13" font-family="monospace">z</text>
588
+
589
+ <!-- Arrow to sigmoid -->
590
+ <line x1="305" y1="157" x2="340" y2="157" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-arr)"/>
591
+
592
+ <!-- Sigmoid box -->
593
+ <rect x="350" y="135" width="60" height="45" rx="5" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="2"/>
594
+ <text x="380" y="155" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">σ(z)</text>
595
+ <text x="380" y="170" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">sigmoid</text>
596
+
597
+ <!-- Arrow to output -->
598
+ <line x1="410" y1="157" x2="445" y2="157" stroke="#4ade80" stroke-width="2" marker-end="url(#fwd-arr)"/>
599
+
600
+ <!-- Output -->
601
+ <circle cx="470" cy="157" r="22" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="2"/>
602
+ <text x="470" y="162" text-anchor="middle" fill="#ff6b6b" font-size="13" font-family="monospace">ŷ</text>
603
+
604
+ <!-- Step labels -->
605
+ <text x="185" y="215" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">STEP 1</text>
606
+ <text x="185" y="228" text-anchor="middle" fill="#5b9bd5" font-size="9" font-family="monospace">Weighted Sum</text>
607
+
608
+ <text x="380" y="215" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">STEP 2</text>
609
+ <text x="380" y="228" text-anchor="middle" fill="#5b9bd5" font-size="9" font-family="monospace">Activation</text>
610
+
611
+ <!-- Equations box - made wider -->
612
+ <rect x="515" y="55" width="365" height="200" rx="8" fill="#141414" stroke="#4ade80" stroke-width="1"/>
613
+ <text x="535" y="80" fill="#4ade80" font-size="14" font-family="monospace" font-weight="bold">THE MATH</text>
614
+
615
+ <text x="535" y="110" fill="#888" font-size="12" font-family="monospace">Step 1: Weighted Sum</text>
616
+ <text x="535" y="132" fill="#f0b030" font-size="14" font-family="monospace">z = w₁x₁ + w₂x₂ + b</text>
617
+
618
+ <text x="535" y="165" fill="#888" font-size="12" font-family="monospace">Step 2: Sigmoid</text>
619
+ <text x="535" y="187" fill="#f0b030" font-size="14" font-family="monospace">ŷ = σ(z) = 1/(1+e⁻ᶻ)</text>
620
+
621
+ <line x1="535" y1="200" x2="860" y2="200" stroke="#333" stroke-width="1"/>
622
+
623
+ <text x="535" y="222" fill="#4ade80" font-size="12" font-family="monospace">Output ŷ ∈ (0,1) = probability</text>
624
+ <text x="535" y="242" fill="#888" font-size="10" font-family="monospace">Squashes any real number to (0,1)</text>
625
+
626
+ <!-- Interactive prompt -->
627
+ <rect x="20" y="270" width="860" height="135" rx="8" fill="#1a1a1a" stroke="#888" stroke-width="1" stroke-dasharray="5,5"/>
628
+ <text x="450" y="300" text-anchor="middle" fill="#888" font-size="15" font-family="monospace">▼ INTERACTIVE TERMINAL ▼</text>
629
+ <text x="450" y="330" text-anchor="middle" fill="#5b9bd5" font-size="13" font-family="monospace">Adjust inputs (x₁, x₂), weights (w₁, w₂), and bias (b)</text>
630
+ <text x="450" y="360" text-anchor="middle" fill="#4ade80" font-size="13" font-family="monospace">Click "EXECUTE FORWARD PASS" to see the values</text>
631
+ <text x="450" y="390" text-anchor="middle" fill="#555" font-size="10" font-family="monospace">[Vault-Tec recommends saving your work before experiments]</text>
632
+ </svg>
633
+ '''
634
+
635
+ FORWARD_INTRO = f"""
636
+ {FORWARD_INTRO_SVG}
637
+ """
638
+
639
+ # ============================================================================
640
+ # TAB 2: CHAIN RULE FUNDAMENTALS
641
+ # ============================================================================
642
+
643
+ CHAIN_RULE_INTRO_SVG = '''
644
+ <svg viewBox="0 0 900 480" style="width:100%; max-width:900px; height:auto; background:#0c0c0c; border-radius:8px; border:1px solid #333; margin-bottom:20px;">
645
+ <defs>
646
+ <marker id="arr-blue" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
647
+ <polygon points="0 0, 10 3.5, 0 7" fill="#5b9bd5" />
648
+ </marker>
649
+ </defs>
650
+
651
+ <!-- Title -->
652
+ <text x="450" y="32" text-anchor="middle" fill="#f0b030" font-size="20" font-family="monospace" font-weight="bold">THE CHAIN RULE</text>
653
+
654
+ <!-- Section 1: Basic Idea -->
655
+ <rect x="20" y="50" width="860" height="110" rx="8" fill="#141414" stroke="#5b9bd5" stroke-width="1"/>
656
+ <text x="40" y="72" fill="#5b9bd5" font-size="14" font-family="monospace" font-weight="bold">1. THE BASIC IDEA</text>
657
+
658
+ <!-- Composition diagram - more compact -->
659
+ <rect x="50" y="90" width="45" height="32" rx="5" fill="#1a3a2a" stroke="#4ade80" stroke-width="2"/>
660
+ <text x="72" y="111" text-anchor="middle" fill="#4ade80" font-size="14" font-family="monospace">x</text>
661
+
662
+ <line x1="95" y1="106" x2="130" y2="106" stroke="#5b9bd5" stroke-width="2" marker-end="url(#arr-blue)"/>
663
+
664
+ <rect x="140" y="88" width="55" height="36" rx="5" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="2"/>
665
+ <text x="167" y="111" text-anchor="middle" fill="#5b9bd5" font-size="13" font-family="monospace">g(x)</text>
666
+
667
+ <line x1="195" y1="106" x2="230" y2="106" stroke="#5b9bd5" stroke-width="2" marker-end="url(#arr-blue)"/>
668
+
669
+ <rect x="240" y="90" width="45" height="32" rx="5" fill="#2a2a1a" stroke="#f0b030" stroke-width="2"/>
670
+ <text x="262" y="111" text-anchor="middle" fill="#f0b030" font-size="14" font-family="monospace">u</text>
671
+
672
+ <line x1="285" y1="106" x2="320" y2="106" stroke="#5b9bd5" stroke-width="2" marker-end="url(#arr-blue)"/>
673
+
674
+ <rect x="330" y="88" width="55" height="36" rx="5" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="2"/>
675
+ <text x="357" y="111" text-anchor="middle" fill="#5b9bd5" font-size="13" font-family="monospace">f(u)</text>
676
+
677
+ <line x1="385" y1="106" x2="420" y2="106" stroke="#5b9bd5" stroke-width="2" marker-end="url(#arr-blue)"/>
678
+
679
+ <rect x="430" y="90" width="45" height="32" rx="5" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="2"/>
680
+ <text x="452" y="111" text-anchor="middle" fill="#ff6b6b" font-size="14" font-family="monospace">y</text>
681
+
682
+ <!-- Labels -->
683
+ <text x="262" y="140" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">u = g(x)</text>
684
+ <text x="452" y="140" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">y = f(g(x))</text>
685
+
686
+ <!-- Formula box - wider -->
687
+ <rect x="510" y="80" width="350" height="65" rx="6" fill="#0c0c0c" stroke="#f0b030" stroke-width="1"/>
688
+ <text x="685" y="105" text-anchor="middle" fill="#f0b030" font-size="13" font-family="monospace">Chain Rule Formula:</text>
689
+ <text x="685" y="130" text-anchor="middle" fill="#4ade80" font-size="15" font-family="monospace">dy/dx = (dy/du) × (du/dx)</text>
690
+
691
+ <!-- Section 2: Why It Works -->
692
+ <rect x="20" y="170" width="420" height="140" rx="8" fill="#141414" stroke="#f0b030" stroke-width="1"/>
693
+ <text x="40" y="192" fill="#f0b030" font-size="14" font-family="monospace" font-weight="bold">2. WHY IT WORKS</text>
694
+
695
+ <text x="40" y="218" fill="#4ade80" font-size="13" font-family="monospace">Think of it like fractions:</text>
696
+
697
+ <!-- Fraction visualization - smaller -->
698
+ <text x="55" y="250" fill="#5b9bd5" font-size="16" font-family="monospace">dy</text>
699
+ <line x1="50" y1="255" x2="75" y2="255" stroke="#5b9bd5" stroke-width="2"/>
700
+ <text x="55" y="272" fill="#5b9bd5" font-size="16" font-family="monospace">dx</text>
701
+
702
+ <text x="90" y="260" fill="#888" font-size="18" font-family="monospace">=</text>
703
+
704
+ <text x="115" y="250" fill="#f0b030" font-size="16" font-family="monospace">dy</text>
705
+ <line x1="110" y1="255" x2="135" y2="255" stroke="#f0b030" stroke-width="2"/>
706
+ <text x="115" y="272" fill="#ff6b6b" font-size="16" font-family="monospace" text-decoration="line-through">du</text>
707
+
708
+ <text x="150" y="260" fill="#888" font-size="18" font-family="monospace">×</text>
709
+
710
+ <text x="175" y="250" fill="#ff6b6b" font-size="16" font-family="monospace" text-decoration="line-through">du</text>
711
+ <line x1="170" y1="255" x2="195" y2="255" stroke="#f0b030" stroke-width="2"/>
712
+ <text x="175" y="272" fill="#4ade80" font-size="16" font-family="monospace">dx</text>
713
+
714
+ <text x="210" y="260" fill="#888" font-size="18" font-family="monospace">=</text>
715
+
716
+ <text x="235" y="250" fill="#f0b030" font-size="16" font-family="monospace">dy</text>
717
+ <line x1="230" y1="255" x2="255" y2="255" stroke="#4ade80" stroke-width="2"/>
718
+ <text x="235" y="272" fill="#4ade80" font-size="16" font-family="monospace">dx</text>
719
+
720
+
721
+ <!-- Section 3: Concrete Example -->
722
+ <rect x="460" y="170" width="420" height="140" rx="8" fill="#141414" stroke="#4ade80" stroke-width="1"/>
723
+ <text x="480" y="192" fill="#4ade80" font-size="14" font-family="monospace" font-weight="bold">3. EXAMPLE: y = (3x + 2)²</text>
724
+
725
+ <!-- Break down - repositioned -->
726
+ <text x="480" y="222" fill="#5b9bd5" font-size="12" font-family="monospace">Inner: u = 3x+2</text>
727
+ <text x="630" y="222" fill="#888" font-size="12" font-family="monospace">→ du/dx = 3</text>
728
+
729
+ <text x="480" y="248" fill="#5b9bd5" font-size="12" font-family="monospace">Outer: y = u²</text>
730
+ <text x="630" y="248" fill="#888" font-size="12" font-family="monospace">→ dy/du = 2u</text>
731
+
732
+ <line x1="480" y1="260" x2="860" y2="260" stroke="#333" stroke-width="1"/>
733
+
734
+ <text x="480" y="282" fill="#f0b030" font-size="13" font-family="monospace">Chain: dy/dx = 2u × 3 = 6(3x+2)</text>
735
+
736
+ <!-- Section 4: Interactive -->
737
+ <rect x="20" y="320" width="860" height="145" rx="8" fill="#1a1a1a" stroke="#888" stroke-width="1" stroke-dasharray="5,5"/>
738
+ <text x="450" y="350" text-anchor="middle" fill="#888" font-size="15" font-family="monospace">▼ INTERACTIVE TERMINAL ▼</text>
739
+ <text x="450" y="380" text-anchor="middle" fill="#5b9bd5" font-size="13" font-family="monospace">Adjust a, b, and x with the sliders</text>
740
+ <text x="450" y="410" text-anchor="middle" fill="#4ade80" font-size="13" font-family="monospace">Click "APPLY CHAIN RULE" to see values flow through</text>
741
+ <text x="450" y="440" text-anchor="middle" fill="#555" font-size="10" font-family="monospace">[Remember: derivatives chain together like Vault access codes]</text>
742
+ </svg>
743
+ '''
744
+
745
+ CHAIN_RULE_INTRO = f"""
746
+ {CHAIN_RULE_INTRO_SVG}
747
+ """
748
+
749
+ def generate_chain_rule_svg(a, b, x_val, u, y, du_dx, dy_du, dy_dx):
750
+ """Generate SVG showing chain rule visually."""
751
+
752
+ bg = "#0c0c0c"
753
+ node_fill = "#1a2a3a"
754
+ node_stroke = "#5b9bd5"
755
+ input_fill = "#1a3a2a"
756
+ input_stroke = "#4ade80"
757
+ output_fill = "#2a2a1a"
758
+ output_stroke = "#f0b030"
759
+ text_color = "#4ade80"
760
+ label_color = "#5b9bd5"
761
+ arrow_color = "#5b9bd5"
762
+ value_color = "#f0b030"
763
+ deriv_color = "#ff6b6b"
764
+
765
+ svg = f'''
766
+ <svg viewBox="0 0 800 280" style="width:100%; max-width:800px; height:auto; background:{bg}; border-radius:8px; border:1px solid #333;">
767
+ <defs>
768
+ <marker id="arr" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
769
+ <polygon points="0 0, 10 3.5, 0 7" fill="{arrow_color}" />
770
+ </marker>
771
+ <marker id="arr-red" markerWidth="10" markerHeight="7" refX="0" refY="3.5" orient="auto">
772
+ <polygon points="10 0, 0 3.5, 10 7" fill="{deriv_color}" />
773
+ </marker>
774
+ </defs>
775
+
776
+ <!-- Title -->
777
+ <text x="400" y="30" text-anchor="middle" fill="{output_stroke}" font-size="18" font-family="monospace">CHAIN RULE: y = ({a}x + {b})²</text>
778
+
779
+ <!-- Input x -->
780
+ <rect x="50" y="80" width="80" height="60" rx="8" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/>
781
+ <text x="90" y="105" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">x</text>
782
+ <text x="90" y="128" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace">{x_val:.2f}</text>
783
+
784
+ <!-- Arrow x to u -->
785
+ <line x1="130" y1="110" x2="200" y2="110" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arr)"/>
786
+ <text x="165" y="100" text-anchor="middle" fill="{text_color}" font-size="12" font-family="monospace">g(x)</text>
787
+
788
+ <!-- Inner function u -->
789
+ <rect x="210" y="80" width="120" height="60" rx="8" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/>
790
+ <text x="270" y="100" text-anchor="middle" fill="{label_color}" font-size="12" font-family="monospace">u = {a}x + {b}</text>
791
+ <text x="270" y="125" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace">u = {u:.2f}</text>
792
+
793
+ <!-- Arrow u to y -->
794
+ <line x1="330" y1="110" x2="400" y2="110" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arr)"/>
795
+ <text x="365" y="100" text-anchor="middle" fill="{text_color}" font-size="12" font-family="monospace">f(u)</text>
796
+
797
+ <!-- Outer function y -->
798
+ <rect x="410" y="80" width="100" height="60" rx="8" fill="{output_fill}" stroke="{output_stroke}" stroke-width="2"/>
799
+ <text x="460" y="100" text-anchor="middle" fill="{label_color}" font-size="12" font-family="monospace">y = u²</text>
800
+ <text x="460" y="125" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace">y = {y:.2f}</text>
801
+
802
+ <!-- Derivative arrows (below, going backwards) -->
803
+ <line x1="410" y1="160" x2="335" y2="160" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3" marker-end="url(#arr-red)"/>
804
+ <text x="372" y="180" text-anchor="middle" fill="{deriv_color}" font-size="12" font-family="monospace">dy/du = {dy_du:.2f}</text>
805
+
806
+ <line x1="210" y1="160" x2="135" y2="160" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3" marker-end="url(#arr-red)"/>
807
+ <text x="172" y="180" text-anchor="middle" fill="{deriv_color}" font-size="12" font-family="monospace">du/dx = {du_dx:.2f}</text>
808
+
809
+ <!-- Chain rule result box -->
810
+ <rect x="540" y="70" width="230" height="120" rx="8" fill="#141414" stroke="#333" stroke-width="1"/>
811
+ <text x="655" y="95" text-anchor="middle" fill="{output_stroke}" font-size="14" font-family="monospace">CHAIN RULE</text>
812
+ <text x="555" y="120" fill="{text_color}" font-size="13" font-family="monospace">dy/dx = dy/du × du/dx</text>
813
+ <text x="555" y="145" fill="{text_color}" font-size="13" font-family="monospace"> = {dy_du:.2f} × {du_dx:.2f}</text>
814
+ <text x="555" y="175" fill="{value_color}" font-size="16" font-family="monospace"> = {dy_dx:.2f}</text>
815
+
816
+ <!-- Legend -->
817
+ <line x1="50" y1="240" x2="90" y2="240" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arr)"/>
818
+ <text x="100" y="244" fill="{text_color}" font-size="11" font-family="monospace">Forward</text>
819
+
820
+ <line x1="200" y1="240" x2="240" y2="240" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3" marker-end="url(#arr-red)"/>
821
+ <text x="250" y="244" fill="{text_color}" font-size="11" font-family="monospace">Derivatives (multiply!)</text>
822
+ </svg>
823
+ '''
824
+ return svg
825
+
826
+
827
+ def chain_rule_calculator(a, b, x_val):
828
+ """Demonstrate chain rule with y = (ax + b)^2"""
829
+
830
+ # u = ax + b
831
+ u = a * x_val + b
832
+
833
+ # y = u^2
834
+ y = u ** 2
835
+
836
+ # Derivatives
837
+ du_dx = a
838
+ dy_du = 2 * u
839
+ dy_dx = dy_du * du_dx
840
+
841
+ # Generate SVG
842
+ svg_diagram = generate_chain_rule_svg(a, b, x_val, u, y, du_dx, dy_du, dy_dx)
843
+
844
+ explanation = f"""
845
+ ## CHAIN RULE CALCULATION: y = ({a}x + {b})^2
846
+ ===============================================
847
+
848
+ ### Setting up the composition:
849
+
850
+ ```
851
+ Inner function: u = {a}x + {b}
852
+ Outer function: y = u^2
853
+ ```
854
+
855
+ At x = {x_val}:
856
+ ```
857
+ u = {a}*{x_val} + {b} = {u}
858
+ y = ({u})^2 = {y}
859
+ ```
860
+
861
+ -----------------------------------------------
862
+
863
+ ### Step 1: Find du/dx (derivative of inner function)
864
+
865
+ ```
866
+ u = {a}x + {b}
867
+
868
+ du/dx = {a} (coefficient of x)
869
+ ```
870
+
871
+ -----------------------------------------------
872
+
873
+ ### Step 2: Find dy/du (derivative of outer function)
874
+
875
+ ```
876
+ y = u^2
877
+
878
+ dy/du = 2u = 2*({u}) = {dy_du}
879
+ ```
880
+
881
+ -----------------------------------------------
882
+
883
+ ### Step 3: Apply the Chain Rule!
884
+
885
+ ```
886
+ dy/dx = (dy/du) * (du/dx)
887
+ = {dy_du} * {du_dx}
888
+ = {dy_dx}
889
+ ```
890
+
891
+ -----------------------------------------------
892
+
893
+ ### VERIFICATION (optional sanity check)
894
+
895
+ If x increases by tiny amount h=0.001:
896
+ ```
897
+ y(x+h) = ({a}*{x_val+0.001} + {b})^2 = {(a*(x_val+0.001) + b)**2:.6f}
898
+ y(x) = {y}
899
+
900
+ Slope ≈ (y(x+h) - y(x)) / h
901
+ = {((a*(x_val+0.001) + b)**2 - y) / 0.001:.4f}
902
+
903
+ Our dy/dx = {dy_dx}
904
+ ```
905
+
906
+ The chain rule works!
907
+ """
908
+ return svg_diagram, explanation
909
+
910
+
911
+ # ============================================================================
912
+ # TAB 3: DERIVATIVES OF KEY FUNCTIONS
913
+ # ============================================================================
914
+
915
+ DERIVATIVES_INTRO_SVG = '''
916
+ <svg viewBox="0 0 900 520" style="width:100%; max-width:900px; height:auto; background:#0c0c0c; border-radius:8px; border:1px solid #333; margin-bottom:20px;">
917
+
918
+ <!-- Title -->
919
+ <text x="450" y="32" text-anchor="middle" fill="#f0b030" font-size="18" font-family="monospace" font-weight="bold">KEY DERIVATIVES YOU NEED TO KNOW</text>
920
+
921
+ <!-- Sigmoid Section -->
922
+ <rect x="20" y="50" width="430" height="200" rx="8" fill="#141414" stroke="#4ade80" stroke-width="2"/>
923
+ <text x="35" y="75" fill="#4ade80" font-size="14" font-family="monospace" font-weight="bold">1. SIGMOID FUNCTION</text>
924
+
925
+ <!-- Mini sigmoid curve - compact -->
926
+ <polyline points="35,150 50,148 65,145 80,140 95,130 110,115 125,102 140,94 155,89 170,87"
927
+ fill="none" stroke="#4ade80" stroke-width="2"/>
928
+ <line x1="35" y1="155" x2="170" y2="155" stroke="#333" stroke-width="1"/>
929
+ <line x1="102" y1="87" x2="102" y2="155" stroke="#333" stroke-width="1" stroke-dasharray="3,3"/>
930
+ <text x="102" y="170" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">z=0 → σ=0.5</text>
931
+
932
+ <!-- Formula - compact layout -->
933
+ <text x="185" y="95" fill="#5b9bd5" font-size="10" font-family="monospace">Function:</text>
934
+ <text x="185" y="112" fill="#f0b030" font-size="10" font-family="monospace">σ(z) = 1/(1+e⁻ᶻ)</text>
935
+
936
+ <text x="185" y="135" fill="#5b9bd5" font-size="10" font-family="monospace">Derivative:</text>
937
+ <text x="185" y="152" fill="#ff6b6b" font-size="10" font-family="monospace">dσ/dz = σ(z)(1-σ(z))</text>
938
+
939
+ <!-- Key insight box -->
940
+ <rect x="30" y="190" width="410" height="50" rx="4" fill="#1a2a1a" stroke="#4ade80" stroke-width="1"/>
941
+ <text x="45" y="210" fill="#4ade80" font-size="10" font-family="monospace">Derivative uses the function itself!</text>
942
+ <text x="45" y="228" fill="#888" font-size="10" font-family="monospace">Already have σ(z)? No extra work needed.</text>
943
+
944
+ <!-- BCE Section -->
945
+ <rect x="460" y="50" width="420" height="200" rx="8" fill="#141414" stroke="#ff6b6b" stroke-width="2"/>
946
+ <text x="480" y="75" fill="#ff6b6b" font-size="14" font-family="monospace" font-weight="bold">2. BINARY CROSS-ENTROPY</text>
947
+
948
+ <text x="480" y="100" fill="#5b9bd5" font-size="11" font-family="monospace">Loss Function:</text>
949
+ <text x="480" y="120" fill="#f0b030" font-size="11" font-family="monospace">L = -[y·log(ŷ) + (1-y)·log(1-ŷ)]</text>
950
+
951
+ <text x="480" y="150" fill="#5b9bd5" font-size="11" font-family="monospace">Derivative w.r.t. ŷ:</text>
952
+ <text x="480" y="170" fill="#ff6b6b" font-size="11" font-family="monospace">dL/dŷ = -y/ŷ + (1-y)/(1-ŷ)</text>
953
+
954
+ <!-- Magic box -->
955
+ <rect x="480" y="190" width="385" height="50" rx="4" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="1"/>
956
+ <text x="495" y="210" fill="#f0b030" font-size="10" font-family="monospace">Combined with sigmoid:</text>
957
+ <text x="495" y="227" fill="#4ade80" font-size="13" font-family="monospace">dL/dz = ŷ - y (Vault-Tec approved)</text>
958
+
959
+ <!-- Common Patterns Section -->
960
+ <rect x="20" y="260" width="860" height="130" rx="8" fill="#141414" stroke="#5b9bd5" stroke-width="1"/>
961
+ <text x="40" y="283" fill="#5b9bd5" font-size="14" font-family="monospace" font-weight="bold">3. DERIVATIVE PATTERNS</text>
962
+
963
+ <!-- Pattern boxes - evenly spaced -->
964
+ <rect x="35" y="300" width="135" height="75" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
965
+ <text x="102" y="320" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">Powers</text>
966
+ <text x="102" y="342" text-anchor="middle" fill="#f0b030" font-size="12" font-family="monospace">d/dx[xⁿ]</text>
967
+ <text x="102" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= n·xⁿ⁻¹</text>
968
+
969
+ <rect x="185" y="300" width="135" height="75" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
970
+ <text x="252" y="320" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">Exponential</text>
971
+ <text x="252" y="342" text-anchor="middle" fill="#f0b030" font-size="12" font-family="monospace">d/dx[eˣ]</text>
972
+ <text x="252" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= eˣ</text>
973
+
974
+ <rect x="335" y="300" width="135" height="75" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
975
+ <text x="402" y="320" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">Logarithm</text>
976
+ <text x="402" y="342" text-anchor="middle" fill="#f0b030" font-size="12" font-family="monospace">d/dx[log(x)]</text>
977
+ <text x="402" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= 1/x</text>
978
+
979
+ <rect x="485" y="300" width="135" height="75" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
980
+ <text x="552" y="320" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">Sigmoid</text>
981
+ <text x="552" y="342" text-anchor="middle" fill="#f0b030" font-size="12" font-family="monospace">d/dx[σ(x)]</text>
982
+ <text x="552" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= σ(1-σ)</text>
983
+
984
+ <rect x="635" y="300" width="230" height="75" rx="4" fill="#2a2a1a" stroke="#f0b030" stroke-width="1"/>
985
+ <text x="750" y="320" text-anchor="middle" fill="#f0b030" font-size="11" font-family="monospace">Chain Rule</text>
986
+ <text x="750" y="342" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">d/dx[f(g(x))]</text>
987
+ <text x="750" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= f'(g(x)) · g'(x)</text>
988
+
989
+ <!-- Interactive prompt -->
990
+ <rect x="20" y="405" width="860" height="100" rx="8" fill="#1a1a1a" stroke="#888" stroke-width="1" stroke-dasharray="5,5"/>
991
+ <text x="450" y="430" text-anchor="middle" fill="#888" font-size="14" font-family="monospace">INTERACTIVE TERMINAL</text>
992
+ <text x="450" y="455" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">Move the z slider to see sigmoid and its derivative</text>
993
+ <text x="450" y="480" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">Derivative peaks at z=0, vanishes at extremes</text>
994
+ </svg>
995
+ '''
996
+
997
+ DERIVATIVES_INTRO = f"""
998
+ {DERIVATIVES_INTRO_SVG}
999
+ """
1000
+
1001
+ def generate_sigmoid_svg(z, sig, dsig):
1002
+ """Generate SVG showing sigmoid function and derivative visually."""
1003
+
1004
+ bg = "#0c0c0c"
1005
+ curve_color = "#4ade80"
1006
+ deriv_color = "#ff6b6b"
1007
+ point_color = "#f0b030"
1008
+ grid_color = "#333"
1009
+ text_color = "#4ade80"
1010
+ label_color = "#5b9bd5"
1011
+
1012
+ # Generate sigmoid curve points
1013
+ curve_points = []
1014
+ for i in range(-50, 51):
1015
+ x_pt = i / 10 # -5 to 5
1016
+ y_pt = 1 / (1 + np.exp(-x_pt))
1017
+ # Map to SVG coordinates: x: -5..5 -> 100..500, y: 0..1 -> 250..50
1018
+ svg_x = 100 + (x_pt + 5) * 40
1019
+ svg_y = 250 - y_pt * 200
1020
+ curve_points.append(f"{svg_x:.1f},{svg_y:.1f}")
1021
+
1022
+ curve_path = " ".join(curve_points)
1023
+
1024
+ # Current point coordinates
1025
+ pt_x = 100 + (z + 5) * 40
1026
+ pt_y = 250 - sig * 200
1027
+
1028
+ # Tangent line (slope = dsig, in SVG coordinates)
1029
+ # The slope in data space is dsig, but in SVG space y is inverted
1030
+ tangent_dx = 40
1031
+ tangent_dy = -dsig * 200
1032
+ t_x1 = pt_x - tangent_dx
1033
+ t_y1 = pt_y - tangent_dy
1034
+ t_x2 = pt_x + tangent_dx
1035
+ t_y2 = pt_y + tangent_dy
1036
+
1037
+ svg = f'''
1038
+ <svg viewBox="0 0 700 320" style="width:100%; max-width:700px; height:auto; background:{bg}; border-radius:8px; border:1px solid #333;">
1039
+
1040
+ <!-- Title -->
1041
+ <text x="350" y="25" text-anchor="middle" fill="{point_color}" font-size="16" font-family="monospace">SIGMOID FUNCTION & DERIVATIVE</text>
1042
+
1043
+ <!-- Grid lines -->
1044
+ <line x1="100" y1="150" x2="500" y2="150" stroke="{grid_color}" stroke-width="1" stroke-dasharray="3,3"/>
1045
+ <text x="95" y="154" text-anchor="end" fill="{grid_color}" font-size="10" font-family="monospace">0.5</text>
1046
+
1047
+ <line x1="300" y1="50" x2="300" y2="250" stroke="{grid_color}" stroke-width="1" stroke-dasharray="3,3"/>
1048
+ <text x="300" y="265" text-anchor="middle" fill="{grid_color}" font-size="10" font-family="monospace">z=0</text>
1049
+
1050
+ <!-- Axes -->
1051
+ <line x1="100" y1="250" x2="500" y2="250" stroke="{label_color}" stroke-width="1"/>
1052
+ <line x1="100" y1="50" x2="100" y2="250" stroke="{label_color}" stroke-width="1"/>
1053
+
1054
+ <!-- Axis labels -->
1055
+ <text x="300" y="280" text-anchor="middle" fill="{label_color}" font-size="12" font-family="monospace">z</text>
1056
+ <text x="70" y="150" text-anchor="middle" fill="{label_color}" font-size="12" font-family="monospace">σ(z)</text>
1057
+ <text x="95" y="55" text-anchor="end" fill="{grid_color}" font-size="10" font-family="monospace">1.0</text>
1058
+ <text x="95" y="254" text-anchor="end" fill="{grid_color}" font-size="10" font-family="monospace">0.0</text>
1059
+ <text x="100" y="275" text-anchor="middle" fill="{grid_color}" font-size="10" font-family="monospace">-5</text>
1060
+ <text x="500" y="275" text-anchor="middle" fill="{grid_color}" font-size="10" font-family="monospace">5</text>
1061
+
1062
+ <!-- Sigmoid curve -->
1063
+ <polyline points="{curve_path}" fill="none" stroke="{curve_color}" stroke-width="2"/>
1064
+
1065
+ <!-- Tangent line at current point -->
1066
+ <line x1="{t_x1:.1f}" y1="{t_y1:.1f}" x2="{t_x2:.1f}" y2="{t_y2:.1f}" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3"/>
1067
+
1068
+ <!-- Current point -->
1069
+ <circle cx="{pt_x:.1f}" cy="{pt_y:.1f}" r="8" fill="{point_color}" stroke="#fff" stroke-width="2"/>
1070
+
1071
+ <!-- Point label -->
1072
+ <line x1="{pt_x:.1f}" y1="{pt_y:.1f}" x2="{pt_x + 30:.1f}" y2="{pt_y - 30:.1f}" stroke="{point_color}" stroke-width="1"/>
1073
+ <text x="{pt_x + 35:.1f}" y="{pt_y - 35:.1f}" fill="{point_color}" font-size="11" font-family="monospace">z={z:.1f}</text>
1074
+ <text x="{pt_x + 35:.1f}" y="{pt_y - 22:.1f}" fill="{point_color}" font-size="11" font-family="monospace">σ={sig:.3f}</text>
1075
+
1076
+ <!-- Info box -->
1077
+ <rect x="530" y="60" width="155" height="140" rx="6" fill="#141414" stroke="#333" stroke-width="1"/>
1078
+ <text x="607" y="85" text-anchor="middle" fill="{point_color}" font-size="13" font-family="monospace">VALUES</text>
1079
+
1080
+ <text x="540" y="110" fill="{label_color}" font-size="12" font-family="monospace">z = {z:.2f}</text>
1081
+ <text x="540" y="130" fill="{curve_color}" font-size="12" font-family="monospace">σ(z) = {sig:.4f}</text>
1082
+ <text x="540" y="155" fill="{deriv_color}" font-size="12" font-family="monospace">dσ/dz = {dsig:.4f}</text>
1083
+
1084
+ <text x="540" y="185" fill="{text_color}" font-size="10" font-family="monospace">= σ(1-σ)</text>
1085
+ <text x="540" y="198" fill="{text_color}" font-size="10" font-family="monospace">= {sig:.3f}×{1-sig:.3f}</text>
1086
+
1087
+ <!-- Legend -->
1088
+ <line x1="530" y1="240" x2="560" y2="240" stroke="{curve_color}" stroke-width="2"/>
1089
+ <text x="565" y="244" fill="{text_color}" font-size="10" font-family="monospace">σ(z)</text>
1090
+
1091
+ <line x1="620" y1="240" x2="650" y2="240" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3"/>
1092
+ <text x="655" y="244" fill="{text_color}" font-size="10" font-family="monospace">tangent</text>
1093
+
1094
+ </svg>
1095
+ '''
1096
+ return svg
1097
+
1098
+
1099
+ def sigmoid_derivative_demo(z):
1100
+ """Show sigmoid and its derivative."""
1101
+
1102
+ sig = 1 / (1 + np.exp(-z))
1103
+ dsig = sig * (1 - sig)
1104
+
1105
+ svg_diagram = generate_sigmoid_svg(z, sig, dsig)
1106
+
1107
+ explanation = f"""
1108
+ ## SIGMOID DERIVATIVE AT z = {z}
1109
+ ===============================================
1110
+
1111
+ ### Step 1: Compute sigmoid(z)
1112
+
1113
+ ```
1114
+ σ(z) = 1 / (1 + e^(-z))
1115
+ = 1 / (1 + e^(-{z}))
1116
+ = 1 / (1 + {np.exp(-z):.6f})
1117
+ = 1 / {1 + np.exp(-z):.6f}
1118
+ = {sig:.6f}
1119
+ ```
1120
+
1121
+ -----------------------------------------------
1122
+
1123
+ ### Step 2: Compute the derivative
1124
+
1125
+ Using the formula: dσ/dz = σ(z) * (1 - σ(z))
1126
+
1127
+ ```
1128
+ dσ/dz = σ(z) * (1 - σ(z))
1129
+ = {sig:.6f} * (1 - {sig:.6f})
1130
+ = {sig:.6f} * {1-sig:.6f}
1131
+ = {dsig:.6f}
1132
+ ```
1133
+
1134
+ -----------------------------------------------
1135
+
1136
+ ### Interpretation
1137
+
1138
+ At z = {z}:
1139
+ - Sigmoid output: {sig:.4f} (how confident the neuron is)
1140
+ - Derivative: {dsig:.4f} (how sensitive output is to z)
1141
+
1142
+ **Key insight:** The derivative is LARGEST when z≈0 (sigmoid≈0.5)
1143
+ and SMALLEST when |z| is large. This is the "vanishing gradient"
1144
+ problem - extreme values barely update!
1145
+
1146
+ ```
1147
+ z = 0 --> σ = 0.5, dσ/dz = 0.25 (maximum!)
1148
+ z = 5 --> σ ≈ 0.99, dσ/dz ≈ 0.007 (tiny!)
1149
+ z = -5 --> σ ≈ 0.01, dσ/dz ≈ 0.007 (tiny!)
1150
+ ```
1151
+ """
1152
+ return svg_diagram, explanation
1153
+
1154
+
1155
+ # ============================================================================
1156
+ # TAB 4: BACKWARD PASS (THE MAIN EVENT)
1157
+ # ============================================================================
1158
+
1159
+ BACKWARD_INTRO_SVG = '''
1160
+ <svg viewBox="0 0 900 500" style="width:100%; max-width:900px; height:auto; background:#0c0c0c; border-radius:8px; border:1px solid #333; margin-bottom:20px;">
1161
+ <defs>
1162
+ <marker id="fwd-b" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
1163
+ <polygon points="0 0, 10 3.5, 0 7" fill="#5b9bd5" />
1164
+ </marker>
1165
+ <marker id="bwd-b" markerWidth="10" markerHeight="7" refX="0" refY="3.5" orient="auto">
1166
+ <polygon points="10 0, 0 3.5, 10 7" fill="#ff6b6b" />
1167
+ </marker>
1168
+ </defs>
1169
+
1170
+ <!-- Title -->
1171
+ <text x="450" y="32" text-anchor="middle" fill="#f0b030" font-size="18" font-family="monospace" font-weight="bold">BACKPROPAGATION: LEARNING FROM MISTAKES</text>
1172
+
1173
+ <!-- Main concept - THE BIG PICTURE -->
1174
+ <rect x="20" y="50" width="550" height="170" rx="8" fill="#141414" stroke="#5b9bd5" stroke-width="1"/>
1175
+ <text x="40" y="75" fill="#5b9bd5" font-size="14" font-family="monospace" font-weight="bold">THE BIG PICTURE</text>
1176
+
1177
+ <!-- Forward pass row -->
1178
+ <text x="40" y="105" fill="#5b9bd5" font-size="12" font-family="monospace">FWD:</text>
1179
+ <rect x="85" y="90" width="50" height="30" rx="4" fill="#1a3a2a" stroke="#4ade80" stroke-width="1"/>
1180
+ <text x="110" y="110" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">x</text>
1181
+ <line x1="135" y1="105" x2="175" y2="105" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-b)"/>
1182
+ <rect x="185" y="90" width="50" height="30" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
1183
+ <text x="210" y="110" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">z</text>
1184
+ <line x1="235" y1="105" x2="275" y2="105" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-b)"/>
1185
+ <rect x="285" y="90" width="50" height="30" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
1186
+ <text x="310" y="110" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">y</text>
1187
+ <line x1="335" y1="105" x2="375" y2="105" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-b)"/>
1188
+ <rect x="385" y="90" width="50" height="30" rx="4" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="1"/>
1189
+ <text x="410" y="110" text-anchor="middle" fill="#ff6b6b" font-size="12" font-family="monospace">L</text>
1190
+ <text x="450" y="110" fill="#888" font-size="11" font-family="monospace">Data flows</text>
1191
+
1192
+ <!-- Backward pass row - wider boxes for partials -->
1193
+ <text x="40" y="160" fill="#ff6b6b" font-size="12" font-family="monospace">BWD:</text>
1194
+ <rect x="85" y="145" width="50" height="30" rx="4" fill="#1a3a2a" stroke="#4ade80" stroke-width="1"/>
1195
+ <text x="110" y="165" text-anchor="middle" fill="#ff6b6b" font-size="10" font-family="monospace">dL/dx</text>
1196
+ <line x1="135" y1="160" x2="175" y2="160" stroke="#ff6b6b" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd-b)"/>
1197
+ <rect x="185" y="145" width="50" height="30" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
1198
+ <text x="210" y="165" text-anchor="middle" fill="#ff6b6b" font-size="10" font-family="monospace">dL/dz</text>
1199
+ <line x1="235" y1="160" x2="275" y2="160" stroke="#ff6b6b" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd-b)"/>
1200
+ <rect x="285" y="145" width="50" height="30" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
1201
+ <text x="310" y="165" text-anchor="middle" fill="#ff6b6b" font-size="10" font-family="monospace">dL/dy</text>
1202
+ <line x1="335" y1="160" x2="375" y2="160" stroke="#ff6b6b" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd-b)"/>
1203
+ <rect x="385" y="145" width="50" height="30" rx="4" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="1"/>
1204
+ <text x="410" y="165" text-anchor="middle" fill="#ff6b6b" font-size="12" font-family="monospace">1</text>
1205
+ <text x="450" y="165" fill="#888" font-size="11" font-family="monospace">Grads flow</text>
1206
+
1207
+ <!-- Goal text -->
1208
+ <text x="40" y="205" fill="#f0b030" font-size="11" font-family="monospace">GOAL: Find dL/dw1, dL/dw2, dL/db to update weights</text>
1209
+
1210
+ <!-- Key insight box -->
1211
+ <rect x="585" y="50" width="295" height="170" rx="6" fill="#2a2a1a" stroke="#f0b030" stroke-width="2"/>
1212
+ <text x="732" y="78" text-anchor="middle" fill="#f0b030" font-size="13" font-family="monospace" font-weight="bold">KEY INSIGHT</text>
1213
+ <text x="600" y="108" fill="#4ade80" font-size="11" font-family="monospace">Same computation graph,</text>
1214
+ <text x="600" y="128" fill="#4ade80" font-size="11" font-family="monospace">opposite direction!</text>
1215
+ <line x1="600" y1="142" x2="865" y2="142" stroke="#333" stroke-width="1"/>
1216
+ <text x="600" y="165" fill="#888" font-size="11" font-family="monospace">At each node multiply:</text>
1217
+ <text x="600" y="188" fill="#5b9bd5" font-size="10" font-family="monospace">upstream × local derivative</text>
1218
+
1219
+ <!-- Chain rule explanation -->
1220
+ <rect x="20" y="235" width="420" height="120" rx="8" fill="#141414" stroke="#ff6b6b" stroke-width="1"/>
1221
+ <text x="40" y="260" fill="#ff6b6b" font-size="13" font-family="monospace" font-weight="bold">CHAIN RULE IN ACTION</text>
1222
+ <text x="40" y="290" fill="#888" font-size="12" font-family="monospace">To find dL/dw1:</text>
1223
+ <text x="40" y="315" fill="#f0b030" font-size="12" font-family="monospace">dL/dw1 = dL/dy * dy/dz * dz/dw1</text>
1224
+ <text x="40" y="340" fill="#4ade80" font-size="11" font-family="monospace">Multiply derivatives along the path!</text>
1225
+
1226
+ <!-- Visual chain - wider boxes -->
1227
+ <rect x="455" y="235" width="425" height="120" rx="8" fill="#141414" stroke="#4ade80" stroke-width="1"/>
1228
+ <text x="475" y="260" fill="#4ade80" font-size="13" font-family="monospace" font-weight="bold">VISUAL MULTIPLICATION</text>
1229
+
1230
+ <rect x="475" y="280" width="62" height="35" rx="4" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="1"/>
1231
+ <text x="506" y="302" text-anchor="middle" fill="#ff6b6b" font-size="10" font-family="monospace">dL/dy</text>
1232
+
1233
+ <text x="545" y="302" fill="#888" font-size="16" font-family="monospace">×</text>
1234
+
1235
+ <rect x="562" y="280" width="62" height="35" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
1236
+ <text x="593" y="302" text-anchor="middle" fill="#5b9bd5" font-size="10" font-family="monospace">dy/dz</text>
1237
+
1238
+ <text x="632" y="302" fill="#888" font-size="16" font-family="monospace">×</text>
1239
+
1240
+ <rect x="650" y="280" width="70" height="35" rx="4" fill="#1a3a2a" stroke="#4ade80" stroke-width="1"/>
1241
+ <text x="685" y="302" text-anchor="middle" fill="#4ade80" font-size="10" font-family="monospace">dz/dw1</text>
1242
+
1243
+ <text x="728" y="302" fill="#888" font-size="16" font-family="monospace">=</text>
1244
+
1245
+ <rect x="745" y="280" width="70" height="35" rx="4" fill="#2a2a1a" stroke="#f0b030" stroke-width="2"/>
1246
+ <text x="780" y="302" text-anchor="middle" fill="#f0b030" font-size="10" font-family="monospace">dL/dw1</text>
1247
+
1248
+ <text x="475" y="340" fill="#888" font-size="11" font-family="monospace">upstream × local = pass backward</text>
1249
+
1250
+ <!-- Interactive prompt -->
1251
+ <rect x="20" y="370" width="860" height="115" rx="8" fill="#1a1a1a" stroke="#888" stroke-width="1" stroke-dasharray="5,5"/>
1252
+ <text x="450" y="395" text-anchor="middle" fill="#888" font-size="14" font-family="monospace">INTERACTIVE TERMINAL</text>
1253
+ <text x="450" y="420" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">Blue arrows = forward data | Red arrows = backward gradients</text>
1254
+ <text x="450" y="445" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">Click "EXECUTE FULL BACKPROP" to see all values calculated</text>
1255
+ <text x="450" y="470" text-anchor="middle" fill="#555" font-size="10" font-family="monospace">[Vault-Tec tip: errors propagate backward, just like rumors in the cafeteria]</text>
1256
+ </svg>
1257
+ '''
1258
+
1259
+ BACKWARD_INTRO = f"""
1260
+ {BACKWARD_INTRO_SVG}
1261
+ """
1262
+
1263
+ def backward_pass_demo(x1, x2, w1, w2, b, y_true):
1264
+ """Complete forward + backward pass with detailed chain rule."""
1265
+
1266
+ # Forward pass
1267
+ z = w1 * x1 + w2 * x2 + b
1268
+ y_pred = 1 / (1 + np.exp(-z))
1269
+
1270
+ # Binary cross-entropy loss (with small epsilon for numerical stability)
1271
+ eps = 1e-7
1272
+ y_pred_clipped = np.clip(y_pred, eps, 1 - eps)
1273
+ loss = -(y_true * np.log(y_pred_clipped) + (1 - y_true) * np.log(1 - y_pred_clipped))
1274
+
1275
+ # Backward pass - compute all gradients
1276
+ # dL/dy_pred
1277
+ dL_dy = -y_true / y_pred_clipped + (1 - y_true) / (1 - y_pred_clipped)
1278
+
1279
+ # dy_pred/dz (sigmoid derivative)
1280
+ dy_dz = y_pred * (1 - y_pred)
1281
+
1282
+ # dz/dw1, dz/dw2, dz/db
1283
+ dz_dw1 = x1
1284
+ dz_dw2 = x2
1285
+ dz_db = 1
1286
+
1287
+ # Chain rule to get final gradients
1288
+ dL_dz = dL_dy * dy_dz # This is the "upstream gradient"
1289
+ dL_dw1 = dL_dz * dz_dw1
1290
+ dL_dw2 = dL_dz * dz_dw2
1291
+ dL_db = dL_dz * dz_db
1292
+
1293
+ # Generate SVG diagram
1294
+ svg_diagram = generate_backward_svg(
1295
+ x1, x2, w1, w2, b, y_true, z, y_pred,
1296
+ dL_dy, dy_dz, dL_dz, dL_dw1, dL_dw2, dL_db, loss
1297
+ )
1298
+
1299
+ explanation = f"""
1300
+ ## COMPLETE BACKPROP WALKTHROUGH
1301
+ ===============================================
1302
+
1303
+ ### GIVEN:
1304
+ ```
1305
+ Inputs: x1 = {x1}, x2 = {x2}
1306
+ Weights: w1 = {w1}, w2 = {w2}
1307
+ Bias: b = {b}
1308
+ True label: y_true = {y_true}
1309
+ ```
1310
+
1311
+ ===============================================
1312
+ ## PART 1: FORWARD PASS (review)
1313
+ ===============================================
1314
+
1315
+ **Step 1a: Weighted sum**
1316
+ ```
1317
+ z = w1*x1 + w2*x2 + b
1318
+ = ({w1})*({x1}) + ({w2})*({x2}) + ({b})
1319
+ = {z:.6f}
1320
+ ```
1321
+
1322
+ **Step 1b: Sigmoid activation**
1323
+ ```
1324
+ y_pred = sigmoid(z) = 1/(1+e^(-z))
1325
+ = 1/(1+e^(-{z:.4f}))
1326
+ = {y_pred:.6f}
1327
+ ```
1328
+
1329
+ **Step 1c: Binary Cross-Entropy Loss**
1330
+ ```
1331
+ L = -[y_true*log(y_pred) + (1-y_true)*log(1-y_pred)]
1332
+ = -[{y_true}*log({y_pred:.6f}) + {1-y_true}*log({1-y_pred:.6f})]
1333
+ = -[{y_true * np.log(y_pred_clipped):.6f} + {(1-y_true) * np.log(1-y_pred_clipped):.6f}]
1334
+ = {loss:.6f}
1335
+ ```
1336
+
1337
+ ===============================================
1338
+ ## PART 2: BACKWARD PASS (reversing the flow)
1339
+ ===============================================
1340
+
1341
+ We need: dL/dw1, dL/dw2, dL/db
1342
+
1343
+ **The computation graph:**
1344
+ ```
1345
+ w1,x1,w2,x2,b --> z --> y_pred --> L
1346
+ | | |
1347
+ dz/dw dy/dz dL/dy
1348
+ ```
1349
+
1350
+ We work BACKWARDS from Loss to weights.
1351
+
1352
+ -----------------------------------------------
1353
+ ### STEP 2a: dL/dy_pred (how loss changes with prediction)
1354
+
1355
+ ```
1356
+ L = -y_true*log(y_pred) - (1-y_true)*log(1-y_pred)
1357
+
1358
+ dL/dy_pred = -y_true/y_pred + (1-y_true)/(1-y_pred)
1359
+ = -{y_true}/{y_pred:.6f} + {1-y_true}/{1-y_pred:.6f}
1360
+ = {-y_true/y_pred_clipped:.6f} + {(1-y_true)/(1-y_pred_clipped):.6f}
1361
+ = {dL_dy:.6f}
1362
+ ```
1363
+
1364
+ -----------------------------------------------
1365
+ ### STEP 2b: dy_pred/dz (sigmoid derivative)
1366
+
1367
+ Using: d/dz[sigmoid(z)] = sigmoid(z)*(1-sigmoid(z))
1368
+
1369
+ ```
1370
+ dy/dz = y_pred * (1 - y_pred)
1371
+ = {y_pred:.6f} * (1 - {y_pred:.6f})
1372
+ = {y_pred:.6f} * {1-y_pred:.6f}
1373
+ = {dy_dz:.6f}
1374
+ ```
1375
+
1376
+ -----------------------------------------------
1377
+ ### STEP 2c: dz/dw1, dz/dw2, dz/db
1378
+
1379
+ Since z = w1*x1 + w2*x2 + b:
1380
+
1381
+ ```
1382
+ dz/dw1 = x1 = {dz_dw1}
1383
+ dz/dw2 = x2 = {dz_dw2}
1384
+ dz/db = 1 = {dz_db}
1385
+ ```
1386
+
1387
+ -----------------------------------------------
1388
+ ### STEP 2d: CHAIN RULE - Put it together!
1389
+
1390
+ First, compute dL/dz (the "upstream gradient"):
1391
+ ```
1392
+ dL/dz = (dL/dy_pred) * (dy_pred/dz)
1393
+ = {dL_dy:.6f} * {dy_dz:.6f}
1394
+ = {dL_dz:.6f}
1395
+ ```
1396
+
1397
+ Now chain to each weight:
1398
+ ```
1399
+ dL/dw1 = (dL/dz) * (dz/dw1)
1400
+ = {dL_dz:.6f} * {dz_dw1}
1401
+ = {dL_dw1:.6f}
1402
+
1403
+ dL/dw2 = (dL/dz) * (dz/dw2)
1404
+ = {dL_dz:.6f} * {dz_dw2}
1405
+ = {dL_dw2:.6f}
1406
+
1407
+ dL/db = (dL/dz) * (dz/db)
1408
+ = {dL_dz:.6f} * {dz_db}
1409
+ = {dL_db:.6f}
1410
+ ```
1411
+
1412
+ ===============================================
1413
+ ## PART 3: GRADIENT DESCENT UPDATE
1414
+ ===============================================
1415
+
1416
+ With learning rate α = 0.1:
1417
+
1418
+ ```
1419
+ w1_new = w1 - α * dL/dw1
1420
+ = {w1} - 0.1 * {dL_dw1:.6f}
1421
+ = {w1 - 0.1 * dL_dw1:.6f}
1422
+
1423
+ w2_new = w2 - α * dL/dw2
1424
+ = {w2} - 0.1 * {dL_dw2:.6f}
1425
+ = {w2 - 0.1 * dL_dw2:.6f}
1426
+
1427
+ b_new = b - α * dL/db
1428
+ = {b} - 0.1 * {dL_db:.6f}
1429
+ = {b - 0.1 * dL_db:.6f}
1430
+ ```
1431
+
1432
+ **We've completed one step of learning!**
1433
+
1434
+ ===============================================
1435
+ ## SUMMARY TABLE
1436
+ ===============================================
1437
+
1438
+ | Gradient | Value | Meaning |
1439
+ |----------|-------|---------|
1440
+ | dL/dy | {dL_dy:.4f} | Loss sensitivity to prediction |
1441
+ | dy/dz | {dy_dz:.4f} | Sigmoid sensitivity |
1442
+ | dL/dz | {dL_dz:.4f} | "Upstream gradient" |
1443
+ | dL/dw1 | {dL_dw1:.4f} | How to adjust w1 |
1444
+ | dL/dw2 | {dL_dw2:.4f} | How to adjust w2 |
1445
+ | dL/db | {dL_db:.4f} | How to adjust bias |
1446
+ """
1447
+ return svg_diagram, explanation
1448
+
1449
+
1450
+ # ============================================================================
1451
+ # TAB 5: PRACTICE PROBLEMS
1452
+ # ============================================================================
1453
+
1454
+ PRACTICE_INTRO = """
1455
+ # PRACTICE: COMPUTE BY HAND FIRST!
1456
+ ===============================================
1457
+
1458
+ Welcome to the Gradient Occupational Aptitude Test (G.O.A.T.).
1459
+ Per Vault-Tec guidelines, pencil-and-paper practice builds neural
1460
+ pathways (the biological kind). Complete these problems to determine
1461
+ your future as a Machine Learning Specialist.
1462
+
1463
+ ## TIPS FOR HAND CALCULATION
1464
+
1465
+ 1. **Draw the computation graph** - boxes for operations,
1466
+ arrows for data flow
1467
+
1468
+ 2. **Forward pass first** - compute all intermediate values
1469
+
1470
+ 3. **Backward pass** - start from loss, work backwards
1471
+
1472
+ 4. **Check dimensions** - gradient of scalar w.r.t. vector
1473
+ has same shape as the vector
1474
+
1475
+ 5. **Verify numerically** - if unsure, use tiny h to approximate:
1476
+ df/dx ≈ (f(x+h) - f(x)) / h
1477
+
1478
+ ## PRACTICE PROBLEMS
1479
+
1480
+ Select a problem below and try it before clicking "Check Answer"!
1481
+ """
1482
+
1483
+ def practice_problem(problem_num):
1484
+ """Generate practice problems with solutions."""
1485
+
1486
+ problems = {
1487
+ 1: {
1488
+ "question": """
1489
+ ### Problem 1: Simple Chain Rule
1490
+
1491
+ Compute dy/dx where:
1492
+ ```
1493
+ y = (2x + 3)^3
1494
+ ```
1495
+
1496
+ at x = 1.
1497
+
1498
+ **Hint:** Let u = 2x + 3, so y = u^3
1499
+ """,
1500
+ "solution": """
1501
+ ### Solution to Problem 1
1502
+
1503
+ **Step 1: Identify the composition**
1504
+ ```
1505
+ u = 2x + 3 (inner)
1506
+ y = u^3 (outer)
1507
+ ```
1508
+
1509
+ **Step 2: Find individual derivatives**
1510
+ ```
1511
+ du/dx = 2
1512
+ dy/du = 3u^2
1513
+ ```
1514
+
1515
+ **Step 3: Apply chain rule**
1516
+ ```
1517
+ dy/dx = (dy/du) * (du/dx)
1518
+ = 3u^2 * 2
1519
+ = 6u^2
1520
+ = 6(2x + 3)^2
1521
+ ```
1522
+
1523
+ **Step 4: Evaluate at x = 1**
1524
+ ```
1525
+ dy/dx = 6(2*1 + 3)^2
1526
+ = 6(5)^2
1527
+ = 6 * 25
1528
+ = 150
1529
+ ```
1530
+
1531
+ **Answer: dy/dx = 150 at x = 1**
1532
+ """
1533
+ },
1534
+ 2: {
1535
+ "question": """
1536
+ ### Problem 2: Sigmoid Derivative
1537
+
1538
+ Given z = 2, compute:
1539
+ 1. sigmoid(z)
1540
+ 2. d/dz[sigmoid(z)]
1541
+
1542
+ **Reminder:** sigmoid(z) = 1/(1+e^(-z))
1543
+ d/dz[sigmoid(z)] = sigmoid(z) * (1 - sigmoid(z))
1544
+ """,
1545
+ "solution": f"""
1546
+ ### Solution to Problem 2
1547
+
1548
+ **Step 1: Compute sigmoid(2)**
1549
+ ```
1550
+ sigmoid(2) = 1/(1 + e^(-2))
1551
+ = 1/(1 + {np.exp(-2):.6f})
1552
+ = 1/{1 + np.exp(-2):.6f}
1553
+ = {1/(1+np.exp(-2)):.6f}
1554
+ ```
1555
+
1556
+ **Step 2: Compute derivative**
1557
+ ```
1558
+ Let s = sigmoid(2) = {1/(1+np.exp(-2)):.6f}
1559
+
1560
+ ds/dz = s * (1 - s)
1561
+ = {1/(1+np.exp(-2)):.6f} * (1 - {1/(1+np.exp(-2)):.6f})
1562
+ = {1/(1+np.exp(-2)):.6f} * {1 - 1/(1+np.exp(-2)):.6f}
1563
+ = {(1/(1+np.exp(-2))) * (1 - 1/(1+np.exp(-2))):.6f}
1564
+ ```
1565
+
1566
+ **Answers:**
1567
+ - sigmoid(2) ≈ 0.8808
1568
+ - d/dz[sigmoid(2)] ≈ 0.1050
1569
+ """
1570
+ },
1571
+ 3: {
1572
+ "question": """
1573
+ ### Problem 3: Full Backprop (Mini Version)
1574
+
1575
+ Single neuron with:
1576
+ ```
1577
+ x = 2
1578
+ w = 0.5
1579
+ b = -1
1580
+ y_true = 1
1581
+ ```
1582
+
1583
+ Using sigmoid activation and BCE loss, find dL/dw.
1584
+
1585
+ **Steps to follow:**
1586
+ 1. Forward: z = wx + b
1587
+ 2. Forward: y_pred = sigmoid(z)
1588
+ 3. Forward: L = BCE(y_true, y_pred)
1589
+ 4. Backward: Apply chain rule
1590
+ """,
1591
+ "solution": """
1592
+ ### Solution to Problem 3
1593
+
1594
+ **Forward Pass:**
1595
+ ```
1596
+ z = w*x + b = 0.5*2 + (-1) = 0
1597
+
1598
+ y_pred = sigmoid(0) = 0.5
1599
+
1600
+ L = -[1*log(0.5) + 0*log(0.5)]
1601
+ = -log(0.5)
1602
+ = 0.693
1603
+ ```
1604
+
1605
+ **Backward Pass:**
1606
+
1607
+ dL/dy_pred:
1608
+ ```
1609
+ = -y_true/y_pred + (1-y_true)/(1-y_pred)
1610
+ = -1/0.5 + 0/0.5
1611
+ = -2
1612
+ ```
1613
+
1614
+ dy_pred/dz:
1615
+ ```
1616
+ = y_pred * (1 - y_pred)
1617
+ = 0.5 * 0.5 = 0.25
1618
+ ```
1619
+
1620
+ dz/dw:
1621
+ ```
1622
+ = x = 2
1623
+ ```
1624
+
1625
+ **Chain Rule:**
1626
+ ```
1627
+ dL/dw = (dL/dy) * (dy/dz) * (dz/dw)
1628
+ = (-2) * (0.25) * (2)
1629
+ = -1.0
1630
+ ```
1631
+
1632
+ **Answer: dL/dw = -1.0**
1633
+
1634
+ **Interpretation:** Negative gradient means we should
1635
+ INCREASE w to reduce loss (moving opposite to gradient).
1636
+ """
1637
+ }
1638
+ }
1639
+
1640
+ prob = problems.get(problem_num, problems[1])
1641
+ return prob["question"], prob["solution"]
1642
+
1643
+
1644
+ # ============================================================================
1645
+ # BUILD THE GRADIO APP
1646
+ # ============================================================================
1647
+
1648
+ with gr.Blocks(title="BACKPROP TERMINAL v1.0") as demo:
1649
+
1650
+ gr.Markdown("""
1651
+ # > VAULT-TEC NEURAL NETWORK TRAINING TERMINAL
1652
+ ## > SECURITY CLEARANCE: STAT 3106
1653
+ ### > INITIALIZING BACKPROPAGATION MODULES...
1654
+ """)
1655
+
1656
+ with gr.Tabs():
1657
+ # TAB 1: Forward Pass
1658
+ with gr.TabItem("01: FORWARD PASS"):
1659
+ gr.HTML(FORWARD_INTRO)
1660
+
1661
+ with gr.Row():
1662
+ with gr.Column(scale=1):
1663
+ gr.Markdown("### INPUT PARAMETERS")
1664
+ x1_input = gr.Slider(minimum=-5, maximum=5, value=1.0, step=0.1, label="x1 (input 1)")
1665
+ x2_input = gr.Slider(minimum=-5, maximum=5, value=2.0, step=0.1, label="x2 (input 2)")
1666
+ w1_input = gr.Slider(minimum=-2, maximum=2, value=0.5, step=0.1, label="w1 (weight 1)")
1667
+ w2_input = gr.Slider(minimum=-2, maximum=2, value=-0.3, step=0.1, label="w2 (weight 2)")
1668
+ b_input = gr.Slider(minimum=-2, maximum=2, value=0.1, step=0.1, label="b (bias)")
1669
+ forward_btn = gr.Button(">> EXECUTE FORWARD PASS <<")
1670
+
1671
+ with gr.Column(scale=2):
1672
+ forward_svg = gr.HTML(label="Computation Graph")
1673
+ forward_output = gr.Markdown(label="Calculation")
1674
+
1675
+ forward_btn.click(
1676
+ forward_pass_demo,
1677
+ inputs=[x1_input, x2_input, w1_input, w2_input, b_input],
1678
+ outputs=[forward_svg, forward_output]
1679
+ )
1680
+
1681
+ # TAB 2: Chain Rule
1682
+ with gr.TabItem("02: CHAIN RULE"):
1683
+ gr.HTML(CHAIN_RULE_INTRO)
1684
+
1685
+ with gr.Row():
1686
+ with gr.Column(scale=1):
1687
+ gr.Markdown("### FUNCTION: y = (ax + b)²")
1688
+ a_input = gr.Slider(minimum=-5, maximum=5, value=3.0, step=0.1, label="a (coefficient)")
1689
+ b2_input = gr.Slider(minimum=-5, maximum=5, value=2.0, step=0.1, label="b (constant)")
1690
+ x_input = gr.Slider(minimum=-5, maximum=5, value=1.0, step=0.1, label="x (evaluation point)")
1691
+ chain_btn = gr.Button(">> APPLY CHAIN RULE <<")
1692
+
1693
+ with gr.Column(scale=2):
1694
+ chain_svg = gr.HTML(label="Chain Rule Visualization")
1695
+ chain_output = gr.Markdown(label="Chain Rule Breakdown")
1696
+
1697
+ chain_btn.click(
1698
+ chain_rule_calculator,
1699
+ inputs=[a_input, b2_input, x_input],
1700
+ outputs=[chain_svg, chain_output]
1701
+ )
1702
+
1703
+ # TAB 3: Key Derivatives
1704
+ with gr.TabItem("03: KEY DERIVATIVES"):
1705
+ gr.HTML(DERIVATIVES_INTRO)
1706
+
1707
+ with gr.Row():
1708
+ with gr.Column(scale=1):
1709
+ gr.Markdown("### SIGMOID DERIVATIVE CALCULATOR")
1710
+ z_input = gr.Slider(
1711
+ minimum=-5, maximum=5, value=0, step=0.1,
1712
+ label="z value"
1713
+ )
1714
+ sigmoid_btn = gr.Button(">> COMPUTE SIGMOID DERIVATIVE <<")
1715
+
1716
+ with gr.Column(scale=2):
1717
+ sigmoid_svg = gr.HTML(label="Sigmoid Visualization")
1718
+ sigmoid_output = gr.Markdown(label="Derivative Calculation")
1719
+
1720
+ sigmoid_btn.click(
1721
+ sigmoid_derivative_demo,
1722
+ inputs=[z_input],
1723
+ outputs=[sigmoid_svg, sigmoid_output]
1724
+ )
1725
+
1726
+ # TAB 4: Backward Pass
1727
+ with gr.TabItem("04: BACKWARD PASS"):
1728
+ gr.HTML(BACKWARD_INTRO)
1729
+
1730
+ with gr.Row():
1731
+ with gr.Column(scale=1):
1732
+ gr.Markdown("### NETWORK CONFIGURATION")
1733
+ bx1 = gr.Slider(minimum=-5, maximum=5, value=1.0, step=0.1, label="x1")
1734
+ bx2 = gr.Slider(minimum=-5, maximum=5, value=2.0, step=0.1, label="x2")
1735
+ bw1 = gr.Slider(minimum=-2, maximum=2, value=0.5, step=0.1, label="w1")
1736
+ bw2 = gr.Slider(minimum=-2, maximum=2, value=-0.3, step=0.1, label="w2")
1737
+ bb = gr.Slider(minimum=-2, maximum=2, value=0.1, step=0.1, label="bias")
1738
+ by_true = gr.Slider(minimum=0, maximum=1, value=1, step=1, label="y_true (0 or 1)")
1739
+ back_btn = gr.Button(">> EXECUTE FULL BACKPROP <<")
1740
+
1741
+ with gr.Column(scale=2):
1742
+ back_svg = gr.HTML(label="Backprop Graph")
1743
+ back_output = gr.Markdown(label="Complete Backprop Trace")
1744
+
1745
+ back_btn.click(
1746
+ backward_pass_demo,
1747
+ inputs=[bx1, bx2, bw1, bw2, bb, by_true],
1748
+ outputs=[back_svg, back_output]
1749
+ )
1750
+
1751
+ # TAB 5: Practice
1752
+ with gr.TabItem("05: PRACTICE"):
1753
+ gr.Markdown(PRACTICE_INTRO)
1754
+
1755
+ with gr.Row():
1756
+ with gr.Column():
1757
+ problem_select = gr.Radio(
1758
+ choices=["Problem 1: Chain Rule", "Problem 2: Sigmoid", "Problem 3: Full Backprop"],
1759
+ label="Select Problem",
1760
+ value="Problem 1: Chain Rule"
1761
+ )
1762
+ show_problem_btn = gr.Button(">> SHOW PROBLEM <<")
1763
+ show_answer_btn = gr.Button(">> REVEAL SOLUTION <<")
1764
+
1765
+ with gr.Column():
1766
+ problem_display = gr.Markdown(label="Problem")
1767
+ solution_display = gr.Markdown(label="Solution", visible=False)
1768
+
1769
+ def show_problem(selection):
1770
+ prob_num = int(selection.split(":")[0].split()[-1])
1771
+ q, _ = practice_problem(prob_num)
1772
+ return q, gr.update(visible=False, value="")
1773
+
1774
+ def show_solution(selection):
1775
+ prob_num = int(selection.split(":")[0].split()[-1])
1776
+ _, s = practice_problem(prob_num)
1777
+ return gr.update(visible=True, value=s)
1778
+
1779
+ show_problem_btn.click(
1780
+ show_problem,
1781
+ inputs=[problem_select],
1782
+ outputs=[problem_display, solution_display]
1783
+ )
1784
+
1785
+ show_answer_btn.click(
1786
+ show_solution,
1787
+ inputs=[problem_select],
1788
+ outputs=[solution_display]
1789
+ )
1790
+
1791
+ # TAB 6: Quick Reference
1792
+ with gr.TabItem("06: REFERENCE"):
1793
+ gr.Markdown("""
1794
+ # QUICK REFERENCE CARD
1795
+ ===============================================
1796
+
1797
+ ## CHAIN RULE
1798
+
1799
+ ```
1800
+ y = f(g(x))
1801
+
1802
+ dy/dx = (df/dg) * (dg/dx)
1803
+ ```
1804
+
1805
+ For longer chains: just multiply all the derivatives!
1806
+
1807
+ -----------------------------------------------
1808
+
1809
+ ## COMMON DERIVATIVES
1810
+
1811
+ | Function | Derivative |
1812
+ |----------|------------|
1813
+ | x^n | n*x^(n-1) |
1814
+ | e^x | e^x |
1815
+ | log(x) | 1/x |
1816
+ | sigmoid(x) | sigmoid(x)*(1-sigmoid(x)) |
1817
+ | ReLU(x) | 1 if x>0, else 0 |
1818
+
1819
+ -----------------------------------------------
1820
+
1821
+ ## NEURAL NETWORK CHAIN
1822
+
1823
+ For a single neuron with sigmoid:
1824
+ ```
1825
+ z = Σ(wi*xi) + b
1826
+ y = sigmoid(z)
1827
+ L = loss(y, y_true)
1828
+
1829
+ dL/dwi = (dL/dy) * (dy/dz) * (dz/dwi)
1830
+ = (dL/dy) * sigmoid'(z) * xi
1831
+ ```
1832
+
1833
+ -----------------------------------------------
1834
+
1835
+ ## GRADIENT DESCENT
1836
+
1837
+ ```
1838
+ w_new = w_old - learning_rate * dL/dw
1839
+ ```
1840
+
1841
+ The gradient points UPHILL; we go opposite direction.
1842
+
1843
+ -----------------------------------------------
1844
+
1845
+ ## BCE LOSS GRADIENT (sigmoid output)
1846
+
1847
+ For BCE loss with sigmoid output:
1848
+
1849
+ ```
1850
+ dL/dz = y_pred - y_true
1851
+ ```
1852
+
1853
+ This clean result comes from cancellation in the chain!
1854
+
1855
+ -----------------------------------------------
1856
+
1857
+ ## DEBUGGING TIPS
1858
+
1859
+ 1. **Gradient check:** Compare with numerical gradient
1860
+ ```
1861
+ dL/dw ≈ [L(w+h) - L(w-h)] / (2h)
1862
+ ```
1863
+
1864
+ 2. **Shapes must match:** gradient of L w.r.t. W has same shape as W
1865
+
1866
+ 3. **Large gradients?** Try gradient clipping or smaller learning rate
1867
+
1868
+ 4. **Vanishing gradients?** Consider ReLU or residual connections
1869
+ """)
1870
+
1871
+ gr.Markdown("""
1872
+ ---
1873
+ > TERMINAL SESSION ACTIVE
1874
+
1875
+ > VAULT-TEC WISHES YOU A PLEASANT TRAINING EXPERIENCE
1876
+ """)
1877
+
1878
+
1879
+ if __name__ == "__main__":
1880
+ demo.launch(
1881
+ server_port=7860,
1882
+ css=FALLOUT_CSS,
1883
+ js="""
1884
+ () => {
1885
+ // Force dark mode and hide theme toggle
1886
+ document.body.classList.add('dark');
1887
+ const style = document.createElement('style');
1888
+ style.textContent = `
1889
+ .dark-mode-toggle, [aria-label="Toggle dark mode"],
1890
+ button[title*="theme"], .theme-toggle { display: none !important; }
1891
+ `;
1892
+ document.head.appendChild(style);
1893
+ }
1894
+ """
1895
+ )