stat3106 / app.py
db-d2's picture
Feat: Gradient descent and chain rule helper
a681fba
import gradio as gr
import numpy as np
# Fallout Terminal Theme CSS
# Color palette:
# - Pip-Boy Amber: #f0b030 (warm, readable headers)
# - Terminal Green: #4ade80 (softer green, easy on eyes)
# - Vault-Tec Blue: #5b9bd5 (trusty Vault-Tec corporate blue)
# - Background: #0c0c0c (near-black terminal)
# - Panel BG: #141414 (slightly lifted for depth)
FALLOUT_CSS = """
@import url('https://fonts.googleapis.com/css2?family=VT323&display=swap');
@import url('https://fonts.googleapis.com/css2?family=Share+Tech+Mono&display=swap');
:root {
--pip-amber: #f0b030;
--pip-amber-dim: #c49028;
--terminal-green: #4ade80;
--terminal-green-dim: #22c55e;
--vault-blue: #5b9bd5;
--vault-blue-dim: #4080b8;
--bg-dark: #0c0c0c;
--bg-panel: #141414;
--bg-input: #1a1a1a;
--text-muted: #888888;
}
* {
font-family: 'Share Tech Mono', 'VT323', monospace !important;
font-size: 20px !important;
line-height: 1.6 !important;
}
h1 { font-size: 36px !important; }
h2 { font-size: 30px !important; }
h3 { font-size: 24px !important; }
h4, h5 { font-size: 22px !important; }
code, pre { font-size: 18px !important; }
body, .gradio-container {
background-color: var(--bg-dark) !important;
}
.gradio-container {
max-width: 1200px !important;
}
/* Main text - soft green, NO glow */
.markdown-text, .prose, p, span, label, .label-wrap {
color: var(--terminal-green) !important;
}
/* Headers - warm amber for hierarchy */
h1 {
color: var(--pip-amber) !important;
border-bottom: 2px solid var(--pip-amber-dim) !important;
padding-bottom: 8px !important;
}
h2 {
color: var(--pip-amber) !important;
border-bottom: 1px solid var(--pip-amber-dim) !important;
padding-bottom: 4px !important;
}
h3, h4, h5 {
color: var(--vault-blue) !important;
border-bottom: none !important;
}
/* Tab styling - Vault-Tec blue for navigation */
.tabs {
background-color: var(--bg-dark) !important;
border: 1px solid var(--vault-blue-dim) !important;
border-radius: 4px !important;
}
.tab-nav {
background-color: var(--bg-panel) !important;
border-bottom: 2px solid var(--vault-blue-dim) !important;
}
.tab-nav button {
background-color: var(--bg-panel) !important;
color: var(--vault-blue) !important;
border: none !important;
border-right: 1px solid var(--bg-dark) !important;
padding: 10px 16px !important;
transition: all 0.2s ease !important;
}
.tab-nav button:hover {
background-color: #1e3a5f !important;
color: #8ec5fc !important;
}
.tab-nav button.selected {
background-color: #1a3550 !important;
color: #8ec5fc !important;
border-bottom: 2px solid var(--pip-amber) !important;
}
/* Input/Output boxes - subtle with green text */
.textbox, textarea, input {
background-color: var(--bg-input) !important;
color: var(--terminal-green) !important;
border: 1px solid #333 !important;
border-radius: 3px !important;
}
.textbox:focus, textarea:focus, input:focus {
border-color: var(--terminal-green-dim) !important;
outline: none !important;
}
/* Buttons - amber accent for actions */
.primary, .secondary, button {
background-color: #2a2010 !important;
color: var(--pip-amber) !important;
border: 1px solid var(--pip-amber-dim) !important;
border-radius: 3px !important;
transition: all 0.2s ease !important;
}
button:hover {
background-color: #3d2e15 !important;
border-color: var(--pip-amber) !important;
}
/* Sliders - amber accent */
input[type="range"] {
accent-color: var(--pip-amber) !important;
}
/* Number inputs */
.number-input input {
background-color: var(--bg-input) !important;
color: var(--terminal-green) !important;
border: 1px solid #333 !important;
}
/* Code blocks - slightly blue-tinted for distinction */
code, pre {
background-color: #0d1520 !important;
color: var(--terminal-green) !important;
border: 1px solid #2a4060 !important;
border-left: 3px solid var(--vault-blue) !important;
border-radius: 3px !important;
padding: 2px 6px !important;
}
pre {
padding: 12px !important;
}
/* Tables */
table {
border-collapse: collapse !important;
}
th {
background-color: #1a2a3a !important;
color: var(--pip-amber) !important;
border: 1px solid #2a4060 !important;
padding: 8px !important;
}
td {
background-color: var(--bg-panel) !important;
color: var(--terminal-green) !important;
border: 1px solid #2a4060 !important;
padding: 8px !important;
}
/* Strong/bold text - amber for emphasis */
strong, b {
color: var(--pip-amber) !important;
font-weight: bold !important;
}
/* Links */
a {
color: var(--vault-blue) !important;
}
a:hover {
color: #8ec5fc !important;
}
/* Radio buttons and checkboxes */
.radio-group label, .checkbox-group label {
color: var(--terminal-green) !important;
}
/* Scrollbar - subtle */
::-webkit-scrollbar {
width: 8px;
height: 8px;
background-color: var(--bg-dark);
}
::-webkit-scrollbar-thumb {
background-color: #333;
border-radius: 4px;
}
::-webkit-scrollbar-thumb:hover {
background-color: #444;
}
/* Subtle scanlines - very light, not distracting */
.gradio-container::before {
content: "";
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
background: repeating-linear-gradient(
0deg,
rgba(0, 0, 0, 0.03),
rgba(0, 0, 0, 0.03) 1px,
transparent 1px,
transparent 2px
);
pointer-events: none;
z-index: 1000;
}
/* Horizontal rules - amber accent */
hr {
border: none !important;
border-top: 1px solid var(--pip-amber-dim) !important;
margin: 16px 0 !important;
}
/* Blockquotes - for terminal prompts */
blockquote {
border-left: 3px solid var(--pip-amber) !important;
background-color: var(--bg-panel) !important;
padding: 8px 16px !important;
margin: 8px 0 !important;
color: var(--pip-amber) !important;
}
/* Muted/secondary text */
.secondary-text, .hint {
color: var(--text-muted) !important;
}
"""
# ============================================================================
# SVG DIAGRAM GENERATORS
# ============================================================================
def generate_forward_svg(x1, x2, w1, w2, b, z, y):
"""Generate an SVG diagram showing the forward pass with actual values."""
# Colors matching our theme
bg = "#0c0c0c"
node_fill = "#1a2a3a"
node_stroke = "#5b9bd5"
input_fill = "#1a3a2a"
input_stroke = "#4ade80"
output_fill = "#2a2a1a"
output_stroke = "#f0b030"
text_color = "#4ade80"
label_color = "#5b9bd5"
arrow_color = "#5b9bd5"
value_color = "#f0b030"
svg = f'''
<svg viewBox="0 0 800 320" style="width:100%; max-width:800px; height:auto; background:{bg}; border-radius:8px; border:1px solid #333;">
<defs>
<marker id="arrowhead" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill="{arrow_color}" />
</marker>
<marker id="arrowhead-amber" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill="{output_stroke}" />
</marker>
</defs>
<!-- Title -->
<text x="400" y="30" text-anchor="middle" fill="{output_stroke}" font-size="20" font-family="monospace">FORWARD PASS: Data Flow</text>
<!-- Input nodes -->
<rect x="40" y="60" width="80" height="60" rx="8" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/>
<text x="80" y="85" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">x₁</text>
<text x="80" y="108" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace" font-weight="bold">{x1:.2f}</text>
<rect x="40" y="200" width="80" height="60" rx="8" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/>
<text x="80" y="225" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">x₂</text>
<text x="80" y="248" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace" font-weight="bold">{x2:.2f}</text>
<!-- Weight labels on arrows -->
<line x1="120" y1="90" x2="220" y2="140" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arrowhead)"/>
<text x="155" y="100" fill="{text_color}" font-size="12" font-family="monospace">w₁={w1:.2f}</text>
<line x1="120" y1="230" x2="220" y2="180" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arrowhead)"/>
<text x="155" y="230" fill="{text_color}" font-size="12" font-family="monospace">w₂={w2:.2f}</text>
<!-- Summation node -->
<rect x="230" y="130" width="100" height="70" rx="8" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/>
<text x="280" y="152" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">Σ + b</text>
<text x="280" y="175" text-anchor="middle" fill="{text_color}" font-size="11" font-family="monospace">b={b:.2f}</text>
<text x="280" y="193" text-anchor="middle" fill="{value_color}" font-size="14" font-family="monospace" font-weight="bold">z={z:.3f}</text>
<!-- Arrow to sigmoid -->
<line x1="330" y1="165" x2="400" y2="165" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arrowhead)"/>
<!-- Sigmoid node -->
<rect x="410" y="130" width="100" height="70" rx="8" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/>
<text x="460" y="155" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">σ(z)</text>
<text x="460" y="175" text-anchor="middle" fill="{text_color}" font-size="10" font-family="monospace">1/(1+e⁻ᶻ)</text>
<text x="460" y="193" text-anchor="middle" fill="{value_color}" font-size="14" font-family="monospace" font-weight="bold">ŷ={y:.4f}</text>
<!-- Arrow to output -->
<line x1="510" y1="165" x2="580" y2="165" stroke="{output_stroke}" stroke-width="2" marker-end="url(#arrowhead-amber)"/>
<!-- Output node -->
<rect x="590" y="130" width="100" height="70" rx="8" fill="{output_fill}" stroke="{output_stroke}" stroke-width="2"/>
<text x="640" y="155" text-anchor="middle" fill="{output_stroke}" font-size="14" font-family="monospace">OUTPUT</text>
<text x="640" y="180" text-anchor="middle" fill="{value_color}" font-size="18" font-family="monospace" font-weight="bold">{y:.4f}</text>
<!-- Legend -->
<rect x="40" y="280" width="15" height="15" fill="{input_fill}" stroke="{input_stroke}" stroke-width="1"/>
<text x="60" y="292" fill="{text_color}" font-size="12" font-family="monospace">Inputs</text>
<rect x="140" y="280" width="15" height="15" fill="{node_fill}" stroke="{node_stroke}" stroke-width="1"/>
<text x="160" y="292" fill="{text_color}" font-size="12" font-family="monospace">Operations</text>
<rect x="280" y="280" width="15" height="15" fill="{output_fill}" stroke="{output_stroke}" stroke-width="1"/>
<text x="300" y="292" fill="{text_color}" font-size="12" font-family="monospace">Output</text>
<text x="420" y="292" fill="{value_color}" font-size="12" font-family="monospace">■ Computed Values</text>
</svg>
'''
return svg
def generate_backward_svg(x1, x2, w1, w2, b, y_true, z, y_pred, dL_dy, dy_dz, dL_dz, dL_dw1, dL_dw2, dL_db, loss):
"""Generate an SVG diagram showing backward pass with gradients."""
bg = "#0c0c0c"
node_fill = "#1a2a3a"
node_stroke = "#5b9bd5"
input_fill = "#1a3a2a"
input_stroke = "#4ade80"
loss_fill = "#3a1a1a"
loss_stroke = "#ff6b6b"
text_color = "#4ade80"
label_color = "#5b9bd5"
forward_arrow = "#5b9bd5"
backward_arrow = "#ff6b6b"
value_color = "#f0b030"
gradient_color = "#ff6b6b"
svg = f'''
<svg viewBox="0 0 900 420" style="width:100%; max-width:900px; height:auto; background:{bg}; border-radius:8px; border:1px solid #333;">
<defs>
<marker id="fwd" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill="{forward_arrow}" />
</marker>
<marker id="bwd" markerWidth="10" markerHeight="7" refX="0" refY="3.5" orient="auto">
<polygon points="10 0, 0 3.5, 10 7" fill="{backward_arrow}" />
</marker>
</defs>
<!-- Title -->
<text x="300" y="25" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace">GRADIENT FLOW DIAGRAM</text>
<!-- FORWARD PATH - Row 1 -->
<!-- Input x1 -->
<rect x="20" y="50" width="60" height="50" rx="5" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/>
<text x="50" y="70" text-anchor="middle" fill="{label_color}" font-size="11" font-family="monospace">x1</text>
<text x="50" y="88" text-anchor="middle" fill="{value_color}" font-size="12" font-family="monospace">{x1:.2f}</text>
<!-- Weight w1 -->
<line x1="80" y1="75" x2="105" y2="75" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/>
<rect x="115" y="55" width="55" height="40" rx="4" fill="#1a1a2a" stroke="{forward_arrow}" stroke-width="1"/>
<text x="142" y="80" text-anchor="middle" fill="{value_color}" font-size="10" font-family="monospace">w1={w1:.1f}</text>
<!-- Arrow to Sum -->
<line x1="170" y1="75" x2="195" y2="100" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/>
<!-- Input x2 -->
<rect x="20" y="120" width="60" height="50" rx="5" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/>
<text x="50" y="140" text-anchor="middle" fill="{label_color}" font-size="11" font-family="monospace">x2</text>
<text x="50" y="158" text-anchor="middle" fill="{value_color}" font-size="12" font-family="monospace">{x2:.2f}</text>
<!-- Weight w2 -->
<line x1="80" y1="145" x2="105" y2="145" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/>
<rect x="115" y="125" width="55" height="40" rx="4" fill="#1a1a2a" stroke="{forward_arrow}" stroke-width="1"/>
<text x="142" y="150" text-anchor="middle" fill="{value_color}" font-size="10" font-family="monospace">w2={w2:.1f}</text>
<!-- Arrow to Sum -->
<line x1="170" y1="145" x2="195" y2="120" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/>
<!-- Sum node -->
<rect x="205" y="90" width="70" height="50" rx="5" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/>
<text x="240" y="110" text-anchor="middle" fill="{label_color}" font-size="11" font-family="monospace">Sum+b</text>
<text x="240" y="128" text-anchor="middle" fill="{value_color}" font-size="10" font-family="monospace">z={z:.2f}</text>
<!-- Arrow to Sigmoid -->
<line x1="275" y1="115" x2="310" y2="115" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/>
<!-- Sigmoid -->
<rect x="320" y="90" width="70" height="50" rx="5" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/>
<text x="355" y="110" text-anchor="middle" fill="{label_color}" font-size="11" font-family="monospace">sigmoid</text>
<text x="355" y="128" text-anchor="middle" fill="{value_color}" font-size="10" font-family="monospace">y={y_pred:.3f}</text>
<!-- Arrow to Loss -->
<line x1="390" y1="115" x2="425" y2="115" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/>
<!-- Loss -->
<rect x="435" y="85" width="80" height="60" rx="5" fill="{loss_fill}" stroke="{loss_stroke}" stroke-width="2"/>
<text x="475" y="105" text-anchor="middle" fill="{loss_stroke}" font-size="11" font-family="monospace">BCE</text>
<text x="475" y="122" text-anchor="middle" fill="{value_color}" font-size="11" font-family="monospace">L={loss:.4f}</text>
<text x="475" y="138" text-anchor="middle" fill="{text_color}" font-size="9" font-family="monospace">y_true={y_true}</text>
<!-- BACKWARD SECTION -->
<text x="300" y="185" text-anchor="middle" fill="{backward_arrow}" font-size="12" font-family="monospace">BACKWARD PASS (gradients)</text>
<!-- Gradient chain boxes -->
<rect x="435" y="200" width="80" height="35" rx="4" fill="{loss_fill}" stroke="{loss_stroke}" stroke-width="1"/>
<text x="475" y="222" text-anchor="middle" fill="{gradient_color}" font-size="9" font-family="monospace">dL/dy={dL_dy:.2f}</text>
<line x1="435" y1="218" x2="405" y2="218" stroke="{backward_arrow}" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd)"/>
<rect x="320" y="200" width="80" height="35" rx="4" fill="{node_fill}" stroke="{node_stroke}" stroke-width="1"/>
<text x="360" y="222" text-anchor="middle" fill="{gradient_color}" font-size="9" font-family="monospace">dy/dz={dy_dz:.3f}</text>
<line x1="320" y1="218" x2="290" y2="218" stroke="{backward_arrow}" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd)"/>
<rect x="205" y="200" width="80" height="35" rx="4" fill="{node_fill}" stroke="{node_stroke}" stroke-width="1"/>
<text x="245" y="222" text-anchor="middle" fill="{gradient_color}" font-size="9" font-family="monospace">dL/dz={dL_dz:.3f}</text>
<line x1="205" y1="218" x2="175" y2="218" stroke="{backward_arrow}" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd)"/>
<!-- Final Gradients Box -->
<rect x="20" y="260" width="240" height="80" rx="6" fill="#141414" stroke="{gradient_color}" stroke-width="1"/>
<text x="140" y="282" text-anchor="middle" fill="{value_color}" font-size="12" font-family="monospace">COMPUTED GRADIENTS</text>
<text x="140" y="305" text-anchor="middle" fill="{gradient_color}" font-size="11" font-family="monospace">dL/dw1 = {dL_dw1:.4f}</text>
<text x="140" y="325" text-anchor="middle" fill="{gradient_color}" font-size="11" font-family="monospace">dL/dw2 = {dL_dw2:.4f} dL/db = {dL_db:.4f}</text>
<!-- Chain Rule Box -->
<rect x="530" y="50" width="360" height="190" rx="6" fill="#141414" stroke="#555" stroke-width="1"/>
<text x="710" y="75" text-anchor="middle" fill="{value_color}" font-size="13" font-family="monospace">CHAIN RULE COMPUTATION</text>
<text x="545" y="100" fill="{text_color}" font-size="11" font-family="monospace">dL/dw1 = dL/dy * dy/dz * dz/dw1</text>
<text x="545" y="125" fill="#888" font-size="11" font-family="monospace"> = ({dL_dy:.2f}) * ({dy_dz:.4f}) * ({x1:.2f})</text>
<text x="545" y="150" fill="{value_color}" font-size="12" font-family="monospace"> = {dL_dw1:.4f}</text>
<line x1="545" y1="165" x2="875" y2="165" stroke="#333" stroke-width="1"/>
<text x="545" y="185" fill="{text_color}" font-size="10" font-family="monospace">Key: dz/dw1 = x1, dz/dw2 = x2, dz/db = 1</text>
<text x="545" y="205" fill="#888" font-size="10" font-family="monospace">The input values become gradients!</text>
<text x="545" y="225" fill="{text_color}" font-size="10" font-family="monospace">dL/dw2 = dL/dz * x2 = {dL_dz:.3f} * {x2:.2f} = {dL_dw2:.4f}</text>
<!-- Legend -->
<rect x="530" y="260" width="360" height="80" rx="6" fill="#141414" stroke="#333" stroke-width="1"/>
<text x="710" y="282" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">LEGEND</text>
<line x1="550" y1="302" x2="590" y2="302" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/>
<text x="600" y="306" fill="{text_color}" font-size="10" font-family="monospace">Forward (data)</text>
<line x1="550" y1="322" x2="590" y2="322" stroke="{backward_arrow}" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd)"/>
<text x="600" y="326" fill="{text_color}" font-size="10" font-family="monospace">Backward (grads)</text>
<rect x="750" y="295" width="12" height="12" fill="{gradient_color}"/>
<text x="770" y="306" fill="{text_color}" font-size="10" font-family="monospace">Gradient values</text>
</svg>
'''
return svg
# ============================================================================
# TAB 1: FORWARD PASS
# ============================================================================
def forward_pass_demo(x1, x2, w1, w2, b):
"""Step-by-step forward pass calculation."""
# Step 1: Weighted sum
z = w1 * x1 + w2 * x2 + b
# Step 2: Sigmoid activation
sigmoid_z = 1 / (1 + np.exp(-z))
# Generate SVG diagram
svg_diagram = generate_forward_svg(x1, x2, w1, w2, b, z, sigmoid_z)
explanation = f"""
## FORWARD PASS CALCULATION
===============================================
### STEP 1: The Weighted Sum (z)
The neuron computes a **weighted sum** of inputs plus a bias:
```
z = w1*x1 + w2*x2 + b
z = ({w1:.2f})*({x1:.2f}) + ({w2:.2f})*({x2:.2f}) + ({b:.2f})
z = {w1*x1:.4f} + {w2*x2:.4f} + {b:.2f}
z = {z:.4f}
```
**What's happening:** Each input is scaled by its weight,
then we add them up. The bias shifts the whole thing.
-----------------------------------------------
### STEP 2: The Sigmoid Activation Function
We squash z through the **sigmoid function** to get output in (0,1):
```
sigmoid(z) = 1 / (1 + e^(-z))
= 1 / (1 + e^(-{z:.4f}))
= 1 / (1 + {np.exp(-z):.4f})
= 1 / {1 + np.exp(-z):.4f}
= {sigmoid_z:.4f}
```
**Why sigmoid?** It smoothly maps any real number to (0,1).
- z >> 0 --> sigmoid(z) ≈ 1
- z << 0 --> sigmoid(z) ≈ 0
- z = 0 --> sigmoid(z) = 0.5
-----------------------------------------------
### SUMMARY
```
Inputs: x1={x1:.2f}, x2={x2:.2f}
Weights: w1={w1:.2f}, w2={w2:.2f}
Bias: b={b:.2f}
z = {z:.4f}
y = sigmoid(z) = {sigmoid_z:.4f}
```
**Interpretation:** Output of {sigmoid_z:.4f} means
{sigmoid_z*100:.1f}% probability of class 1.
"""
return svg_diagram, explanation
FORWARD_INTRO_SVG = '''
<svg viewBox="0 0 900 420" style="width:100%; max-width:900px; height:auto; background:#0c0c0c; border-radius:8px; border:1px solid #333; margin-bottom:20px;">
<defs>
<marker id="fwd-arr" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill="#5b9bd5" />
</marker>
</defs>
<!-- Title -->
<text x="450" y="35" text-anchor="middle" fill="#f0b030" font-size="20" font-family="monospace" font-weight="bold">FORWARD PASS: DATA IN → PREDICTION OUT</text>
<!-- Main neuron diagram -->
<rect x="20" y="55" width="480" height="200" rx="8" fill="#141414" stroke="#5b9bd5" stroke-width="1"/>
<text x="40" y="80" fill="#5b9bd5" font-size="14" font-family="monospace" font-weight="bold">THE SINGLE NEURON</text>
<!-- Input x1 -->
<circle cx="70" cy="120" r="22" fill="#1a3a2a" stroke="#4ade80" stroke-width="2"/>
<text x="70" y="125" text-anchor="middle" fill="#4ade80" font-size="13" font-family="monospace">x₁</text>
<!-- Input x2 -->
<circle cx="70" cy="195" r="22" fill="#1a3a2a" stroke="#4ade80" stroke-width="2"/>
<text x="70" y="200" text-anchor="middle" fill="#4ade80" font-size="13" font-family="monospace">x₂</text>
<!-- Weights on arrows -->
<line x1="92" y1="120" x2="155" y2="145" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-arr)"/>
<text x="108" y="115" fill="#f0b030" font-size="11" font-family="monospace">×w₁</text>
<line x1="92" y1="195" x2="155" y2="170" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-arr)"/>
<text x="108" y="210" fill="#f0b030" font-size="11" font-family="monospace">×w₂</text>
<!-- Summation node -->
<circle cx="185" cy="157" r="28" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="2"/>
<text x="185" y="152" text-anchor="middle" fill="#5b9bd5" font-size="16" font-family="monospace">Σ</text>
<text x="185" y="170" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">+b</text>
<!-- Arrow to z -->
<line x1="213" y1="157" x2="255" y2="157" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-arr)"/>
<!-- z value box -->
<rect x="265" y="140" width="40" height="35" rx="5" fill="#2a2a1a" stroke="#f0b030" stroke-width="2"/>
<text x="285" y="162" text-anchor="middle" fill="#f0b030" font-size="13" font-family="monospace">z</text>
<!-- Arrow to sigmoid -->
<line x1="305" y1="157" x2="340" y2="157" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-arr)"/>
<!-- Sigmoid box -->
<rect x="350" y="135" width="60" height="45" rx="5" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="2"/>
<text x="380" y="155" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">σ(z)</text>
<text x="380" y="170" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">sigmoid</text>
<!-- Arrow to output -->
<line x1="410" y1="157" x2="445" y2="157" stroke="#4ade80" stroke-width="2" marker-end="url(#fwd-arr)"/>
<!-- Output -->
<circle cx="470" cy="157" r="22" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="2"/>
<text x="470" y="162" text-anchor="middle" fill="#ff6b6b" font-size="13" font-family="monospace">ŷ</text>
<!-- Step labels -->
<text x="185" y="215" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">STEP 1</text>
<text x="185" y="228" text-anchor="middle" fill="#5b9bd5" font-size="9" font-family="monospace">Weighted Sum</text>
<text x="380" y="215" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">STEP 2</text>
<text x="380" y="228" text-anchor="middle" fill="#5b9bd5" font-size="9" font-family="monospace">Activation</text>
<!-- Equations box - made wider -->
<rect x="515" y="55" width="365" height="200" rx="8" fill="#141414" stroke="#4ade80" stroke-width="1"/>
<text x="535" y="80" fill="#4ade80" font-size="14" font-family="monospace" font-weight="bold">THE MATH</text>
<text x="535" y="110" fill="#888" font-size="12" font-family="monospace">Step 1: Weighted Sum</text>
<text x="535" y="132" fill="#f0b030" font-size="14" font-family="monospace">z = w₁x₁ + w₂x₂ + b</text>
<text x="535" y="165" fill="#888" font-size="12" font-family="monospace">Step 2: Sigmoid</text>
<text x="535" y="187" fill="#f0b030" font-size="14" font-family="monospace">ŷ = σ(z) = 1/(1+e⁻ᶻ)</text>
<line x1="535" y1="200" x2="860" y2="200" stroke="#333" stroke-width="1"/>
<text x="535" y="222" fill="#4ade80" font-size="12" font-family="monospace">Output ŷ ∈ (0,1) = probability</text>
<text x="535" y="242" fill="#888" font-size="10" font-family="monospace">Squashes any real number to (0,1)</text>
<!-- Interactive prompt -->
<rect x="20" y="270" width="860" height="135" rx="8" fill="#1a1a1a" stroke="#888" stroke-width="1" stroke-dasharray="5,5"/>
<text x="450" y="300" text-anchor="middle" fill="#888" font-size="15" font-family="monospace">▼ INTERACTIVE TERMINAL ▼</text>
<text x="450" y="330" text-anchor="middle" fill="#5b9bd5" font-size="13" font-family="monospace">Adjust inputs (x₁, x₂), weights (w₁, w₂), and bias (b)</text>
<text x="450" y="360" text-anchor="middle" fill="#4ade80" font-size="13" font-family="monospace">Click "EXECUTE FORWARD PASS" to see the values</text>
<text x="450" y="390" text-anchor="middle" fill="#555" font-size="10" font-family="monospace">[Vault-Tec recommends saving your work before experiments]</text>
</svg>
'''
FORWARD_INTRO = f"""
{FORWARD_INTRO_SVG}
"""
# ============================================================================
# TAB 2: CHAIN RULE FUNDAMENTALS
# ============================================================================
CHAIN_RULE_INTRO_SVG = '''
<svg viewBox="0 0 900 480" style="width:100%; max-width:900px; height:auto; background:#0c0c0c; border-radius:8px; border:1px solid #333; margin-bottom:20px;">
<defs>
<marker id="arr-blue" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill="#5b9bd5" />
</marker>
</defs>
<!-- Title -->
<text x="450" y="32" text-anchor="middle" fill="#f0b030" font-size="20" font-family="monospace" font-weight="bold">THE CHAIN RULE</text>
<!-- Section 1: Basic Idea -->
<rect x="20" y="50" width="860" height="110" rx="8" fill="#141414" stroke="#5b9bd5" stroke-width="1"/>
<text x="40" y="72" fill="#5b9bd5" font-size="14" font-family="monospace" font-weight="bold">1. THE BASIC IDEA</text>
<!-- Composition diagram - more compact -->
<rect x="50" y="90" width="45" height="32" rx="5" fill="#1a3a2a" stroke="#4ade80" stroke-width="2"/>
<text x="72" y="111" text-anchor="middle" fill="#4ade80" font-size="14" font-family="monospace">x</text>
<line x1="95" y1="106" x2="130" y2="106" stroke="#5b9bd5" stroke-width="2" marker-end="url(#arr-blue)"/>
<rect x="140" y="88" width="55" height="36" rx="5" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="2"/>
<text x="167" y="111" text-anchor="middle" fill="#5b9bd5" font-size="13" font-family="monospace">g(x)</text>
<line x1="195" y1="106" x2="230" y2="106" stroke="#5b9bd5" stroke-width="2" marker-end="url(#arr-blue)"/>
<rect x="240" y="90" width="45" height="32" rx="5" fill="#2a2a1a" stroke="#f0b030" stroke-width="2"/>
<text x="262" y="111" text-anchor="middle" fill="#f0b030" font-size="14" font-family="monospace">u</text>
<line x1="285" y1="106" x2="320" y2="106" stroke="#5b9bd5" stroke-width="2" marker-end="url(#arr-blue)"/>
<rect x="330" y="88" width="55" height="36" rx="5" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="2"/>
<text x="357" y="111" text-anchor="middle" fill="#5b9bd5" font-size="13" font-family="monospace">f(u)</text>
<line x1="385" y1="106" x2="420" y2="106" stroke="#5b9bd5" stroke-width="2" marker-end="url(#arr-blue)"/>
<rect x="430" y="90" width="45" height="32" rx="5" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="2"/>
<text x="452" y="111" text-anchor="middle" fill="#ff6b6b" font-size="14" font-family="monospace">y</text>
<!-- Labels -->
<text x="262" y="140" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">u = g(x)</text>
<text x="452" y="140" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">y = f(g(x))</text>
<!-- Formula box - wider -->
<rect x="510" y="80" width="350" height="65" rx="6" fill="#0c0c0c" stroke="#f0b030" stroke-width="1"/>
<text x="685" y="105" text-anchor="middle" fill="#f0b030" font-size="13" font-family="monospace">Chain Rule Formula:</text>
<text x="685" y="130" text-anchor="middle" fill="#4ade80" font-size="15" font-family="monospace">dy/dx = (dy/du) × (du/dx)</text>
<!-- Section 2: Why It Works -->
<rect x="20" y="170" width="420" height="140" rx="8" fill="#141414" stroke="#f0b030" stroke-width="1"/>
<text x="40" y="192" fill="#f0b030" font-size="14" font-family="monospace" font-weight="bold">2. WHY IT WORKS</text>
<text x="40" y="218" fill="#4ade80" font-size="13" font-family="monospace">Think of it like fractions:</text>
<!-- Fraction visualization - smaller -->
<text x="55" y="250" fill="#5b9bd5" font-size="16" font-family="monospace">dy</text>
<line x1="50" y1="255" x2="75" y2="255" stroke="#5b9bd5" stroke-width="2"/>
<text x="55" y="272" fill="#5b9bd5" font-size="16" font-family="monospace">dx</text>
<text x="90" y="260" fill="#888" font-size="18" font-family="monospace">=</text>
<text x="115" y="250" fill="#f0b030" font-size="16" font-family="monospace">dy</text>
<line x1="110" y1="255" x2="135" y2="255" stroke="#f0b030" stroke-width="2"/>
<text x="115" y="272" fill="#ff6b6b" font-size="16" font-family="monospace" text-decoration="line-through">du</text>
<text x="150" y="260" fill="#888" font-size="18" font-family="monospace">×</text>
<text x="175" y="250" fill="#ff6b6b" font-size="16" font-family="monospace" text-decoration="line-through">du</text>
<line x1="170" y1="255" x2="195" y2="255" stroke="#f0b030" stroke-width="2"/>
<text x="175" y="272" fill="#4ade80" font-size="16" font-family="monospace">dx</text>
<text x="210" y="260" fill="#888" font-size="18" font-family="monospace">=</text>
<text x="235" y="250" fill="#f0b030" font-size="16" font-family="monospace">dy</text>
<line x1="230" y1="255" x2="255" y2="255" stroke="#4ade80" stroke-width="2"/>
<text x="235" y="272" fill="#4ade80" font-size="16" font-family="monospace">dx</text>
<!-- Section 3: Concrete Example -->
<rect x="460" y="170" width="420" height="140" rx="8" fill="#141414" stroke="#4ade80" stroke-width="1"/>
<text x="480" y="192" fill="#4ade80" font-size="14" font-family="monospace" font-weight="bold">3. EXAMPLE: y = (3x + 2)²</text>
<!-- Break down - repositioned -->
<text x="480" y="222" fill="#5b9bd5" font-size="12" font-family="monospace">Inner: u = 3x+2</text>
<text x="630" y="222" fill="#888" font-size="12" font-family="monospace">→ du/dx = 3</text>
<text x="480" y="248" fill="#5b9bd5" font-size="12" font-family="monospace">Outer: y = u²</text>
<text x="630" y="248" fill="#888" font-size="12" font-family="monospace">→ dy/du = 2u</text>
<line x1="480" y1="260" x2="860" y2="260" stroke="#333" stroke-width="1"/>
<text x="480" y="282" fill="#f0b030" font-size="13" font-family="monospace">Chain: dy/dx = 2u × 3 = 6(3x+2)</text>
<!-- Section 4: Interactive -->
<rect x="20" y="320" width="860" height="145" rx="8" fill="#1a1a1a" stroke="#888" stroke-width="1" stroke-dasharray="5,5"/>
<text x="450" y="350" text-anchor="middle" fill="#888" font-size="15" font-family="monospace">▼ INTERACTIVE TERMINAL ▼</text>
<text x="450" y="380" text-anchor="middle" fill="#5b9bd5" font-size="13" font-family="monospace">Adjust a, b, and x with the sliders</text>
<text x="450" y="410" text-anchor="middle" fill="#4ade80" font-size="13" font-family="monospace">Click "APPLY CHAIN RULE" to see values flow through</text>
<text x="450" y="440" text-anchor="middle" fill="#555" font-size="10" font-family="monospace">[Remember: derivatives chain together like Vault access codes]</text>
</svg>
'''
CHAIN_RULE_INTRO = f"""
{CHAIN_RULE_INTRO_SVG}
"""
def generate_chain_rule_svg(a, b, x_val, u, y, du_dx, dy_du, dy_dx):
"""Generate SVG showing chain rule visually."""
bg = "#0c0c0c"
node_fill = "#1a2a3a"
node_stroke = "#5b9bd5"
input_fill = "#1a3a2a"
input_stroke = "#4ade80"
output_fill = "#2a2a1a"
output_stroke = "#f0b030"
text_color = "#4ade80"
label_color = "#5b9bd5"
arrow_color = "#5b9bd5"
value_color = "#f0b030"
deriv_color = "#ff6b6b"
svg = f'''
<svg viewBox="0 0 800 280" style="width:100%; max-width:800px; height:auto; background:{bg}; border-radius:8px; border:1px solid #333;">
<defs>
<marker id="arr" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill="{arrow_color}" />
</marker>
<marker id="arr-red" markerWidth="10" markerHeight="7" refX="0" refY="3.5" orient="auto">
<polygon points="10 0, 0 3.5, 10 7" fill="{deriv_color}" />
</marker>
</defs>
<!-- Title -->
<text x="400" y="30" text-anchor="middle" fill="{output_stroke}" font-size="18" font-family="monospace">CHAIN RULE: y = ({a}x + {b})²</text>
<!-- Input x -->
<rect x="50" y="80" width="80" height="60" rx="8" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/>
<text x="90" y="105" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">x</text>
<text x="90" y="128" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace">{x_val:.2f}</text>
<!-- Arrow x to u -->
<line x1="130" y1="110" x2="200" y2="110" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arr)"/>
<text x="165" y="100" text-anchor="middle" fill="{text_color}" font-size="12" font-family="monospace">g(x)</text>
<!-- Inner function u -->
<rect x="210" y="80" width="120" height="60" rx="8" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/>
<text x="270" y="100" text-anchor="middle" fill="{label_color}" font-size="12" font-family="monospace">u = {a}x + {b}</text>
<text x="270" y="125" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace">u = {u:.2f}</text>
<!-- Arrow u to y -->
<line x1="330" y1="110" x2="400" y2="110" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arr)"/>
<text x="365" y="100" text-anchor="middle" fill="{text_color}" font-size="12" font-family="monospace">f(u)</text>
<!-- Outer function y -->
<rect x="410" y="80" width="100" height="60" rx="8" fill="{output_fill}" stroke="{output_stroke}" stroke-width="2"/>
<text x="460" y="100" text-anchor="middle" fill="{label_color}" font-size="12" font-family="monospace">y = u²</text>
<text x="460" y="125" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace">y = {y:.2f}</text>
<!-- Derivative arrows (below, going backwards) -->
<line x1="410" y1="160" x2="335" y2="160" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3" marker-end="url(#arr-red)"/>
<text x="372" y="180" text-anchor="middle" fill="{deriv_color}" font-size="12" font-family="monospace">dy/du = {dy_du:.2f}</text>
<line x1="210" y1="160" x2="135" y2="160" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3" marker-end="url(#arr-red)"/>
<text x="172" y="180" text-anchor="middle" fill="{deriv_color}" font-size="12" font-family="monospace">du/dx = {du_dx:.2f}</text>
<!-- Chain rule result box -->
<rect x="540" y="70" width="230" height="120" rx="8" fill="#141414" stroke="#333" stroke-width="1"/>
<text x="655" y="95" text-anchor="middle" fill="{output_stroke}" font-size="14" font-family="monospace">CHAIN RULE</text>
<text x="555" y="120" fill="{text_color}" font-size="13" font-family="monospace">dy/dx = dy/du × du/dx</text>
<text x="555" y="145" fill="{text_color}" font-size="13" font-family="monospace"> = {dy_du:.2f} × {du_dx:.2f}</text>
<text x="555" y="175" fill="{value_color}" font-size="16" font-family="monospace"> = {dy_dx:.2f}</text>
<!-- Legend -->
<line x1="50" y1="240" x2="90" y2="240" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arr)"/>
<text x="100" y="244" fill="{text_color}" font-size="11" font-family="monospace">Forward</text>
<line x1="200" y1="240" x2="240" y2="240" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3" marker-end="url(#arr-red)"/>
<text x="250" y="244" fill="{text_color}" font-size="11" font-family="monospace">Derivatives (multiply!)</text>
</svg>
'''
return svg
def chain_rule_calculator(a, b, x_val):
"""Demonstrate chain rule with y = (ax + b)^2"""
# u = ax + b
u = a * x_val + b
# y = u^2
y = u ** 2
# Derivatives
du_dx = a
dy_du = 2 * u
dy_dx = dy_du * du_dx
# Generate SVG
svg_diagram = generate_chain_rule_svg(a, b, x_val, u, y, du_dx, dy_du, dy_dx)
explanation = f"""
## CHAIN RULE CALCULATION: y = ({a}x + {b})^2
===============================================
### Setting up the composition:
```
Inner function: u = {a}x + {b}
Outer function: y = u^2
```
At x = {x_val}:
```
u = {a}*{x_val} + {b} = {u}
y = ({u})^2 = {y}
```
-----------------------------------------------
### Step 1: Find du/dx (derivative of inner function)
```
u = {a}x + {b}
du/dx = {a} (coefficient of x)
```
-----------------------------------------------
### Step 2: Find dy/du (derivative of outer function)
```
y = u^2
dy/du = 2u = 2*({u}) = {dy_du}
```
-----------------------------------------------
### Step 3: Apply the Chain Rule!
```
dy/dx = (dy/du) * (du/dx)
= {dy_du} * {du_dx}
= {dy_dx}
```
-----------------------------------------------
### VERIFICATION (optional sanity check)
If x increases by tiny amount h=0.001:
```
y(x+h) = ({a}*{x_val+0.001} + {b})^2 = {(a*(x_val+0.001) + b)**2:.6f}
y(x) = {y}
Slope ≈ (y(x+h) - y(x)) / h
= {((a*(x_val+0.001) + b)**2 - y) / 0.001:.4f}
Our dy/dx = {dy_dx}
```
The chain rule works!
"""
return svg_diagram, explanation
# ============================================================================
# TAB 3: DERIVATIVES OF KEY FUNCTIONS
# ============================================================================
DERIVATIVES_INTRO_SVG = '''
<svg viewBox="0 0 900 520" style="width:100%; max-width:900px; height:auto; background:#0c0c0c; border-radius:8px; border:1px solid #333; margin-bottom:20px;">
<!-- Title -->
<text x="450" y="32" text-anchor="middle" fill="#f0b030" font-size="18" font-family="monospace" font-weight="bold">KEY DERIVATIVES YOU NEED TO KNOW</text>
<!-- Sigmoid Section -->
<rect x="20" y="50" width="430" height="200" rx="8" fill="#141414" stroke="#4ade80" stroke-width="2"/>
<text x="35" y="75" fill="#4ade80" font-size="14" font-family="monospace" font-weight="bold">1. SIGMOID FUNCTION</text>
<!-- Mini sigmoid curve - compact -->
<polyline points="35,150 50,148 65,145 80,140 95,130 110,115 125,102 140,94 155,89 170,87"
fill="none" stroke="#4ade80" stroke-width="2"/>
<line x1="35" y1="155" x2="170" y2="155" stroke="#333" stroke-width="1"/>
<line x1="102" y1="87" x2="102" y2="155" stroke="#333" stroke-width="1" stroke-dasharray="3,3"/>
<text x="102" y="170" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">z=0 → σ=0.5</text>
<!-- Formula - compact layout -->
<text x="185" y="95" fill="#5b9bd5" font-size="10" font-family="monospace">Function:</text>
<text x="185" y="112" fill="#f0b030" font-size="10" font-family="monospace">σ(z) = 1/(1+e⁻ᶻ)</text>
<text x="185" y="135" fill="#5b9bd5" font-size="10" font-family="monospace">Derivative:</text>
<text x="185" y="152" fill="#ff6b6b" font-size="10" font-family="monospace">dσ/dz = σ(z)(1-σ(z))</text>
<!-- Key insight box -->
<rect x="30" y="190" width="410" height="50" rx="4" fill="#1a2a1a" stroke="#4ade80" stroke-width="1"/>
<text x="45" y="210" fill="#4ade80" font-size="10" font-family="monospace">Derivative uses the function itself!</text>
<text x="45" y="228" fill="#888" font-size="10" font-family="monospace">Already have σ(z)? No extra work needed.</text>
<!-- BCE Section -->
<rect x="460" y="50" width="420" height="200" rx="8" fill="#141414" stroke="#ff6b6b" stroke-width="2"/>
<text x="480" y="75" fill="#ff6b6b" font-size="14" font-family="monospace" font-weight="bold">2. BINARY CROSS-ENTROPY</text>
<text x="480" y="100" fill="#5b9bd5" font-size="11" font-family="monospace">Loss Function:</text>
<text x="480" y="120" fill="#f0b030" font-size="11" font-family="monospace">L = -[y·log(ŷ) + (1-y)·log(1-ŷ)]</text>
<text x="480" y="150" fill="#5b9bd5" font-size="11" font-family="monospace">Derivative w.r.t. ŷ:</text>
<text x="480" y="170" fill="#ff6b6b" font-size="11" font-family="monospace">dL/dŷ = -y/ŷ + (1-y)/(1-ŷ)</text>
<!-- Magic box -->
<rect x="480" y="190" width="385" height="50" rx="4" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="1"/>
<text x="495" y="210" fill="#f0b030" font-size="10" font-family="monospace">Combined with sigmoid:</text>
<text x="495" y="227" fill="#4ade80" font-size="13" font-family="monospace">dL/dz = ŷ - y (Vault-Tec approved)</text>
<!-- Common Patterns Section -->
<rect x="20" y="260" width="860" height="130" rx="8" fill="#141414" stroke="#5b9bd5" stroke-width="1"/>
<text x="40" y="283" fill="#5b9bd5" font-size="14" font-family="monospace" font-weight="bold">3. DERIVATIVE PATTERNS</text>
<!-- Pattern boxes - evenly spaced -->
<rect x="35" y="300" width="135" height="75" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
<text x="102" y="320" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">Powers</text>
<text x="102" y="342" text-anchor="middle" fill="#f0b030" font-size="12" font-family="monospace">d/dx[xⁿ]</text>
<text x="102" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= n·xⁿ⁻¹</text>
<rect x="185" y="300" width="135" height="75" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
<text x="252" y="320" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">Exponential</text>
<text x="252" y="342" text-anchor="middle" fill="#f0b030" font-size="12" font-family="monospace">d/dx[eˣ]</text>
<text x="252" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= eˣ</text>
<rect x="335" y="300" width="135" height="75" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
<text x="402" y="320" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">Logarithm</text>
<text x="402" y="342" text-anchor="middle" fill="#f0b030" font-size="12" font-family="monospace">d/dx[log(x)]</text>
<text x="402" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= 1/x</text>
<rect x="485" y="300" width="135" height="75" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
<text x="552" y="320" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">Sigmoid</text>
<text x="552" y="342" text-anchor="middle" fill="#f0b030" font-size="12" font-family="monospace">d/dx[σ(x)]</text>
<text x="552" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= σ(1-σ)</text>
<rect x="635" y="300" width="230" height="75" rx="4" fill="#2a2a1a" stroke="#f0b030" stroke-width="1"/>
<text x="750" y="320" text-anchor="middle" fill="#f0b030" font-size="11" font-family="monospace">Chain Rule</text>
<text x="750" y="342" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">d/dx[f(g(x))]</text>
<text x="750" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= f'(g(x)) · g'(x)</text>
<!-- Interactive prompt -->
<rect x="20" y="405" width="860" height="100" rx="8" fill="#1a1a1a" stroke="#888" stroke-width="1" stroke-dasharray="5,5"/>
<text x="450" y="430" text-anchor="middle" fill="#888" font-size="14" font-family="monospace">INTERACTIVE TERMINAL</text>
<text x="450" y="455" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">Move the z slider to see sigmoid and its derivative</text>
<text x="450" y="480" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">Derivative peaks at z=0, vanishes at extremes</text>
</svg>
'''
DERIVATIVES_INTRO = f"""
{DERIVATIVES_INTRO_SVG}
"""
def generate_sigmoid_svg(z, sig, dsig):
"""Generate SVG showing sigmoid function and derivative visually."""
bg = "#0c0c0c"
curve_color = "#4ade80"
deriv_color = "#ff6b6b"
point_color = "#f0b030"
grid_color = "#333"
text_color = "#4ade80"
label_color = "#5b9bd5"
# Generate sigmoid curve points
curve_points = []
for i in range(-50, 51):
x_pt = i / 10 # -5 to 5
y_pt = 1 / (1 + np.exp(-x_pt))
# Map to SVG coordinates: x: -5..5 -> 100..500, y: 0..1 -> 250..50
svg_x = 100 + (x_pt + 5) * 40
svg_y = 250 - y_pt * 200
curve_points.append(f"{svg_x:.1f},{svg_y:.1f}")
curve_path = " ".join(curve_points)
# Current point coordinates
pt_x = 100 + (z + 5) * 40
pt_y = 250 - sig * 200
# Tangent line (slope = dsig, in SVG coordinates)
# The slope in data space is dsig, but in SVG space y is inverted
tangent_dx = 40
tangent_dy = -dsig * 200
t_x1 = pt_x - tangent_dx
t_y1 = pt_y - tangent_dy
t_x2 = pt_x + tangent_dx
t_y2 = pt_y + tangent_dy
svg = f'''
<svg viewBox="0 0 700 320" style="width:100%; max-width:700px; height:auto; background:{bg}; border-radius:8px; border:1px solid #333;">
<!-- Title -->
<text x="350" y="25" text-anchor="middle" fill="{point_color}" font-size="16" font-family="monospace">SIGMOID FUNCTION & DERIVATIVE</text>
<!-- Grid lines -->
<line x1="100" y1="150" x2="500" y2="150" stroke="{grid_color}" stroke-width="1" stroke-dasharray="3,3"/>
<text x="95" y="154" text-anchor="end" fill="{grid_color}" font-size="10" font-family="monospace">0.5</text>
<line x1="300" y1="50" x2="300" y2="250" stroke="{grid_color}" stroke-width="1" stroke-dasharray="3,3"/>
<text x="300" y="265" text-anchor="middle" fill="{grid_color}" font-size="10" font-family="monospace">z=0</text>
<!-- Axes -->
<line x1="100" y1="250" x2="500" y2="250" stroke="{label_color}" stroke-width="1"/>
<line x1="100" y1="50" x2="100" y2="250" stroke="{label_color}" stroke-width="1"/>
<!-- Axis labels -->
<text x="300" y="280" text-anchor="middle" fill="{label_color}" font-size="12" font-family="monospace">z</text>
<text x="70" y="150" text-anchor="middle" fill="{label_color}" font-size="12" font-family="monospace">σ(z)</text>
<text x="95" y="55" text-anchor="end" fill="{grid_color}" font-size="10" font-family="monospace">1.0</text>
<text x="95" y="254" text-anchor="end" fill="{grid_color}" font-size="10" font-family="monospace">0.0</text>
<text x="100" y="275" text-anchor="middle" fill="{grid_color}" font-size="10" font-family="monospace">-5</text>
<text x="500" y="275" text-anchor="middle" fill="{grid_color}" font-size="10" font-family="monospace">5</text>
<!-- Sigmoid curve -->
<polyline points="{curve_path}" fill="none" stroke="{curve_color}" stroke-width="2"/>
<!-- Tangent line at current point -->
<line x1="{t_x1:.1f}" y1="{t_y1:.1f}" x2="{t_x2:.1f}" y2="{t_y2:.1f}" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3"/>
<!-- Current point -->
<circle cx="{pt_x:.1f}" cy="{pt_y:.1f}" r="8" fill="{point_color}" stroke="#fff" stroke-width="2"/>
<!-- Point label -->
<line x1="{pt_x:.1f}" y1="{pt_y:.1f}" x2="{pt_x + 30:.1f}" y2="{pt_y - 30:.1f}" stroke="{point_color}" stroke-width="1"/>
<text x="{pt_x + 35:.1f}" y="{pt_y - 35:.1f}" fill="{point_color}" font-size="11" font-family="monospace">z={z:.1f}</text>
<text x="{pt_x + 35:.1f}" y="{pt_y - 22:.1f}" fill="{point_color}" font-size="11" font-family="monospace">σ={sig:.3f}</text>
<!-- Info box -->
<rect x="530" y="60" width="155" height="140" rx="6" fill="#141414" stroke="#333" stroke-width="1"/>
<text x="607" y="85" text-anchor="middle" fill="{point_color}" font-size="13" font-family="monospace">VALUES</text>
<text x="540" y="110" fill="{label_color}" font-size="12" font-family="monospace">z = {z:.2f}</text>
<text x="540" y="130" fill="{curve_color}" font-size="12" font-family="monospace">σ(z) = {sig:.4f}</text>
<text x="540" y="155" fill="{deriv_color}" font-size="12" font-family="monospace">dσ/dz = {dsig:.4f}</text>
<text x="540" y="185" fill="{text_color}" font-size="10" font-family="monospace">= σ(1-σ)</text>
<text x="540" y="198" fill="{text_color}" font-size="10" font-family="monospace">= {sig:.3f}×{1-sig:.3f}</text>
<!-- Legend -->
<line x1="530" y1="240" x2="560" y2="240" stroke="{curve_color}" stroke-width="2"/>
<text x="565" y="244" fill="{text_color}" font-size="10" font-family="monospace">σ(z)</text>
<line x1="620" y1="240" x2="650" y2="240" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3"/>
<text x="655" y="244" fill="{text_color}" font-size="10" font-family="monospace">tangent</text>
</svg>
'''
return svg
def sigmoid_derivative_demo(z):
"""Show sigmoid and its derivative."""
sig = 1 / (1 + np.exp(-z))
dsig = sig * (1 - sig)
svg_diagram = generate_sigmoid_svg(z, sig, dsig)
explanation = f"""
## SIGMOID DERIVATIVE AT z = {z}
===============================================
### Step 1: Compute sigmoid(z)
```
σ(z) = 1 / (1 + e^(-z))
= 1 / (1 + e^(-{z}))
= 1 / (1 + {np.exp(-z):.6f})
= 1 / {1 + np.exp(-z):.6f}
= {sig:.6f}
```
-----------------------------------------------
### Step 2: Compute the derivative
Using the formula: dσ/dz = σ(z) * (1 - σ(z))
```
dσ/dz = σ(z) * (1 - σ(z))
= {sig:.6f} * (1 - {sig:.6f})
= {sig:.6f} * {1-sig:.6f}
= {dsig:.6f}
```
-----------------------------------------------
### Interpretation
At z = {z}:
- Sigmoid output: {sig:.4f} (how confident the neuron is)
- Derivative: {dsig:.4f} (how sensitive output is to z)
**Key insight:** The derivative is LARGEST when z≈0 (sigmoid≈0.5)
and SMALLEST when |z| is large. This is the "vanishing gradient"
problem - extreme values barely update!
```
z = 0 --> σ = 0.5, dσ/dz = 0.25 (maximum!)
z = 5 --> σ ≈ 0.99, dσ/dz ≈ 0.007 (tiny!)
z = -5 --> σ ≈ 0.01, dσ/dz ≈ 0.007 (tiny!)
```
"""
return svg_diagram, explanation
# ============================================================================
# TAB 4: BACKWARD PASS (THE MAIN EVENT)
# ============================================================================
BACKWARD_INTRO_SVG = '''
<svg viewBox="0 0 900 500" style="width:100%; max-width:900px; height:auto; background:#0c0c0c; border-radius:8px; border:1px solid #333; margin-bottom:20px;">
<defs>
<marker id="fwd-b" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill="#5b9bd5" />
</marker>
<marker id="bwd-b" markerWidth="10" markerHeight="7" refX="0" refY="3.5" orient="auto">
<polygon points="10 0, 0 3.5, 10 7" fill="#ff6b6b" />
</marker>
</defs>
<!-- Title -->
<text x="450" y="32" text-anchor="middle" fill="#f0b030" font-size="18" font-family="monospace" font-weight="bold">BACKPROPAGATION: LEARNING FROM MISTAKES</text>
<!-- Main concept - THE BIG PICTURE -->
<rect x="20" y="50" width="550" height="170" rx="8" fill="#141414" stroke="#5b9bd5" stroke-width="1"/>
<text x="40" y="75" fill="#5b9bd5" font-size="14" font-family="monospace" font-weight="bold">THE BIG PICTURE</text>
<!-- Forward pass row -->
<text x="40" y="105" fill="#5b9bd5" font-size="12" font-family="monospace">FWD:</text>
<rect x="85" y="90" width="50" height="30" rx="4" fill="#1a3a2a" stroke="#4ade80" stroke-width="1"/>
<text x="110" y="110" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">x</text>
<line x1="135" y1="105" x2="175" y2="105" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-b)"/>
<rect x="185" y="90" width="50" height="30" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
<text x="210" y="110" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">z</text>
<line x1="235" y1="105" x2="275" y2="105" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-b)"/>
<rect x="285" y="90" width="50" height="30" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
<text x="310" y="110" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">y</text>
<line x1="335" y1="105" x2="375" y2="105" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-b)"/>
<rect x="385" y="90" width="50" height="30" rx="4" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="1"/>
<text x="410" y="110" text-anchor="middle" fill="#ff6b6b" font-size="12" font-family="monospace">L</text>
<text x="450" y="110" fill="#888" font-size="11" font-family="monospace">Data flows</text>
<!-- Backward pass row - wider boxes for partials -->
<text x="40" y="160" fill="#ff6b6b" font-size="12" font-family="monospace">BWD:</text>
<rect x="85" y="145" width="50" height="30" rx="4" fill="#1a3a2a" stroke="#4ade80" stroke-width="1"/>
<text x="110" y="165" text-anchor="middle" fill="#ff6b6b" font-size="10" font-family="monospace">dL/dx</text>
<line x1="135" y1="160" x2="175" y2="160" stroke="#ff6b6b" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd-b)"/>
<rect x="185" y="145" width="50" height="30" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
<text x="210" y="165" text-anchor="middle" fill="#ff6b6b" font-size="10" font-family="monospace">dL/dz</text>
<line x1="235" y1="160" x2="275" y2="160" stroke="#ff6b6b" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd-b)"/>
<rect x="285" y="145" width="50" height="30" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
<text x="310" y="165" text-anchor="middle" fill="#ff6b6b" font-size="10" font-family="monospace">dL/dy</text>
<line x1="335" y1="160" x2="375" y2="160" stroke="#ff6b6b" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd-b)"/>
<rect x="385" y="145" width="50" height="30" rx="4" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="1"/>
<text x="410" y="165" text-anchor="middle" fill="#ff6b6b" font-size="12" font-family="monospace">1</text>
<text x="450" y="165" fill="#888" font-size="11" font-family="monospace">Grads flow</text>
<!-- Goal text -->
<text x="40" y="205" fill="#f0b030" font-size="11" font-family="monospace">GOAL: Find dL/dw1, dL/dw2, dL/db to update weights</text>
<!-- Key insight box -->
<rect x="585" y="50" width="295" height="170" rx="6" fill="#2a2a1a" stroke="#f0b030" stroke-width="2"/>
<text x="732" y="78" text-anchor="middle" fill="#f0b030" font-size="13" font-family="monospace" font-weight="bold">KEY INSIGHT</text>
<text x="600" y="108" fill="#4ade80" font-size="11" font-family="monospace">Same computation graph,</text>
<text x="600" y="128" fill="#4ade80" font-size="11" font-family="monospace">opposite direction!</text>
<line x1="600" y1="142" x2="865" y2="142" stroke="#333" stroke-width="1"/>
<text x="600" y="165" fill="#888" font-size="11" font-family="monospace">At each node multiply:</text>
<text x="600" y="188" fill="#5b9bd5" font-size="10" font-family="monospace">upstream × local derivative</text>
<!-- Chain rule explanation -->
<rect x="20" y="235" width="420" height="120" rx="8" fill="#141414" stroke="#ff6b6b" stroke-width="1"/>
<text x="40" y="260" fill="#ff6b6b" font-size="13" font-family="monospace" font-weight="bold">CHAIN RULE IN ACTION</text>
<text x="40" y="290" fill="#888" font-size="12" font-family="monospace">To find dL/dw1:</text>
<text x="40" y="315" fill="#f0b030" font-size="12" font-family="monospace">dL/dw1 = dL/dy * dy/dz * dz/dw1</text>
<text x="40" y="340" fill="#4ade80" font-size="11" font-family="monospace">Multiply derivatives along the path!</text>
<!-- Visual chain - wider boxes -->
<rect x="455" y="235" width="425" height="120" rx="8" fill="#141414" stroke="#4ade80" stroke-width="1"/>
<text x="475" y="260" fill="#4ade80" font-size="13" font-family="monospace" font-weight="bold">VISUAL MULTIPLICATION</text>
<rect x="475" y="280" width="62" height="35" rx="4" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="1"/>
<text x="506" y="302" text-anchor="middle" fill="#ff6b6b" font-size="10" font-family="monospace">dL/dy</text>
<text x="545" y="302" fill="#888" font-size="16" font-family="monospace">×</text>
<rect x="562" y="280" width="62" height="35" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/>
<text x="593" y="302" text-anchor="middle" fill="#5b9bd5" font-size="10" font-family="monospace">dy/dz</text>
<text x="632" y="302" fill="#888" font-size="16" font-family="monospace">×</text>
<rect x="650" y="280" width="70" height="35" rx="4" fill="#1a3a2a" stroke="#4ade80" stroke-width="1"/>
<text x="685" y="302" text-anchor="middle" fill="#4ade80" font-size="10" font-family="monospace">dz/dw1</text>
<text x="728" y="302" fill="#888" font-size="16" font-family="monospace">=</text>
<rect x="745" y="280" width="70" height="35" rx="4" fill="#2a2a1a" stroke="#f0b030" stroke-width="2"/>
<text x="780" y="302" text-anchor="middle" fill="#f0b030" font-size="10" font-family="monospace">dL/dw1</text>
<text x="475" y="340" fill="#888" font-size="11" font-family="monospace">upstream × local = pass backward</text>
<!-- Interactive prompt -->
<rect x="20" y="370" width="860" height="115" rx="8" fill="#1a1a1a" stroke="#888" stroke-width="1" stroke-dasharray="5,5"/>
<text x="450" y="395" text-anchor="middle" fill="#888" font-size="14" font-family="monospace">INTERACTIVE TERMINAL</text>
<text x="450" y="420" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">Blue arrows = forward data | Red arrows = backward gradients</text>
<text x="450" y="445" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">Click "EXECUTE FULL BACKPROP" to see all values calculated</text>
<text x="450" y="470" text-anchor="middle" fill="#555" font-size="10" font-family="monospace">[Vault-Tec tip: errors propagate backward, just like rumors in the cafeteria]</text>
</svg>
'''
BACKWARD_INTRO = f"""
{BACKWARD_INTRO_SVG}
"""
def backward_pass_demo(x1, x2, w1, w2, b, y_true):
"""Complete forward + backward pass with detailed chain rule."""
# Forward pass
z = w1 * x1 + w2 * x2 + b
y_pred = 1 / (1 + np.exp(-z))
# Binary cross-entropy loss (with small epsilon for numerical stability)
eps = 1e-7
y_pred_clipped = np.clip(y_pred, eps, 1 - eps)
loss = -(y_true * np.log(y_pred_clipped) + (1 - y_true) * np.log(1 - y_pred_clipped))
# Backward pass - compute all gradients
# dL/dy_pred
dL_dy = -y_true / y_pred_clipped + (1 - y_true) / (1 - y_pred_clipped)
# dy_pred/dz (sigmoid derivative)
dy_dz = y_pred * (1 - y_pred)
# dz/dw1, dz/dw2, dz/db
dz_dw1 = x1
dz_dw2 = x2
dz_db = 1
# Chain rule to get final gradients
dL_dz = dL_dy * dy_dz # This is the "upstream gradient"
dL_dw1 = dL_dz * dz_dw1
dL_dw2 = dL_dz * dz_dw2
dL_db = dL_dz * dz_db
# Generate SVG diagram
svg_diagram = generate_backward_svg(
x1, x2, w1, w2, b, y_true, z, y_pred,
dL_dy, dy_dz, dL_dz, dL_dw1, dL_dw2, dL_db, loss
)
explanation = f"""
## COMPLETE BACKPROP WALKTHROUGH
===============================================
### GIVEN:
```
Inputs: x1 = {x1}, x2 = {x2}
Weights: w1 = {w1}, w2 = {w2}
Bias: b = {b}
True label: y_true = {y_true}
```
===============================================
## PART 1: FORWARD PASS (review)
===============================================
**Step 1a: Weighted sum**
```
z = w1*x1 + w2*x2 + b
= ({w1})*({x1}) + ({w2})*({x2}) + ({b})
= {z:.6f}
```
**Step 1b: Sigmoid activation**
```
y_pred = sigmoid(z) = 1/(1+e^(-z))
= 1/(1+e^(-{z:.4f}))
= {y_pred:.6f}
```
**Step 1c: Binary Cross-Entropy Loss**
```
L = -[y_true*log(y_pred) + (1-y_true)*log(1-y_pred)]
= -[{y_true}*log({y_pred:.6f}) + {1-y_true}*log({1-y_pred:.6f})]
= -[{y_true * np.log(y_pred_clipped):.6f} + {(1-y_true) * np.log(1-y_pred_clipped):.6f}]
= {loss:.6f}
```
===============================================
## PART 2: BACKWARD PASS (reversing the flow)
===============================================
We need: dL/dw1, dL/dw2, dL/db
**The computation graph:**
```
w1,x1,w2,x2,b --> z --> y_pred --> L
| | |
dz/dw dy/dz dL/dy
```
We work BACKWARDS from Loss to weights.
-----------------------------------------------
### STEP 2a: dL/dy_pred (how loss changes with prediction)
```
L = -y_true*log(y_pred) - (1-y_true)*log(1-y_pred)
dL/dy_pred = -y_true/y_pred + (1-y_true)/(1-y_pred)
= -{y_true}/{y_pred:.6f} + {1-y_true}/{1-y_pred:.6f}
= {-y_true/y_pred_clipped:.6f} + {(1-y_true)/(1-y_pred_clipped):.6f}
= {dL_dy:.6f}
```
-----------------------------------------------
### STEP 2b: dy_pred/dz (sigmoid derivative)
Using: d/dz[sigmoid(z)] = sigmoid(z)*(1-sigmoid(z))
```
dy/dz = y_pred * (1 - y_pred)
= {y_pred:.6f} * (1 - {y_pred:.6f})
= {y_pred:.6f} * {1-y_pred:.6f}
= {dy_dz:.6f}
```
-----------------------------------------------
### STEP 2c: dz/dw1, dz/dw2, dz/db
Since z = w1*x1 + w2*x2 + b:
```
dz/dw1 = x1 = {dz_dw1}
dz/dw2 = x2 = {dz_dw2}
dz/db = 1 = {dz_db}
```
-----------------------------------------------
### STEP 2d: CHAIN RULE - Put it together!
First, compute dL/dz (the "upstream gradient"):
```
dL/dz = (dL/dy_pred) * (dy_pred/dz)
= {dL_dy:.6f} * {dy_dz:.6f}
= {dL_dz:.6f}
```
Now chain to each weight:
```
dL/dw1 = (dL/dz) * (dz/dw1)
= {dL_dz:.6f} * {dz_dw1}
= {dL_dw1:.6f}
dL/dw2 = (dL/dz) * (dz/dw2)
= {dL_dz:.6f} * {dz_dw2}
= {dL_dw2:.6f}
dL/db = (dL/dz) * (dz/db)
= {dL_dz:.6f} * {dz_db}
= {dL_db:.6f}
```
===============================================
## PART 3: GRADIENT DESCENT UPDATE
===============================================
With learning rate α = 0.1:
```
w1_new = w1 - α * dL/dw1
= {w1} - 0.1 * {dL_dw1:.6f}
= {w1 - 0.1 * dL_dw1:.6f}
w2_new = w2 - α * dL/dw2
= {w2} - 0.1 * {dL_dw2:.6f}
= {w2 - 0.1 * dL_dw2:.6f}
b_new = b - α * dL/db
= {b} - 0.1 * {dL_db:.6f}
= {b - 0.1 * dL_db:.6f}
```
**We've completed one step of learning!**
===============================================
## SUMMARY TABLE
===============================================
| Gradient | Value | Meaning |
|----------|-------|---------|
| dL/dy | {dL_dy:.4f} | Loss sensitivity to prediction |
| dy/dz | {dy_dz:.4f} | Sigmoid sensitivity |
| dL/dz | {dL_dz:.4f} | "Upstream gradient" |
| dL/dw1 | {dL_dw1:.4f} | How to adjust w1 |
| dL/dw2 | {dL_dw2:.4f} | How to adjust w2 |
| dL/db | {dL_db:.4f} | How to adjust bias |
"""
return svg_diagram, explanation
# ============================================================================
# TAB 5: PRACTICE PROBLEMS
# ============================================================================
PRACTICE_INTRO = """
# PRACTICE: COMPUTE BY HAND FIRST!
===============================================
Welcome to the Gradient Occupational Aptitude Test (G.O.A.T.).
Per Vault-Tec guidelines, pencil-and-paper practice builds neural
pathways (the biological kind). Complete these problems to determine
your future as a Machine Learning Specialist.
## TIPS FOR HAND CALCULATION
1. **Draw the computation graph** - boxes for operations,
arrows for data flow
2. **Forward pass first** - compute all intermediate values
3. **Backward pass** - start from loss, work backwards
4. **Check dimensions** - gradient of scalar w.r.t. vector
has same shape as the vector
5. **Verify numerically** - if unsure, use tiny h to approximate:
df/dx ≈ (f(x+h) - f(x)) / h
## PRACTICE PROBLEMS
Select a problem below and try it before clicking "Check Answer"!
"""
def practice_problem(problem_num):
"""Generate practice problems with solutions."""
problems = {
1: {
"question": """
### Problem 1: Simple Chain Rule
Compute dy/dx where:
```
y = (2x + 3)^3
```
at x = 1.
**Hint:** Let u = 2x + 3, so y = u^3
""",
"solution": """
### Solution to Problem 1
**Step 1: Identify the composition**
```
u = 2x + 3 (inner)
y = u^3 (outer)
```
**Step 2: Find individual derivatives**
```
du/dx = 2
dy/du = 3u^2
```
**Step 3: Apply chain rule**
```
dy/dx = (dy/du) * (du/dx)
= 3u^2 * 2
= 6u^2
= 6(2x + 3)^2
```
**Step 4: Evaluate at x = 1**
```
dy/dx = 6(2*1 + 3)^2
= 6(5)^2
= 6 * 25
= 150
```
**Answer: dy/dx = 150 at x = 1**
"""
},
2: {
"question": """
### Problem 2: Sigmoid Derivative
Given z = 2, compute:
1. sigmoid(z)
2. d/dz[sigmoid(z)]
**Reminder:** sigmoid(z) = 1/(1+e^(-z))
d/dz[sigmoid(z)] = sigmoid(z) * (1 - sigmoid(z))
""",
"solution": f"""
### Solution to Problem 2
**Step 1: Compute sigmoid(2)**
```
sigmoid(2) = 1/(1 + e^(-2))
= 1/(1 + {np.exp(-2):.6f})
= 1/{1 + np.exp(-2):.6f}
= {1/(1+np.exp(-2)):.6f}
```
**Step 2: Compute derivative**
```
Let s = sigmoid(2) = {1/(1+np.exp(-2)):.6f}
ds/dz = s * (1 - s)
= {1/(1+np.exp(-2)):.6f} * (1 - {1/(1+np.exp(-2)):.6f})
= {1/(1+np.exp(-2)):.6f} * {1 - 1/(1+np.exp(-2)):.6f}
= {(1/(1+np.exp(-2))) * (1 - 1/(1+np.exp(-2))):.6f}
```
**Answers:**
- sigmoid(2) ≈ 0.8808
- d/dz[sigmoid(2)] ≈ 0.1050
"""
},
3: {
"question": """
### Problem 3: Full Backprop (Mini Version)
Single neuron with:
```
x = 2
w = 0.5
b = -1
y_true = 1
```
Using sigmoid activation and BCE loss, find dL/dw.
**Steps to follow:**
1. Forward: z = wx + b
2. Forward: y_pred = sigmoid(z)
3. Forward: L = BCE(y_true, y_pred)
4. Backward: Apply chain rule
""",
"solution": """
### Solution to Problem 3
**Forward Pass:**
```
z = w*x + b = 0.5*2 + (-1) = 0
y_pred = sigmoid(0) = 0.5
L = -[1*log(0.5) + 0*log(0.5)]
= -log(0.5)
= 0.693
```
**Backward Pass:**
dL/dy_pred:
```
= -y_true/y_pred + (1-y_true)/(1-y_pred)
= -1/0.5 + 0/0.5
= -2
```
dy_pred/dz:
```
= y_pred * (1 - y_pred)
= 0.5 * 0.5 = 0.25
```
dz/dw:
```
= x = 2
```
**Chain Rule:**
```
dL/dw = (dL/dy) * (dy/dz) * (dz/dw)
= (-2) * (0.25) * (2)
= -1.0
```
**Answer: dL/dw = -1.0**
**Interpretation:** Negative gradient means we should
INCREASE w to reduce loss (moving opposite to gradient).
"""
}
}
prob = problems.get(problem_num, problems[1])
return prob["question"], prob["solution"]
# ============================================================================
# BUILD THE GRADIO APP
# ============================================================================
with gr.Blocks(title="BACKPROP TERMINAL v1.0") as demo:
gr.Markdown("""
# > VAULT-TEC NEURAL NETWORK TRAINING TERMINAL
## > SECURITY CLEARANCE: STAT 3106
### > INITIALIZING BACKPROPAGATION MODULES...
""")
with gr.Tabs():
# TAB 1: Forward Pass
with gr.TabItem("01: FORWARD PASS"):
gr.HTML(FORWARD_INTRO)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### INPUT PARAMETERS")
x1_input = gr.Slider(minimum=-5, maximum=5, value=1.0, step=0.1, label="x1 (input 1)")
x2_input = gr.Slider(minimum=-5, maximum=5, value=2.0, step=0.1, label="x2 (input 2)")
w1_input = gr.Slider(minimum=-2, maximum=2, value=0.5, step=0.1, label="w1 (weight 1)")
w2_input = gr.Slider(minimum=-2, maximum=2, value=-0.3, step=0.1, label="w2 (weight 2)")
b_input = gr.Slider(minimum=-2, maximum=2, value=0.1, step=0.1, label="b (bias)")
forward_btn = gr.Button(">> EXECUTE FORWARD PASS <<")
with gr.Column(scale=2):
forward_svg = gr.HTML(label="Computation Graph")
forward_output = gr.Markdown(label="Calculation")
forward_btn.click(
forward_pass_demo,
inputs=[x1_input, x2_input, w1_input, w2_input, b_input],
outputs=[forward_svg, forward_output]
)
# TAB 2: Chain Rule
with gr.TabItem("02: CHAIN RULE"):
gr.HTML(CHAIN_RULE_INTRO)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### FUNCTION: y = (ax + b)²")
a_input = gr.Slider(minimum=-5, maximum=5, value=3.0, step=0.1, label="a (coefficient)")
b2_input = gr.Slider(minimum=-5, maximum=5, value=2.0, step=0.1, label="b (constant)")
x_input = gr.Slider(minimum=-5, maximum=5, value=1.0, step=0.1, label="x (evaluation point)")
chain_btn = gr.Button(">> APPLY CHAIN RULE <<")
with gr.Column(scale=2):
chain_svg = gr.HTML(label="Chain Rule Visualization")
chain_output = gr.Markdown(label="Chain Rule Breakdown")
chain_btn.click(
chain_rule_calculator,
inputs=[a_input, b2_input, x_input],
outputs=[chain_svg, chain_output]
)
# TAB 3: Key Derivatives
with gr.TabItem("03: KEY DERIVATIVES"):
gr.HTML(DERIVATIVES_INTRO)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### SIGMOID DERIVATIVE CALCULATOR")
z_input = gr.Slider(
minimum=-5, maximum=5, value=0, step=0.1,
label="z value"
)
sigmoid_btn = gr.Button(">> COMPUTE SIGMOID DERIVATIVE <<")
with gr.Column(scale=2):
sigmoid_svg = gr.HTML(label="Sigmoid Visualization")
sigmoid_output = gr.Markdown(label="Derivative Calculation")
sigmoid_btn.click(
sigmoid_derivative_demo,
inputs=[z_input],
outputs=[sigmoid_svg, sigmoid_output]
)
# TAB 4: Backward Pass
with gr.TabItem("04: BACKWARD PASS"):
gr.HTML(BACKWARD_INTRO)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### NETWORK CONFIGURATION")
bx1 = gr.Slider(minimum=-5, maximum=5, value=1.0, step=0.1, label="x1")
bx2 = gr.Slider(minimum=-5, maximum=5, value=2.0, step=0.1, label="x2")
bw1 = gr.Slider(minimum=-2, maximum=2, value=0.5, step=0.1, label="w1")
bw2 = gr.Slider(minimum=-2, maximum=2, value=-0.3, step=0.1, label="w2")
bb = gr.Slider(minimum=-2, maximum=2, value=0.1, step=0.1, label="bias")
by_true = gr.Slider(minimum=0, maximum=1, value=1, step=1, label="y_true (0 or 1)")
back_btn = gr.Button(">> EXECUTE FULL BACKPROP <<")
with gr.Column(scale=2):
back_svg = gr.HTML(label="Backprop Graph")
back_output = gr.Markdown(label="Complete Backprop Trace")
back_btn.click(
backward_pass_demo,
inputs=[bx1, bx2, bw1, bw2, bb, by_true],
outputs=[back_svg, back_output]
)
# TAB 5: Practice
with gr.TabItem("05: PRACTICE"):
gr.Markdown(PRACTICE_INTRO)
with gr.Row():
with gr.Column():
problem_select = gr.Radio(
choices=["Problem 1: Chain Rule", "Problem 2: Sigmoid", "Problem 3: Full Backprop"],
label="Select Problem",
value="Problem 1: Chain Rule"
)
show_problem_btn = gr.Button(">> SHOW PROBLEM <<")
show_answer_btn = gr.Button(">> REVEAL SOLUTION <<")
with gr.Column():
problem_display = gr.Markdown(label="Problem")
solution_display = gr.Markdown(label="Solution", visible=False)
def show_problem(selection):
prob_num = int(selection.split(":")[0].split()[-1])
q, _ = practice_problem(prob_num)
return q, gr.update(visible=False, value="")
def show_solution(selection):
prob_num = int(selection.split(":")[0].split()[-1])
_, s = practice_problem(prob_num)
return gr.update(visible=True, value=s)
show_problem_btn.click(
show_problem,
inputs=[problem_select],
outputs=[problem_display, solution_display]
)
show_answer_btn.click(
show_solution,
inputs=[problem_select],
outputs=[solution_display]
)
# TAB 6: Quick Reference
with gr.TabItem("06: REFERENCE"):
gr.Markdown("""
# QUICK REFERENCE CARD
===============================================
## CHAIN RULE
```
y = f(g(x))
dy/dx = (df/dg) * (dg/dx)
```
For longer chains: just multiply all the derivatives!
-----------------------------------------------
## COMMON DERIVATIVES
| Function | Derivative |
|----------|------------|
| x^n | n*x^(n-1) |
| e^x | e^x |
| log(x) | 1/x |
| sigmoid(x) | sigmoid(x)*(1-sigmoid(x)) |
| ReLU(x) | 1 if x>0, else 0 |
-----------------------------------------------
## NEURAL NETWORK CHAIN
For a single neuron with sigmoid:
```
z = Σ(wi*xi) + b
y = sigmoid(z)
L = loss(y, y_true)
dL/dwi = (dL/dy) * (dy/dz) * (dz/dwi)
= (dL/dy) * sigmoid'(z) * xi
```
-----------------------------------------------
## GRADIENT DESCENT
```
w_new = w_old - learning_rate * dL/dw
```
The gradient points UPHILL; we go opposite direction.
-----------------------------------------------
## BCE LOSS GRADIENT (sigmoid output)
For BCE loss with sigmoid output:
```
dL/dz = y_pred - y_true
```
This clean result comes from cancellation in the chain!
-----------------------------------------------
## DEBUGGING TIPS
1. **Gradient check:** Compare with numerical gradient
```
dL/dw ≈ [L(w+h) - L(w-h)] / (2h)
```
2. **Shapes must match:** gradient of L w.r.t. W has same shape as W
3. **Large gradients?** Try gradient clipping or smaller learning rate
4. **Vanishing gradients?** Consider ReLU or residual connections
""")
gr.Markdown("""
---
> TERMINAL SESSION ACTIVE
> VAULT-TEC WISHES YOU A PLEASANT TRAINING EXPERIENCE
""")
if __name__ == "__main__":
demo.launch(
server_port=7860,
css=FALLOUT_CSS,
js="""
() => {
// Force dark mode and hide theme toggle
document.body.classList.add('dark');
const style = document.createElement('style');
style.textContent = `
.dark-mode-toggle, [aria-label="Toggle dark mode"],
button[title*="theme"], .theme-toggle { display: none !important; }
`;
document.head.appendChild(style);
}
"""
)