| import gradio as gr |
| import numpy as np |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| FALLOUT_CSS = """ |
| @import url('https://fonts.googleapis.com/css2?family=VT323&display=swap'); |
| @import url('https://fonts.googleapis.com/css2?family=Share+Tech+Mono&display=swap'); |
| |
| :root { |
| --pip-amber: #f0b030; |
| --pip-amber-dim: #c49028; |
| --terminal-green: #4ade80; |
| --terminal-green-dim: #22c55e; |
| --vault-blue: #5b9bd5; |
| --vault-blue-dim: #4080b8; |
| --bg-dark: #0c0c0c; |
| --bg-panel: #141414; |
| --bg-input: #1a1a1a; |
| --text-muted: #888888; |
| } |
| |
| * { |
| font-family: 'Share Tech Mono', 'VT323', monospace !important; |
| font-size: 20px !important; |
| line-height: 1.6 !important; |
| } |
| |
| h1 { font-size: 36px !important; } |
| h2 { font-size: 30px !important; } |
| h3 { font-size: 24px !important; } |
| h4, h5 { font-size: 22px !important; } |
| code, pre { font-size: 18px !important; } |
| |
| body, .gradio-container { |
| background-color: var(--bg-dark) !important; |
| } |
| |
| .gradio-container { |
| max-width: 1200px !important; |
| } |
| |
| /* Main text - soft green, NO glow */ |
| .markdown-text, .prose, p, span, label, .label-wrap { |
| color: var(--terminal-green) !important; |
| } |
| |
| /* Headers - warm amber for hierarchy */ |
| h1 { |
| color: var(--pip-amber) !important; |
| border-bottom: 2px solid var(--pip-amber-dim) !important; |
| padding-bottom: 8px !important; |
| } |
| |
| h2 { |
| color: var(--pip-amber) !important; |
| border-bottom: 1px solid var(--pip-amber-dim) !important; |
| padding-bottom: 4px !important; |
| } |
| |
| h3, h4, h5 { |
| color: var(--vault-blue) !important; |
| border-bottom: none !important; |
| } |
| |
| /* Tab styling - Vault-Tec blue for navigation */ |
| .tabs { |
| background-color: var(--bg-dark) !important; |
| border: 1px solid var(--vault-blue-dim) !important; |
| border-radius: 4px !important; |
| } |
| |
| .tab-nav { |
| background-color: var(--bg-panel) !important; |
| border-bottom: 2px solid var(--vault-blue-dim) !important; |
| } |
| |
| .tab-nav button { |
| background-color: var(--bg-panel) !important; |
| color: var(--vault-blue) !important; |
| border: none !important; |
| border-right: 1px solid var(--bg-dark) !important; |
| padding: 10px 16px !important; |
| transition: all 0.2s ease !important; |
| } |
| |
| .tab-nav button:hover { |
| background-color: #1e3a5f !important; |
| color: #8ec5fc !important; |
| } |
| |
| .tab-nav button.selected { |
| background-color: #1a3550 !important; |
| color: #8ec5fc !important; |
| border-bottom: 2px solid var(--pip-amber) !important; |
| } |
| |
| /* Input/Output boxes - subtle with green text */ |
| .textbox, textarea, input { |
| background-color: var(--bg-input) !important; |
| color: var(--terminal-green) !important; |
| border: 1px solid #333 !important; |
| border-radius: 3px !important; |
| } |
| |
| .textbox:focus, textarea:focus, input:focus { |
| border-color: var(--terminal-green-dim) !important; |
| outline: none !important; |
| } |
| |
| /* Buttons - amber accent for actions */ |
| .primary, .secondary, button { |
| background-color: #2a2010 !important; |
| color: var(--pip-amber) !important; |
| border: 1px solid var(--pip-amber-dim) !important; |
| border-radius: 3px !important; |
| transition: all 0.2s ease !important; |
| } |
| |
| button:hover { |
| background-color: #3d2e15 !important; |
| border-color: var(--pip-amber) !important; |
| } |
| |
| /* Sliders - amber accent */ |
| input[type="range"] { |
| accent-color: var(--pip-amber) !important; |
| } |
| |
| /* Number inputs */ |
| .number-input input { |
| background-color: var(--bg-input) !important; |
| color: var(--terminal-green) !important; |
| border: 1px solid #333 !important; |
| } |
| |
| /* Code blocks - slightly blue-tinted for distinction */ |
| code, pre { |
| background-color: #0d1520 !important; |
| color: var(--terminal-green) !important; |
| border: 1px solid #2a4060 !important; |
| border-left: 3px solid var(--vault-blue) !important; |
| border-radius: 3px !important; |
| padding: 2px 6px !important; |
| } |
| |
| pre { |
| padding: 12px !important; |
| } |
| |
| /* Tables */ |
| table { |
| border-collapse: collapse !important; |
| } |
| |
| th { |
| background-color: #1a2a3a !important; |
| color: var(--pip-amber) !important; |
| border: 1px solid #2a4060 !important; |
| padding: 8px !important; |
| } |
| |
| td { |
| background-color: var(--bg-panel) !important; |
| color: var(--terminal-green) !important; |
| border: 1px solid #2a4060 !important; |
| padding: 8px !important; |
| } |
| |
| /* Strong/bold text - amber for emphasis */ |
| strong, b { |
| color: var(--pip-amber) !important; |
| font-weight: bold !important; |
| } |
| |
| /* Links */ |
| a { |
| color: var(--vault-blue) !important; |
| } |
| |
| a:hover { |
| color: #8ec5fc !important; |
| } |
| |
| /* Radio buttons and checkboxes */ |
| .radio-group label, .checkbox-group label { |
| color: var(--terminal-green) !important; |
| } |
| |
| /* Scrollbar - subtle */ |
| ::-webkit-scrollbar { |
| width: 8px; |
| height: 8px; |
| background-color: var(--bg-dark); |
| } |
| |
| ::-webkit-scrollbar-thumb { |
| background-color: #333; |
| border-radius: 4px; |
| } |
| |
| ::-webkit-scrollbar-thumb:hover { |
| background-color: #444; |
| } |
| |
| /* Subtle scanlines - very light, not distracting */ |
| .gradio-container::before { |
| content: ""; |
| position: fixed; |
| top: 0; |
| left: 0; |
| width: 100%; |
| height: 100%; |
| background: repeating-linear-gradient( |
| 0deg, |
| rgba(0, 0, 0, 0.03), |
| rgba(0, 0, 0, 0.03) 1px, |
| transparent 1px, |
| transparent 2px |
| ); |
| pointer-events: none; |
| z-index: 1000; |
| } |
| |
| /* Horizontal rules - amber accent */ |
| hr { |
| border: none !important; |
| border-top: 1px solid var(--pip-amber-dim) !important; |
| margin: 16px 0 !important; |
| } |
| |
| /* Blockquotes - for terminal prompts */ |
| blockquote { |
| border-left: 3px solid var(--pip-amber) !important; |
| background-color: var(--bg-panel) !important; |
| padding: 8px 16px !important; |
| margin: 8px 0 !important; |
| color: var(--pip-amber) !important; |
| } |
| |
| /* Muted/secondary text */ |
| .secondary-text, .hint { |
| color: var(--text-muted) !important; |
| } |
| """ |
|
|
| |
| |
| |
|
|
| def generate_forward_svg(x1, x2, w1, w2, b, z, y): |
| """Generate an SVG diagram showing the forward pass with actual values.""" |
|
|
| |
| bg = "#0c0c0c" |
| node_fill = "#1a2a3a" |
| node_stroke = "#5b9bd5" |
| input_fill = "#1a3a2a" |
| input_stroke = "#4ade80" |
| output_fill = "#2a2a1a" |
| output_stroke = "#f0b030" |
| text_color = "#4ade80" |
| label_color = "#5b9bd5" |
| arrow_color = "#5b9bd5" |
| value_color = "#f0b030" |
|
|
| svg = f''' |
| <svg viewBox="0 0 800 320" style="width:100%; max-width:800px; height:auto; background:{bg}; border-radius:8px; border:1px solid #333;"> |
| <defs> |
| <marker id="arrowhead" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"> |
| <polygon points="0 0, 10 3.5, 0 7" fill="{arrow_color}" /> |
| </marker> |
| <marker id="arrowhead-amber" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"> |
| <polygon points="0 0, 10 3.5, 0 7" fill="{output_stroke}" /> |
| </marker> |
| </defs> |
| |
| <!-- Title --> |
| <text x="400" y="30" text-anchor="middle" fill="{output_stroke}" font-size="20" font-family="monospace">FORWARD PASS: Data Flow</text> |
| |
| <!-- Input nodes --> |
| <rect x="40" y="60" width="80" height="60" rx="8" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/> |
| <text x="80" y="85" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">x₁</text> |
| <text x="80" y="108" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace" font-weight="bold">{x1:.2f}</text> |
| |
| <rect x="40" y="200" width="80" height="60" rx="8" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/> |
| <text x="80" y="225" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">x₂</text> |
| <text x="80" y="248" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace" font-weight="bold">{x2:.2f}</text> |
| |
| <!-- Weight labels on arrows --> |
| <line x1="120" y1="90" x2="220" y2="140" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arrowhead)"/> |
| <text x="155" y="100" fill="{text_color}" font-size="12" font-family="monospace">w₁={w1:.2f}</text> |
| |
| <line x1="120" y1="230" x2="220" y2="180" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arrowhead)"/> |
| <text x="155" y="230" fill="{text_color}" font-size="12" font-family="monospace">w₂={w2:.2f}</text> |
| |
| <!-- Summation node --> |
| <rect x="230" y="130" width="100" height="70" rx="8" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/> |
| <text x="280" y="152" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">Σ + b</text> |
| <text x="280" y="175" text-anchor="middle" fill="{text_color}" font-size="11" font-family="monospace">b={b:.2f}</text> |
| <text x="280" y="193" text-anchor="middle" fill="{value_color}" font-size="14" font-family="monospace" font-weight="bold">z={z:.3f}</text> |
| |
| <!-- Arrow to sigmoid --> |
| <line x1="330" y1="165" x2="400" y2="165" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arrowhead)"/> |
| |
| <!-- Sigmoid node --> |
| <rect x="410" y="130" width="100" height="70" rx="8" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/> |
| <text x="460" y="155" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">σ(z)</text> |
| <text x="460" y="175" text-anchor="middle" fill="{text_color}" font-size="10" font-family="monospace">1/(1+e⁻ᶻ)</text> |
| <text x="460" y="193" text-anchor="middle" fill="{value_color}" font-size="14" font-family="monospace" font-weight="bold">ŷ={y:.4f}</text> |
| |
| <!-- Arrow to output --> |
| <line x1="510" y1="165" x2="580" y2="165" stroke="{output_stroke}" stroke-width="2" marker-end="url(#arrowhead-amber)"/> |
| |
| <!-- Output node --> |
| <rect x="590" y="130" width="100" height="70" rx="8" fill="{output_fill}" stroke="{output_stroke}" stroke-width="2"/> |
| <text x="640" y="155" text-anchor="middle" fill="{output_stroke}" font-size="14" font-family="monospace">OUTPUT</text> |
| <text x="640" y="180" text-anchor="middle" fill="{value_color}" font-size="18" font-family="monospace" font-weight="bold">{y:.4f}</text> |
| |
| <!-- Legend --> |
| <rect x="40" y="280" width="15" height="15" fill="{input_fill}" stroke="{input_stroke}" stroke-width="1"/> |
| <text x="60" y="292" fill="{text_color}" font-size="12" font-family="monospace">Inputs</text> |
| |
| <rect x="140" y="280" width="15" height="15" fill="{node_fill}" stroke="{node_stroke}" stroke-width="1"/> |
| <text x="160" y="292" fill="{text_color}" font-size="12" font-family="monospace">Operations</text> |
| |
| <rect x="280" y="280" width="15" height="15" fill="{output_fill}" stroke="{output_stroke}" stroke-width="1"/> |
| <text x="300" y="292" fill="{text_color}" font-size="12" font-family="monospace">Output</text> |
| |
| <text x="420" y="292" fill="{value_color}" font-size="12" font-family="monospace">■ Computed Values</text> |
| </svg> |
| ''' |
| return svg |
|
|
|
|
| def generate_backward_svg(x1, x2, w1, w2, b, y_true, z, y_pred, dL_dy, dy_dz, dL_dz, dL_dw1, dL_dw2, dL_db, loss): |
| """Generate an SVG diagram showing backward pass with gradients.""" |
|
|
| bg = "#0c0c0c" |
| node_fill = "#1a2a3a" |
| node_stroke = "#5b9bd5" |
| input_fill = "#1a3a2a" |
| input_stroke = "#4ade80" |
| loss_fill = "#3a1a1a" |
| loss_stroke = "#ff6b6b" |
| text_color = "#4ade80" |
| label_color = "#5b9bd5" |
| forward_arrow = "#5b9bd5" |
| backward_arrow = "#ff6b6b" |
| value_color = "#f0b030" |
| gradient_color = "#ff6b6b" |
|
|
| svg = f''' |
| <svg viewBox="0 0 900 420" style="width:100%; max-width:900px; height:auto; background:{bg}; border-radius:8px; border:1px solid #333;"> |
| <defs> |
| <marker id="fwd" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"> |
| <polygon points="0 0, 10 3.5, 0 7" fill="{forward_arrow}" /> |
| </marker> |
| <marker id="bwd" markerWidth="10" markerHeight="7" refX="0" refY="3.5" orient="auto"> |
| <polygon points="10 0, 0 3.5, 10 7" fill="{backward_arrow}" /> |
| </marker> |
| </defs> |
| |
| <!-- Title --> |
| <text x="300" y="25" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace">GRADIENT FLOW DIAGRAM</text> |
| |
| <!-- FORWARD PATH - Row 1 --> |
| <!-- Input x1 --> |
| <rect x="20" y="50" width="60" height="50" rx="5" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/> |
| <text x="50" y="70" text-anchor="middle" fill="{label_color}" font-size="11" font-family="monospace">x1</text> |
| <text x="50" y="88" text-anchor="middle" fill="{value_color}" font-size="12" font-family="monospace">{x1:.2f}</text> |
| |
| <!-- Weight w1 --> |
| <line x1="80" y1="75" x2="105" y2="75" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/> |
| <rect x="115" y="55" width="55" height="40" rx="4" fill="#1a1a2a" stroke="{forward_arrow}" stroke-width="1"/> |
| <text x="142" y="80" text-anchor="middle" fill="{value_color}" font-size="10" font-family="monospace">w1={w1:.1f}</text> |
| |
| <!-- Arrow to Sum --> |
| <line x1="170" y1="75" x2="195" y2="100" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/> |
| |
| <!-- Input x2 --> |
| <rect x="20" y="120" width="60" height="50" rx="5" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/> |
| <text x="50" y="140" text-anchor="middle" fill="{label_color}" font-size="11" font-family="monospace">x2</text> |
| <text x="50" y="158" text-anchor="middle" fill="{value_color}" font-size="12" font-family="monospace">{x2:.2f}</text> |
| |
| <!-- Weight w2 --> |
| <line x1="80" y1="145" x2="105" y2="145" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/> |
| <rect x="115" y="125" width="55" height="40" rx="4" fill="#1a1a2a" stroke="{forward_arrow}" stroke-width="1"/> |
| <text x="142" y="150" text-anchor="middle" fill="{value_color}" font-size="10" font-family="monospace">w2={w2:.1f}</text> |
| |
| <!-- Arrow to Sum --> |
| <line x1="170" y1="145" x2="195" y2="120" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/> |
| |
| <!-- Sum node --> |
| <rect x="205" y="90" width="70" height="50" rx="5" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/> |
| <text x="240" y="110" text-anchor="middle" fill="{label_color}" font-size="11" font-family="monospace">Sum+b</text> |
| <text x="240" y="128" text-anchor="middle" fill="{value_color}" font-size="10" font-family="monospace">z={z:.2f}</text> |
| |
| <!-- Arrow to Sigmoid --> |
| <line x1="275" y1="115" x2="310" y2="115" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/> |
| |
| <!-- Sigmoid --> |
| <rect x="320" y="90" width="70" height="50" rx="5" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/> |
| <text x="355" y="110" text-anchor="middle" fill="{label_color}" font-size="11" font-family="monospace">sigmoid</text> |
| <text x="355" y="128" text-anchor="middle" fill="{value_color}" font-size="10" font-family="monospace">y={y_pred:.3f}</text> |
| |
| <!-- Arrow to Loss --> |
| <line x1="390" y1="115" x2="425" y2="115" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/> |
| |
| <!-- Loss --> |
| <rect x="435" y="85" width="80" height="60" rx="5" fill="{loss_fill}" stroke="{loss_stroke}" stroke-width="2"/> |
| <text x="475" y="105" text-anchor="middle" fill="{loss_stroke}" font-size="11" font-family="monospace">BCE</text> |
| <text x="475" y="122" text-anchor="middle" fill="{value_color}" font-size="11" font-family="monospace">L={loss:.4f}</text> |
| <text x="475" y="138" text-anchor="middle" fill="{text_color}" font-size="9" font-family="monospace">y_true={y_true}</text> |
| |
| <!-- BACKWARD SECTION --> |
| <text x="300" y="185" text-anchor="middle" fill="{backward_arrow}" font-size="12" font-family="monospace">BACKWARD PASS (gradients)</text> |
| |
| <!-- Gradient chain boxes --> |
| <rect x="435" y="200" width="80" height="35" rx="4" fill="{loss_fill}" stroke="{loss_stroke}" stroke-width="1"/> |
| <text x="475" y="222" text-anchor="middle" fill="{gradient_color}" font-size="9" font-family="monospace">dL/dy={dL_dy:.2f}</text> |
| |
| <line x1="435" y1="218" x2="405" y2="218" stroke="{backward_arrow}" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd)"/> |
| |
| <rect x="320" y="200" width="80" height="35" rx="4" fill="{node_fill}" stroke="{node_stroke}" stroke-width="1"/> |
| <text x="360" y="222" text-anchor="middle" fill="{gradient_color}" font-size="9" font-family="monospace">dy/dz={dy_dz:.3f}</text> |
| |
| <line x1="320" y1="218" x2="290" y2="218" stroke="{backward_arrow}" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd)"/> |
| |
| <rect x="205" y="200" width="80" height="35" rx="4" fill="{node_fill}" stroke="{node_stroke}" stroke-width="1"/> |
| <text x="245" y="222" text-anchor="middle" fill="{gradient_color}" font-size="9" font-family="monospace">dL/dz={dL_dz:.3f}</text> |
| |
| <line x1="205" y1="218" x2="175" y2="218" stroke="{backward_arrow}" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd)"/> |
| |
| <!-- Final Gradients Box --> |
| <rect x="20" y="260" width="240" height="80" rx="6" fill="#141414" stroke="{gradient_color}" stroke-width="1"/> |
| <text x="140" y="282" text-anchor="middle" fill="{value_color}" font-size="12" font-family="monospace">COMPUTED GRADIENTS</text> |
| <text x="140" y="305" text-anchor="middle" fill="{gradient_color}" font-size="11" font-family="monospace">dL/dw1 = {dL_dw1:.4f}</text> |
| <text x="140" y="325" text-anchor="middle" fill="{gradient_color}" font-size="11" font-family="monospace">dL/dw2 = {dL_dw2:.4f} dL/db = {dL_db:.4f}</text> |
| |
| <!-- Chain Rule Box --> |
| <rect x="530" y="50" width="360" height="190" rx="6" fill="#141414" stroke="#555" stroke-width="1"/> |
| <text x="710" y="75" text-anchor="middle" fill="{value_color}" font-size="13" font-family="monospace">CHAIN RULE COMPUTATION</text> |
| |
| <text x="545" y="100" fill="{text_color}" font-size="11" font-family="monospace">dL/dw1 = dL/dy * dy/dz * dz/dw1</text> |
| <text x="545" y="125" fill="#888" font-size="11" font-family="monospace"> = ({dL_dy:.2f}) * ({dy_dz:.4f}) * ({x1:.2f})</text> |
| <text x="545" y="150" fill="{value_color}" font-size="12" font-family="monospace"> = {dL_dw1:.4f}</text> |
| |
| <line x1="545" y1="165" x2="875" y2="165" stroke="#333" stroke-width="1"/> |
| |
| <text x="545" y="185" fill="{text_color}" font-size="10" font-family="monospace">Key: dz/dw1 = x1, dz/dw2 = x2, dz/db = 1</text> |
| <text x="545" y="205" fill="#888" font-size="10" font-family="monospace">The input values become gradients!</text> |
| <text x="545" y="225" fill="{text_color}" font-size="10" font-family="monospace">dL/dw2 = dL/dz * x2 = {dL_dz:.3f} * {x2:.2f} = {dL_dw2:.4f}</text> |
| |
| <!-- Legend --> |
| <rect x="530" y="260" width="360" height="80" rx="6" fill="#141414" stroke="#333" stroke-width="1"/> |
| <text x="710" y="282" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">LEGEND</text> |
| <line x1="550" y1="302" x2="590" y2="302" stroke="{forward_arrow}" stroke-width="2" marker-end="url(#fwd)"/> |
| <text x="600" y="306" fill="{text_color}" font-size="10" font-family="monospace">Forward (data)</text> |
| <line x1="550" y1="322" x2="590" y2="322" stroke="{backward_arrow}" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd)"/> |
| <text x="600" y="326" fill="{text_color}" font-size="10" font-family="monospace">Backward (grads)</text> |
| <rect x="750" y="295" width="12" height="12" fill="{gradient_color}"/> |
| <text x="770" y="306" fill="{text_color}" font-size="10" font-family="monospace">Gradient values</text> |
| </svg> |
| ''' |
| return svg |
|
|
|
|
| |
| |
| |
|
|
| def forward_pass_demo(x1, x2, w1, w2, b): |
| """Step-by-step forward pass calculation.""" |
|
|
| |
| z = w1 * x1 + w2 * x2 + b |
|
|
| |
| sigmoid_z = 1 / (1 + np.exp(-z)) |
|
|
| |
| svg_diagram = generate_forward_svg(x1, x2, w1, w2, b, z, sigmoid_z) |
|
|
| explanation = f""" |
| ## FORWARD PASS CALCULATION |
| =============================================== |
| |
| ### STEP 1: The Weighted Sum (z) |
| |
| The neuron computes a **weighted sum** of inputs plus a bias: |
| |
| ``` |
| z = w1*x1 + w2*x2 + b |
| z = ({w1:.2f})*({x1:.2f}) + ({w2:.2f})*({x2:.2f}) + ({b:.2f}) |
| z = {w1*x1:.4f} + {w2*x2:.4f} + {b:.2f} |
| z = {z:.4f} |
| ``` |
| |
| **What's happening:** Each input is scaled by its weight, |
| then we add them up. The bias shifts the whole thing. |
| |
| ----------------------------------------------- |
| |
| ### STEP 2: The Sigmoid Activation Function |
| |
| We squash z through the **sigmoid function** to get output in (0,1): |
| |
| ``` |
| sigmoid(z) = 1 / (1 + e^(-z)) |
| = 1 / (1 + e^(-{z:.4f})) |
| = 1 / (1 + {np.exp(-z):.4f}) |
| = 1 / {1 + np.exp(-z):.4f} |
| = {sigmoid_z:.4f} |
| ``` |
| |
| **Why sigmoid?** It smoothly maps any real number to (0,1). |
| - z >> 0 --> sigmoid(z) ≈ 1 |
| - z << 0 --> sigmoid(z) ≈ 0 |
| - z = 0 --> sigmoid(z) = 0.5 |
| |
| ----------------------------------------------- |
| |
| ### SUMMARY |
| |
| ``` |
| Inputs: x1={x1:.2f}, x2={x2:.2f} |
| Weights: w1={w1:.2f}, w2={w2:.2f} |
| Bias: b={b:.2f} |
| |
| z = {z:.4f} |
| y = sigmoid(z) = {sigmoid_z:.4f} |
| ``` |
| |
| **Interpretation:** Output of {sigmoid_z:.4f} means |
| {sigmoid_z*100:.1f}% probability of class 1. |
| """ |
| return svg_diagram, explanation |
|
|
|
|
| FORWARD_INTRO_SVG = ''' |
| <svg viewBox="0 0 900 420" style="width:100%; max-width:900px; height:auto; background:#0c0c0c; border-radius:8px; border:1px solid #333; margin-bottom:20px;"> |
| <defs> |
| <marker id="fwd-arr" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"> |
| <polygon points="0 0, 10 3.5, 0 7" fill="#5b9bd5" /> |
| </marker> |
| </defs> |
| |
| <!-- Title --> |
| <text x="450" y="35" text-anchor="middle" fill="#f0b030" font-size="20" font-family="monospace" font-weight="bold">FORWARD PASS: DATA IN → PREDICTION OUT</text> |
| |
| <!-- Main neuron diagram --> |
| <rect x="20" y="55" width="480" height="200" rx="8" fill="#141414" stroke="#5b9bd5" stroke-width="1"/> |
| <text x="40" y="80" fill="#5b9bd5" font-size="14" font-family="monospace" font-weight="bold">THE SINGLE NEURON</text> |
| |
| <!-- Input x1 --> |
| <circle cx="70" cy="120" r="22" fill="#1a3a2a" stroke="#4ade80" stroke-width="2"/> |
| <text x="70" y="125" text-anchor="middle" fill="#4ade80" font-size="13" font-family="monospace">x₁</text> |
| |
| <!-- Input x2 --> |
| <circle cx="70" cy="195" r="22" fill="#1a3a2a" stroke="#4ade80" stroke-width="2"/> |
| <text x="70" y="200" text-anchor="middle" fill="#4ade80" font-size="13" font-family="monospace">x₂</text> |
| |
| <!-- Weights on arrows --> |
| <line x1="92" y1="120" x2="155" y2="145" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-arr)"/> |
| <text x="108" y="115" fill="#f0b030" font-size="11" font-family="monospace">×w₁</text> |
| |
| <line x1="92" y1="195" x2="155" y2="170" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-arr)"/> |
| <text x="108" y="210" fill="#f0b030" font-size="11" font-family="monospace">×w₂</text> |
| |
| <!-- Summation node --> |
| <circle cx="185" cy="157" r="28" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="2"/> |
| <text x="185" y="152" text-anchor="middle" fill="#5b9bd5" font-size="16" font-family="monospace">Σ</text> |
| <text x="185" y="170" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">+b</text> |
| |
| <!-- Arrow to z --> |
| <line x1="213" y1="157" x2="255" y2="157" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-arr)"/> |
| |
| <!-- z value box --> |
| <rect x="265" y="140" width="40" height="35" rx="5" fill="#2a2a1a" stroke="#f0b030" stroke-width="2"/> |
| <text x="285" y="162" text-anchor="middle" fill="#f0b030" font-size="13" font-family="monospace">z</text> |
| |
| <!-- Arrow to sigmoid --> |
| <line x1="305" y1="157" x2="340" y2="157" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-arr)"/> |
| |
| <!-- Sigmoid box --> |
| <rect x="350" y="135" width="60" height="45" rx="5" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="2"/> |
| <text x="380" y="155" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">σ(z)</text> |
| <text x="380" y="170" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">sigmoid</text> |
| |
| <!-- Arrow to output --> |
| <line x1="410" y1="157" x2="445" y2="157" stroke="#4ade80" stroke-width="2" marker-end="url(#fwd-arr)"/> |
| |
| <!-- Output --> |
| <circle cx="470" cy="157" r="22" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="2"/> |
| <text x="470" y="162" text-anchor="middle" fill="#ff6b6b" font-size="13" font-family="monospace">ŷ</text> |
| |
| <!-- Step labels --> |
| <text x="185" y="215" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">STEP 1</text> |
| <text x="185" y="228" text-anchor="middle" fill="#5b9bd5" font-size="9" font-family="monospace">Weighted Sum</text> |
| |
| <text x="380" y="215" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">STEP 2</text> |
| <text x="380" y="228" text-anchor="middle" fill="#5b9bd5" font-size="9" font-family="monospace">Activation</text> |
| |
| <!-- Equations box - made wider --> |
| <rect x="515" y="55" width="365" height="200" rx="8" fill="#141414" stroke="#4ade80" stroke-width="1"/> |
| <text x="535" y="80" fill="#4ade80" font-size="14" font-family="monospace" font-weight="bold">THE MATH</text> |
| |
| <text x="535" y="110" fill="#888" font-size="12" font-family="monospace">Step 1: Weighted Sum</text> |
| <text x="535" y="132" fill="#f0b030" font-size="14" font-family="monospace">z = w₁x₁ + w₂x₂ + b</text> |
| |
| <text x="535" y="165" fill="#888" font-size="12" font-family="monospace">Step 2: Sigmoid</text> |
| <text x="535" y="187" fill="#f0b030" font-size="14" font-family="monospace">ŷ = σ(z) = 1/(1+e⁻ᶻ)</text> |
| |
| <line x1="535" y1="200" x2="860" y2="200" stroke="#333" stroke-width="1"/> |
| |
| <text x="535" y="222" fill="#4ade80" font-size="12" font-family="monospace">Output ŷ ∈ (0,1) = probability</text> |
| <text x="535" y="242" fill="#888" font-size="10" font-family="monospace">Squashes any real number to (0,1)</text> |
| |
| <!-- Interactive prompt --> |
| <rect x="20" y="270" width="860" height="135" rx="8" fill="#1a1a1a" stroke="#888" stroke-width="1" stroke-dasharray="5,5"/> |
| <text x="450" y="300" text-anchor="middle" fill="#888" font-size="15" font-family="monospace">▼ INTERACTIVE TERMINAL ▼</text> |
| <text x="450" y="330" text-anchor="middle" fill="#5b9bd5" font-size="13" font-family="monospace">Adjust inputs (x₁, x₂), weights (w₁, w₂), and bias (b)</text> |
| <text x="450" y="360" text-anchor="middle" fill="#4ade80" font-size="13" font-family="monospace">Click "EXECUTE FORWARD PASS" to see the values</text> |
| <text x="450" y="390" text-anchor="middle" fill="#555" font-size="10" font-family="monospace">[Vault-Tec recommends saving your work before experiments]</text> |
| </svg> |
| ''' |
|
|
| FORWARD_INTRO = f""" |
| {FORWARD_INTRO_SVG} |
| """ |
|
|
| |
| |
| |
|
|
| CHAIN_RULE_INTRO_SVG = ''' |
| <svg viewBox="0 0 900 480" style="width:100%; max-width:900px; height:auto; background:#0c0c0c; border-radius:8px; border:1px solid #333; margin-bottom:20px;"> |
| <defs> |
| <marker id="arr-blue" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"> |
| <polygon points="0 0, 10 3.5, 0 7" fill="#5b9bd5" /> |
| </marker> |
| </defs> |
| |
| <!-- Title --> |
| <text x="450" y="32" text-anchor="middle" fill="#f0b030" font-size="20" font-family="monospace" font-weight="bold">THE CHAIN RULE</text> |
| |
| <!-- Section 1: Basic Idea --> |
| <rect x="20" y="50" width="860" height="110" rx="8" fill="#141414" stroke="#5b9bd5" stroke-width="1"/> |
| <text x="40" y="72" fill="#5b9bd5" font-size="14" font-family="monospace" font-weight="bold">1. THE BASIC IDEA</text> |
| |
| <!-- Composition diagram - more compact --> |
| <rect x="50" y="90" width="45" height="32" rx="5" fill="#1a3a2a" stroke="#4ade80" stroke-width="2"/> |
| <text x="72" y="111" text-anchor="middle" fill="#4ade80" font-size="14" font-family="monospace">x</text> |
| |
| <line x1="95" y1="106" x2="130" y2="106" stroke="#5b9bd5" stroke-width="2" marker-end="url(#arr-blue)"/> |
| |
| <rect x="140" y="88" width="55" height="36" rx="5" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="2"/> |
| <text x="167" y="111" text-anchor="middle" fill="#5b9bd5" font-size="13" font-family="monospace">g(x)</text> |
| |
| <line x1="195" y1="106" x2="230" y2="106" stroke="#5b9bd5" stroke-width="2" marker-end="url(#arr-blue)"/> |
| |
| <rect x="240" y="90" width="45" height="32" rx="5" fill="#2a2a1a" stroke="#f0b030" stroke-width="2"/> |
| <text x="262" y="111" text-anchor="middle" fill="#f0b030" font-size="14" font-family="monospace">u</text> |
| |
| <line x1="285" y1="106" x2="320" y2="106" stroke="#5b9bd5" stroke-width="2" marker-end="url(#arr-blue)"/> |
| |
| <rect x="330" y="88" width="55" height="36" rx="5" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="2"/> |
| <text x="357" y="111" text-anchor="middle" fill="#5b9bd5" font-size="13" font-family="monospace">f(u)</text> |
| |
| <line x1="385" y1="106" x2="420" y2="106" stroke="#5b9bd5" stroke-width="2" marker-end="url(#arr-blue)"/> |
| |
| <rect x="430" y="90" width="45" height="32" rx="5" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="2"/> |
| <text x="452" y="111" text-anchor="middle" fill="#ff6b6b" font-size="14" font-family="monospace">y</text> |
| |
| <!-- Labels --> |
| <text x="262" y="140" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">u = g(x)</text> |
| <text x="452" y="140" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">y = f(g(x))</text> |
| |
| <!-- Formula box - wider --> |
| <rect x="510" y="80" width="350" height="65" rx="6" fill="#0c0c0c" stroke="#f0b030" stroke-width="1"/> |
| <text x="685" y="105" text-anchor="middle" fill="#f0b030" font-size="13" font-family="monospace">Chain Rule Formula:</text> |
| <text x="685" y="130" text-anchor="middle" fill="#4ade80" font-size="15" font-family="monospace">dy/dx = (dy/du) × (du/dx)</text> |
| |
| <!-- Section 2: Why It Works --> |
| <rect x="20" y="170" width="420" height="140" rx="8" fill="#141414" stroke="#f0b030" stroke-width="1"/> |
| <text x="40" y="192" fill="#f0b030" font-size="14" font-family="monospace" font-weight="bold">2. WHY IT WORKS</text> |
| |
| <text x="40" y="218" fill="#4ade80" font-size="13" font-family="monospace">Think of it like fractions:</text> |
| |
| <!-- Fraction visualization - smaller --> |
| <text x="55" y="250" fill="#5b9bd5" font-size="16" font-family="monospace">dy</text> |
| <line x1="50" y1="255" x2="75" y2="255" stroke="#5b9bd5" stroke-width="2"/> |
| <text x="55" y="272" fill="#5b9bd5" font-size="16" font-family="monospace">dx</text> |
| |
| <text x="90" y="260" fill="#888" font-size="18" font-family="monospace">=</text> |
| |
| <text x="115" y="250" fill="#f0b030" font-size="16" font-family="monospace">dy</text> |
| <line x1="110" y1="255" x2="135" y2="255" stroke="#f0b030" stroke-width="2"/> |
| <text x="115" y="272" fill="#ff6b6b" font-size="16" font-family="monospace" text-decoration="line-through">du</text> |
| |
| <text x="150" y="260" fill="#888" font-size="18" font-family="monospace">×</text> |
| |
| <text x="175" y="250" fill="#ff6b6b" font-size="16" font-family="monospace" text-decoration="line-through">du</text> |
| <line x1="170" y1="255" x2="195" y2="255" stroke="#f0b030" stroke-width="2"/> |
| <text x="175" y="272" fill="#4ade80" font-size="16" font-family="monospace">dx</text> |
| |
| <text x="210" y="260" fill="#888" font-size="18" font-family="monospace">=</text> |
| |
| <text x="235" y="250" fill="#f0b030" font-size="16" font-family="monospace">dy</text> |
| <line x1="230" y1="255" x2="255" y2="255" stroke="#4ade80" stroke-width="2"/> |
| <text x="235" y="272" fill="#4ade80" font-size="16" font-family="monospace">dx</text> |
| |
| |
| <!-- Section 3: Concrete Example --> |
| <rect x="460" y="170" width="420" height="140" rx="8" fill="#141414" stroke="#4ade80" stroke-width="1"/> |
| <text x="480" y="192" fill="#4ade80" font-size="14" font-family="monospace" font-weight="bold">3. EXAMPLE: y = (3x + 2)²</text> |
| |
| <!-- Break down - repositioned --> |
| <text x="480" y="222" fill="#5b9bd5" font-size="12" font-family="monospace">Inner: u = 3x+2</text> |
| <text x="630" y="222" fill="#888" font-size="12" font-family="monospace">→ du/dx = 3</text> |
| |
| <text x="480" y="248" fill="#5b9bd5" font-size="12" font-family="monospace">Outer: y = u²</text> |
| <text x="630" y="248" fill="#888" font-size="12" font-family="monospace">→ dy/du = 2u</text> |
| |
| <line x1="480" y1="260" x2="860" y2="260" stroke="#333" stroke-width="1"/> |
| |
| <text x="480" y="282" fill="#f0b030" font-size="13" font-family="monospace">Chain: dy/dx = 2u × 3 = 6(3x+2)</text> |
| |
| <!-- Section 4: Interactive --> |
| <rect x="20" y="320" width="860" height="145" rx="8" fill="#1a1a1a" stroke="#888" stroke-width="1" stroke-dasharray="5,5"/> |
| <text x="450" y="350" text-anchor="middle" fill="#888" font-size="15" font-family="monospace">▼ INTERACTIVE TERMINAL ▼</text> |
| <text x="450" y="380" text-anchor="middle" fill="#5b9bd5" font-size="13" font-family="monospace">Adjust a, b, and x with the sliders</text> |
| <text x="450" y="410" text-anchor="middle" fill="#4ade80" font-size="13" font-family="monospace">Click "APPLY CHAIN RULE" to see values flow through</text> |
| <text x="450" y="440" text-anchor="middle" fill="#555" font-size="10" font-family="monospace">[Remember: derivatives chain together like Vault access codes]</text> |
| </svg> |
| ''' |
|
|
| CHAIN_RULE_INTRO = f""" |
| {CHAIN_RULE_INTRO_SVG} |
| """ |
|
|
| def generate_chain_rule_svg(a, b, x_val, u, y, du_dx, dy_du, dy_dx): |
| """Generate SVG showing chain rule visually.""" |
|
|
| bg = "#0c0c0c" |
| node_fill = "#1a2a3a" |
| node_stroke = "#5b9bd5" |
| input_fill = "#1a3a2a" |
| input_stroke = "#4ade80" |
| output_fill = "#2a2a1a" |
| output_stroke = "#f0b030" |
| text_color = "#4ade80" |
| label_color = "#5b9bd5" |
| arrow_color = "#5b9bd5" |
| value_color = "#f0b030" |
| deriv_color = "#ff6b6b" |
|
|
| svg = f''' |
| <svg viewBox="0 0 800 280" style="width:100%; max-width:800px; height:auto; background:{bg}; border-radius:8px; border:1px solid #333;"> |
| <defs> |
| <marker id="arr" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"> |
| <polygon points="0 0, 10 3.5, 0 7" fill="{arrow_color}" /> |
| </marker> |
| <marker id="arr-red" markerWidth="10" markerHeight="7" refX="0" refY="3.5" orient="auto"> |
| <polygon points="10 0, 0 3.5, 10 7" fill="{deriv_color}" /> |
| </marker> |
| </defs> |
| |
| <!-- Title --> |
| <text x="400" y="30" text-anchor="middle" fill="{output_stroke}" font-size="18" font-family="monospace">CHAIN RULE: y = ({a}x + {b})²</text> |
| |
| <!-- Input x --> |
| <rect x="50" y="80" width="80" height="60" rx="8" fill="{input_fill}" stroke="{input_stroke}" stroke-width="2"/> |
| <text x="90" y="105" text-anchor="middle" fill="{label_color}" font-size="14" font-family="monospace">x</text> |
| <text x="90" y="128" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace">{x_val:.2f}</text> |
| |
| <!-- Arrow x to u --> |
| <line x1="130" y1="110" x2="200" y2="110" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arr)"/> |
| <text x="165" y="100" text-anchor="middle" fill="{text_color}" font-size="12" font-family="monospace">g(x)</text> |
| |
| <!-- Inner function u --> |
| <rect x="210" y="80" width="120" height="60" rx="8" fill="{node_fill}" stroke="{node_stroke}" stroke-width="2"/> |
| <text x="270" y="100" text-anchor="middle" fill="{label_color}" font-size="12" font-family="monospace">u = {a}x + {b}</text> |
| <text x="270" y="125" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace">u = {u:.2f}</text> |
| |
| <!-- Arrow u to y --> |
| <line x1="330" y1="110" x2="400" y2="110" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arr)"/> |
| <text x="365" y="100" text-anchor="middle" fill="{text_color}" font-size="12" font-family="monospace">f(u)</text> |
| |
| <!-- Outer function y --> |
| <rect x="410" y="80" width="100" height="60" rx="8" fill="{output_fill}" stroke="{output_stroke}" stroke-width="2"/> |
| <text x="460" y="100" text-anchor="middle" fill="{label_color}" font-size="12" font-family="monospace">y = u²</text> |
| <text x="460" y="125" text-anchor="middle" fill="{value_color}" font-size="16" font-family="monospace">y = {y:.2f}</text> |
| |
| <!-- Derivative arrows (below, going backwards) --> |
| <line x1="410" y1="160" x2="335" y2="160" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3" marker-end="url(#arr-red)"/> |
| <text x="372" y="180" text-anchor="middle" fill="{deriv_color}" font-size="12" font-family="monospace">dy/du = {dy_du:.2f}</text> |
| |
| <line x1="210" y1="160" x2="135" y2="160" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3" marker-end="url(#arr-red)"/> |
| <text x="172" y="180" text-anchor="middle" fill="{deriv_color}" font-size="12" font-family="monospace">du/dx = {du_dx:.2f}</text> |
| |
| <!-- Chain rule result box --> |
| <rect x="540" y="70" width="230" height="120" rx="8" fill="#141414" stroke="#333" stroke-width="1"/> |
| <text x="655" y="95" text-anchor="middle" fill="{output_stroke}" font-size="14" font-family="monospace">CHAIN RULE</text> |
| <text x="555" y="120" fill="{text_color}" font-size="13" font-family="monospace">dy/dx = dy/du × du/dx</text> |
| <text x="555" y="145" fill="{text_color}" font-size="13" font-family="monospace"> = {dy_du:.2f} × {du_dx:.2f}</text> |
| <text x="555" y="175" fill="{value_color}" font-size="16" font-family="monospace"> = {dy_dx:.2f}</text> |
| |
| <!-- Legend --> |
| <line x1="50" y1="240" x2="90" y2="240" stroke="{arrow_color}" stroke-width="2" marker-end="url(#arr)"/> |
| <text x="100" y="244" fill="{text_color}" font-size="11" font-family="monospace">Forward</text> |
| |
| <line x1="200" y1="240" x2="240" y2="240" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3" marker-end="url(#arr-red)"/> |
| <text x="250" y="244" fill="{text_color}" font-size="11" font-family="monospace">Derivatives (multiply!)</text> |
| </svg> |
| ''' |
| return svg |
|
|
|
|
| def chain_rule_calculator(a, b, x_val): |
| """Demonstrate chain rule with y = (ax + b)^2""" |
|
|
| |
| u = a * x_val + b |
|
|
| |
| y = u ** 2 |
|
|
| |
| du_dx = a |
| dy_du = 2 * u |
| dy_dx = dy_du * du_dx |
|
|
| |
| svg_diagram = generate_chain_rule_svg(a, b, x_val, u, y, du_dx, dy_du, dy_dx) |
|
|
| explanation = f""" |
| ## CHAIN RULE CALCULATION: y = ({a}x + {b})^2 |
| =============================================== |
| |
| ### Setting up the composition: |
| |
| ``` |
| Inner function: u = {a}x + {b} |
| Outer function: y = u^2 |
| ``` |
| |
| At x = {x_val}: |
| ``` |
| u = {a}*{x_val} + {b} = {u} |
| y = ({u})^2 = {y} |
| ``` |
| |
| ----------------------------------------------- |
| |
| ### Step 1: Find du/dx (derivative of inner function) |
| |
| ``` |
| u = {a}x + {b} |
| |
| du/dx = {a} (coefficient of x) |
| ``` |
| |
| ----------------------------------------------- |
| |
| ### Step 2: Find dy/du (derivative of outer function) |
| |
| ``` |
| y = u^2 |
| |
| dy/du = 2u = 2*({u}) = {dy_du} |
| ``` |
| |
| ----------------------------------------------- |
| |
| ### Step 3: Apply the Chain Rule! |
| |
| ``` |
| dy/dx = (dy/du) * (du/dx) |
| = {dy_du} * {du_dx} |
| = {dy_dx} |
| ``` |
| |
| ----------------------------------------------- |
| |
| ### VERIFICATION (optional sanity check) |
| |
| If x increases by tiny amount h=0.001: |
| ``` |
| y(x+h) = ({a}*{x_val+0.001} + {b})^2 = {(a*(x_val+0.001) + b)**2:.6f} |
| y(x) = {y} |
| |
| Slope ≈ (y(x+h) - y(x)) / h |
| = {((a*(x_val+0.001) + b)**2 - y) / 0.001:.4f} |
| |
| Our dy/dx = {dy_dx} |
| ``` |
| |
| The chain rule works! |
| """ |
| return svg_diagram, explanation |
|
|
|
|
| |
| |
| |
|
|
| DERIVATIVES_INTRO_SVG = ''' |
| <svg viewBox="0 0 900 520" style="width:100%; max-width:900px; height:auto; background:#0c0c0c; border-radius:8px; border:1px solid #333; margin-bottom:20px;"> |
| |
| <!-- Title --> |
| <text x="450" y="32" text-anchor="middle" fill="#f0b030" font-size="18" font-family="monospace" font-weight="bold">KEY DERIVATIVES YOU NEED TO KNOW</text> |
| |
| <!-- Sigmoid Section --> |
| <rect x="20" y="50" width="430" height="200" rx="8" fill="#141414" stroke="#4ade80" stroke-width="2"/> |
| <text x="35" y="75" fill="#4ade80" font-size="14" font-family="monospace" font-weight="bold">1. SIGMOID FUNCTION</text> |
| |
| <!-- Mini sigmoid curve - compact --> |
| <polyline points="35,150 50,148 65,145 80,140 95,130 110,115 125,102 140,94 155,89 170,87" |
| fill="none" stroke="#4ade80" stroke-width="2"/> |
| <line x1="35" y1="155" x2="170" y2="155" stroke="#333" stroke-width="1"/> |
| <line x1="102" y1="87" x2="102" y2="155" stroke="#333" stroke-width="1" stroke-dasharray="3,3"/> |
| <text x="102" y="170" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">z=0 → σ=0.5</text> |
| |
| <!-- Formula - compact layout --> |
| <text x="185" y="95" fill="#5b9bd5" font-size="10" font-family="monospace">Function:</text> |
| <text x="185" y="112" fill="#f0b030" font-size="10" font-family="monospace">σ(z) = 1/(1+e⁻ᶻ)</text> |
| |
| <text x="185" y="135" fill="#5b9bd5" font-size="10" font-family="monospace">Derivative:</text> |
| <text x="185" y="152" fill="#ff6b6b" font-size="10" font-family="monospace">dσ/dz = σ(z)(1-σ(z))</text> |
| |
| <!-- Key insight box --> |
| <rect x="30" y="190" width="410" height="50" rx="4" fill="#1a2a1a" stroke="#4ade80" stroke-width="1"/> |
| <text x="45" y="210" fill="#4ade80" font-size="10" font-family="monospace">Derivative uses the function itself!</text> |
| <text x="45" y="228" fill="#888" font-size="10" font-family="monospace">Already have σ(z)? No extra work needed.</text> |
| |
| <!-- BCE Section --> |
| <rect x="460" y="50" width="420" height="200" rx="8" fill="#141414" stroke="#ff6b6b" stroke-width="2"/> |
| <text x="480" y="75" fill="#ff6b6b" font-size="14" font-family="monospace" font-weight="bold">2. BINARY CROSS-ENTROPY</text> |
| |
| <text x="480" y="100" fill="#5b9bd5" font-size="11" font-family="monospace">Loss Function:</text> |
| <text x="480" y="120" fill="#f0b030" font-size="11" font-family="monospace">L = -[y·log(ŷ) + (1-y)·log(1-ŷ)]</text> |
| |
| <text x="480" y="150" fill="#5b9bd5" font-size="11" font-family="monospace">Derivative w.r.t. ŷ:</text> |
| <text x="480" y="170" fill="#ff6b6b" font-size="11" font-family="monospace">dL/dŷ = -y/ŷ + (1-y)/(1-ŷ)</text> |
| |
| <!-- Magic box --> |
| <rect x="480" y="190" width="385" height="50" rx="4" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="1"/> |
| <text x="495" y="210" fill="#f0b030" font-size="10" font-family="monospace">Combined with sigmoid:</text> |
| <text x="495" y="227" fill="#4ade80" font-size="13" font-family="monospace">dL/dz = ŷ - y (Vault-Tec approved)</text> |
| |
| <!-- Common Patterns Section --> |
| <rect x="20" y="260" width="860" height="130" rx="8" fill="#141414" stroke="#5b9bd5" stroke-width="1"/> |
| <text x="40" y="283" fill="#5b9bd5" font-size="14" font-family="monospace" font-weight="bold">3. DERIVATIVE PATTERNS</text> |
| |
| <!-- Pattern boxes - evenly spaced --> |
| <rect x="35" y="300" width="135" height="75" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/> |
| <text x="102" y="320" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">Powers</text> |
| <text x="102" y="342" text-anchor="middle" fill="#f0b030" font-size="12" font-family="monospace">d/dx[xⁿ]</text> |
| <text x="102" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= n·xⁿ⁻¹</text> |
| |
| <rect x="185" y="300" width="135" height="75" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/> |
| <text x="252" y="320" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">Exponential</text> |
| <text x="252" y="342" text-anchor="middle" fill="#f0b030" font-size="12" font-family="monospace">d/dx[eˣ]</text> |
| <text x="252" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= eˣ</text> |
| |
| <rect x="335" y="300" width="135" height="75" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/> |
| <text x="402" y="320" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">Logarithm</text> |
| <text x="402" y="342" text-anchor="middle" fill="#f0b030" font-size="12" font-family="monospace">d/dx[log(x)]</text> |
| <text x="402" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= 1/x</text> |
| |
| <rect x="485" y="300" width="135" height="75" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/> |
| <text x="552" y="320" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">Sigmoid</text> |
| <text x="552" y="342" text-anchor="middle" fill="#f0b030" font-size="12" font-family="monospace">d/dx[σ(x)]</text> |
| <text x="552" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= σ(1-σ)</text> |
| |
| <rect x="635" y="300" width="230" height="75" rx="4" fill="#2a2a1a" stroke="#f0b030" stroke-width="1"/> |
| <text x="750" y="320" text-anchor="middle" fill="#f0b030" font-size="11" font-family="monospace">Chain Rule</text> |
| <text x="750" y="342" text-anchor="middle" fill="#888" font-size="11" font-family="monospace">d/dx[f(g(x))]</text> |
| <text x="750" y="364" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">= f'(g(x)) · g'(x)</text> |
| |
| <!-- Interactive prompt --> |
| <rect x="20" y="405" width="860" height="100" rx="8" fill="#1a1a1a" stroke="#888" stroke-width="1" stroke-dasharray="5,5"/> |
| <text x="450" y="430" text-anchor="middle" fill="#888" font-size="14" font-family="monospace">INTERACTIVE TERMINAL</text> |
| <text x="450" y="455" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">Move the z slider to see sigmoid and its derivative</text> |
| <text x="450" y="480" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">Derivative peaks at z=0, vanishes at extremes</text> |
| </svg> |
| ''' |
|
|
| DERIVATIVES_INTRO = f""" |
| {DERIVATIVES_INTRO_SVG} |
| """ |
|
|
| def generate_sigmoid_svg(z, sig, dsig): |
| """Generate SVG showing sigmoid function and derivative visually.""" |
|
|
| bg = "#0c0c0c" |
| curve_color = "#4ade80" |
| deriv_color = "#ff6b6b" |
| point_color = "#f0b030" |
| grid_color = "#333" |
| text_color = "#4ade80" |
| label_color = "#5b9bd5" |
|
|
| |
| curve_points = [] |
| for i in range(-50, 51): |
| x_pt = i / 10 |
| y_pt = 1 / (1 + np.exp(-x_pt)) |
| |
| svg_x = 100 + (x_pt + 5) * 40 |
| svg_y = 250 - y_pt * 200 |
| curve_points.append(f"{svg_x:.1f},{svg_y:.1f}") |
|
|
| curve_path = " ".join(curve_points) |
|
|
| |
| pt_x = 100 + (z + 5) * 40 |
| pt_y = 250 - sig * 200 |
|
|
| |
| |
| tangent_dx = 40 |
| tangent_dy = -dsig * 200 |
| t_x1 = pt_x - tangent_dx |
| t_y1 = pt_y - tangent_dy |
| t_x2 = pt_x + tangent_dx |
| t_y2 = pt_y + tangent_dy |
|
|
| svg = f''' |
| <svg viewBox="0 0 700 320" style="width:100%; max-width:700px; height:auto; background:{bg}; border-radius:8px; border:1px solid #333;"> |
| |
| <!-- Title --> |
| <text x="350" y="25" text-anchor="middle" fill="{point_color}" font-size="16" font-family="monospace">SIGMOID FUNCTION & DERIVATIVE</text> |
| |
| <!-- Grid lines --> |
| <line x1="100" y1="150" x2="500" y2="150" stroke="{grid_color}" stroke-width="1" stroke-dasharray="3,3"/> |
| <text x="95" y="154" text-anchor="end" fill="{grid_color}" font-size="10" font-family="monospace">0.5</text> |
| |
| <line x1="300" y1="50" x2="300" y2="250" stroke="{grid_color}" stroke-width="1" stroke-dasharray="3,3"/> |
| <text x="300" y="265" text-anchor="middle" fill="{grid_color}" font-size="10" font-family="monospace">z=0</text> |
| |
| <!-- Axes --> |
| <line x1="100" y1="250" x2="500" y2="250" stroke="{label_color}" stroke-width="1"/> |
| <line x1="100" y1="50" x2="100" y2="250" stroke="{label_color}" stroke-width="1"/> |
| |
| <!-- Axis labels --> |
| <text x="300" y="280" text-anchor="middle" fill="{label_color}" font-size="12" font-family="monospace">z</text> |
| <text x="70" y="150" text-anchor="middle" fill="{label_color}" font-size="12" font-family="monospace">σ(z)</text> |
| <text x="95" y="55" text-anchor="end" fill="{grid_color}" font-size="10" font-family="monospace">1.0</text> |
| <text x="95" y="254" text-anchor="end" fill="{grid_color}" font-size="10" font-family="monospace">0.0</text> |
| <text x="100" y="275" text-anchor="middle" fill="{grid_color}" font-size="10" font-family="monospace">-5</text> |
| <text x="500" y="275" text-anchor="middle" fill="{grid_color}" font-size="10" font-family="monospace">5</text> |
| |
| <!-- Sigmoid curve --> |
| <polyline points="{curve_path}" fill="none" stroke="{curve_color}" stroke-width="2"/> |
| |
| <!-- Tangent line at current point --> |
| <line x1="{t_x1:.1f}" y1="{t_y1:.1f}" x2="{t_x2:.1f}" y2="{t_y2:.1f}" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3"/> |
| |
| <!-- Current point --> |
| <circle cx="{pt_x:.1f}" cy="{pt_y:.1f}" r="8" fill="{point_color}" stroke="#fff" stroke-width="2"/> |
| |
| <!-- Point label --> |
| <line x1="{pt_x:.1f}" y1="{pt_y:.1f}" x2="{pt_x + 30:.1f}" y2="{pt_y - 30:.1f}" stroke="{point_color}" stroke-width="1"/> |
| <text x="{pt_x + 35:.1f}" y="{pt_y - 35:.1f}" fill="{point_color}" font-size="11" font-family="monospace">z={z:.1f}</text> |
| <text x="{pt_x + 35:.1f}" y="{pt_y - 22:.1f}" fill="{point_color}" font-size="11" font-family="monospace">σ={sig:.3f}</text> |
| |
| <!-- Info box --> |
| <rect x="530" y="60" width="155" height="140" rx="6" fill="#141414" stroke="#333" stroke-width="1"/> |
| <text x="607" y="85" text-anchor="middle" fill="{point_color}" font-size="13" font-family="monospace">VALUES</text> |
| |
| <text x="540" y="110" fill="{label_color}" font-size="12" font-family="monospace">z = {z:.2f}</text> |
| <text x="540" y="130" fill="{curve_color}" font-size="12" font-family="monospace">σ(z) = {sig:.4f}</text> |
| <text x="540" y="155" fill="{deriv_color}" font-size="12" font-family="monospace">dσ/dz = {dsig:.4f}</text> |
| |
| <text x="540" y="185" fill="{text_color}" font-size="10" font-family="monospace">= σ(1-σ)</text> |
| <text x="540" y="198" fill="{text_color}" font-size="10" font-family="monospace">= {sig:.3f}×{1-sig:.3f}</text> |
| |
| <!-- Legend --> |
| <line x1="530" y1="240" x2="560" y2="240" stroke="{curve_color}" stroke-width="2"/> |
| <text x="565" y="244" fill="{text_color}" font-size="10" font-family="monospace">σ(z)</text> |
| |
| <line x1="620" y1="240" x2="650" y2="240" stroke="{deriv_color}" stroke-width="2" stroke-dasharray="5,3"/> |
| <text x="655" y="244" fill="{text_color}" font-size="10" font-family="monospace">tangent</text> |
| |
| </svg> |
| ''' |
| return svg |
|
|
|
|
| def sigmoid_derivative_demo(z): |
| """Show sigmoid and its derivative.""" |
|
|
| sig = 1 / (1 + np.exp(-z)) |
| dsig = sig * (1 - sig) |
|
|
| svg_diagram = generate_sigmoid_svg(z, sig, dsig) |
|
|
| explanation = f""" |
| ## SIGMOID DERIVATIVE AT z = {z} |
| =============================================== |
| |
| ### Step 1: Compute sigmoid(z) |
| |
| ``` |
| σ(z) = 1 / (1 + e^(-z)) |
| = 1 / (1 + e^(-{z})) |
| = 1 / (1 + {np.exp(-z):.6f}) |
| = 1 / {1 + np.exp(-z):.6f} |
| = {sig:.6f} |
| ``` |
| |
| ----------------------------------------------- |
| |
| ### Step 2: Compute the derivative |
| |
| Using the formula: dσ/dz = σ(z) * (1 - σ(z)) |
| |
| ``` |
| dσ/dz = σ(z) * (1 - σ(z)) |
| = {sig:.6f} * (1 - {sig:.6f}) |
| = {sig:.6f} * {1-sig:.6f} |
| = {dsig:.6f} |
| ``` |
| |
| ----------------------------------------------- |
| |
| ### Interpretation |
| |
| At z = {z}: |
| - Sigmoid output: {sig:.4f} (how confident the neuron is) |
| - Derivative: {dsig:.4f} (how sensitive output is to z) |
| |
| **Key insight:** The derivative is LARGEST when z≈0 (sigmoid≈0.5) |
| and SMALLEST when |z| is large. This is the "vanishing gradient" |
| problem - extreme values barely update! |
| |
| ``` |
| z = 0 --> σ = 0.5, dσ/dz = 0.25 (maximum!) |
| z = 5 --> σ ≈ 0.99, dσ/dz ≈ 0.007 (tiny!) |
| z = -5 --> σ ≈ 0.01, dσ/dz ≈ 0.007 (tiny!) |
| ``` |
| """ |
| return svg_diagram, explanation |
|
|
|
|
| |
| |
| |
|
|
| BACKWARD_INTRO_SVG = ''' |
| <svg viewBox="0 0 900 500" style="width:100%; max-width:900px; height:auto; background:#0c0c0c; border-radius:8px; border:1px solid #333; margin-bottom:20px;"> |
| <defs> |
| <marker id="fwd-b" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"> |
| <polygon points="0 0, 10 3.5, 0 7" fill="#5b9bd5" /> |
| </marker> |
| <marker id="bwd-b" markerWidth="10" markerHeight="7" refX="0" refY="3.5" orient="auto"> |
| <polygon points="10 0, 0 3.5, 10 7" fill="#ff6b6b" /> |
| </marker> |
| </defs> |
| |
| <!-- Title --> |
| <text x="450" y="32" text-anchor="middle" fill="#f0b030" font-size="18" font-family="monospace" font-weight="bold">BACKPROPAGATION: LEARNING FROM MISTAKES</text> |
| |
| <!-- Main concept - THE BIG PICTURE --> |
| <rect x="20" y="50" width="550" height="170" rx="8" fill="#141414" stroke="#5b9bd5" stroke-width="1"/> |
| <text x="40" y="75" fill="#5b9bd5" font-size="14" font-family="monospace" font-weight="bold">THE BIG PICTURE</text> |
| |
| <!-- Forward pass row --> |
| <text x="40" y="105" fill="#5b9bd5" font-size="12" font-family="monospace">FWD:</text> |
| <rect x="85" y="90" width="50" height="30" rx="4" fill="#1a3a2a" stroke="#4ade80" stroke-width="1"/> |
| <text x="110" y="110" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">x</text> |
| <line x1="135" y1="105" x2="175" y2="105" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-b)"/> |
| <rect x="185" y="90" width="50" height="30" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/> |
| <text x="210" y="110" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">z</text> |
| <line x1="235" y1="105" x2="275" y2="105" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-b)"/> |
| <rect x="285" y="90" width="50" height="30" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/> |
| <text x="310" y="110" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">y</text> |
| <line x1="335" y1="105" x2="375" y2="105" stroke="#5b9bd5" stroke-width="2" marker-end="url(#fwd-b)"/> |
| <rect x="385" y="90" width="50" height="30" rx="4" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="1"/> |
| <text x="410" y="110" text-anchor="middle" fill="#ff6b6b" font-size="12" font-family="monospace">L</text> |
| <text x="450" y="110" fill="#888" font-size="11" font-family="monospace">Data flows</text> |
| |
| <!-- Backward pass row - wider boxes for partials --> |
| <text x="40" y="160" fill="#ff6b6b" font-size="12" font-family="monospace">BWD:</text> |
| <rect x="85" y="145" width="50" height="30" rx="4" fill="#1a3a2a" stroke="#4ade80" stroke-width="1"/> |
| <text x="110" y="165" text-anchor="middle" fill="#ff6b6b" font-size="10" font-family="monospace">dL/dx</text> |
| <line x1="135" y1="160" x2="175" y2="160" stroke="#ff6b6b" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd-b)"/> |
| <rect x="185" y="145" width="50" height="30" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/> |
| <text x="210" y="165" text-anchor="middle" fill="#ff6b6b" font-size="10" font-family="monospace">dL/dz</text> |
| <line x1="235" y1="160" x2="275" y2="160" stroke="#ff6b6b" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd-b)"/> |
| <rect x="285" y="145" width="50" height="30" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/> |
| <text x="310" y="165" text-anchor="middle" fill="#ff6b6b" font-size="10" font-family="monospace">dL/dy</text> |
| <line x1="335" y1="160" x2="375" y2="160" stroke="#ff6b6b" stroke-width="2" stroke-dasharray="4,2" marker-end="url(#bwd-b)"/> |
| <rect x="385" y="145" width="50" height="30" rx="4" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="1"/> |
| <text x="410" y="165" text-anchor="middle" fill="#ff6b6b" font-size="12" font-family="monospace">1</text> |
| <text x="450" y="165" fill="#888" font-size="11" font-family="monospace">Grads flow</text> |
| |
| <!-- Goal text --> |
| <text x="40" y="205" fill="#f0b030" font-size="11" font-family="monospace">GOAL: Find dL/dw1, dL/dw2, dL/db to update weights</text> |
| |
| <!-- Key insight box --> |
| <rect x="585" y="50" width="295" height="170" rx="6" fill="#2a2a1a" stroke="#f0b030" stroke-width="2"/> |
| <text x="732" y="78" text-anchor="middle" fill="#f0b030" font-size="13" font-family="monospace" font-weight="bold">KEY INSIGHT</text> |
| <text x="600" y="108" fill="#4ade80" font-size="11" font-family="monospace">Same computation graph,</text> |
| <text x="600" y="128" fill="#4ade80" font-size="11" font-family="monospace">opposite direction!</text> |
| <line x1="600" y1="142" x2="865" y2="142" stroke="#333" stroke-width="1"/> |
| <text x="600" y="165" fill="#888" font-size="11" font-family="monospace">At each node multiply:</text> |
| <text x="600" y="188" fill="#5b9bd5" font-size="10" font-family="monospace">upstream × local derivative</text> |
| |
| <!-- Chain rule explanation --> |
| <rect x="20" y="235" width="420" height="120" rx="8" fill="#141414" stroke="#ff6b6b" stroke-width="1"/> |
| <text x="40" y="260" fill="#ff6b6b" font-size="13" font-family="monospace" font-weight="bold">CHAIN RULE IN ACTION</text> |
| <text x="40" y="290" fill="#888" font-size="12" font-family="monospace">To find dL/dw1:</text> |
| <text x="40" y="315" fill="#f0b030" font-size="12" font-family="monospace">dL/dw1 = dL/dy * dy/dz * dz/dw1</text> |
| <text x="40" y="340" fill="#4ade80" font-size="11" font-family="monospace">Multiply derivatives along the path!</text> |
| |
| <!-- Visual chain - wider boxes --> |
| <rect x="455" y="235" width="425" height="120" rx="8" fill="#141414" stroke="#4ade80" stroke-width="1"/> |
| <text x="475" y="260" fill="#4ade80" font-size="13" font-family="monospace" font-weight="bold">VISUAL MULTIPLICATION</text> |
| |
| <rect x="475" y="280" width="62" height="35" rx="4" fill="#2a1a1a" stroke="#ff6b6b" stroke-width="1"/> |
| <text x="506" y="302" text-anchor="middle" fill="#ff6b6b" font-size="10" font-family="monospace">dL/dy</text> |
| |
| <text x="545" y="302" fill="#888" font-size="16" font-family="monospace">×</text> |
| |
| <rect x="562" y="280" width="62" height="35" rx="4" fill="#1a2a3a" stroke="#5b9bd5" stroke-width="1"/> |
| <text x="593" y="302" text-anchor="middle" fill="#5b9bd5" font-size="10" font-family="monospace">dy/dz</text> |
| |
| <text x="632" y="302" fill="#888" font-size="16" font-family="monospace">×</text> |
| |
| <rect x="650" y="280" width="70" height="35" rx="4" fill="#1a3a2a" stroke="#4ade80" stroke-width="1"/> |
| <text x="685" y="302" text-anchor="middle" fill="#4ade80" font-size="10" font-family="monospace">dz/dw1</text> |
| |
| <text x="728" y="302" fill="#888" font-size="16" font-family="monospace">=</text> |
| |
| <rect x="745" y="280" width="70" height="35" rx="4" fill="#2a2a1a" stroke="#f0b030" stroke-width="2"/> |
| <text x="780" y="302" text-anchor="middle" fill="#f0b030" font-size="10" font-family="monospace">dL/dw1</text> |
| |
| <text x="475" y="340" fill="#888" font-size="11" font-family="monospace">upstream × local = pass backward</text> |
| |
| <!-- Interactive prompt --> |
| <rect x="20" y="370" width="860" height="115" rx="8" fill="#1a1a1a" stroke="#888" stroke-width="1" stroke-dasharray="5,5"/> |
| <text x="450" y="395" text-anchor="middle" fill="#888" font-size="14" font-family="monospace">INTERACTIVE TERMINAL</text> |
| <text x="450" y="420" text-anchor="middle" fill="#5b9bd5" font-size="12" font-family="monospace">Blue arrows = forward data | Red arrows = backward gradients</text> |
| <text x="450" y="445" text-anchor="middle" fill="#4ade80" font-size="12" font-family="monospace">Click "EXECUTE FULL BACKPROP" to see all values calculated</text> |
| <text x="450" y="470" text-anchor="middle" fill="#555" font-size="10" font-family="monospace">[Vault-Tec tip: errors propagate backward, just like rumors in the cafeteria]</text> |
| </svg> |
| ''' |
|
|
| BACKWARD_INTRO = f""" |
| {BACKWARD_INTRO_SVG} |
| """ |
|
|
| def backward_pass_demo(x1, x2, w1, w2, b, y_true): |
| """Complete forward + backward pass with detailed chain rule.""" |
|
|
| |
| z = w1 * x1 + w2 * x2 + b |
| y_pred = 1 / (1 + np.exp(-z)) |
|
|
| |
| eps = 1e-7 |
| y_pred_clipped = np.clip(y_pred, eps, 1 - eps) |
| loss = -(y_true * np.log(y_pred_clipped) + (1 - y_true) * np.log(1 - y_pred_clipped)) |
|
|
| |
| |
| dL_dy = -y_true / y_pred_clipped + (1 - y_true) / (1 - y_pred_clipped) |
|
|
| |
| dy_dz = y_pred * (1 - y_pred) |
|
|
| |
| dz_dw1 = x1 |
| dz_dw2 = x2 |
| dz_db = 1 |
|
|
| |
| dL_dz = dL_dy * dy_dz |
| dL_dw1 = dL_dz * dz_dw1 |
| dL_dw2 = dL_dz * dz_dw2 |
| dL_db = dL_dz * dz_db |
|
|
| |
| svg_diagram = generate_backward_svg( |
| x1, x2, w1, w2, b, y_true, z, y_pred, |
| dL_dy, dy_dz, dL_dz, dL_dw1, dL_dw2, dL_db, loss |
| ) |
|
|
| explanation = f""" |
| ## COMPLETE BACKPROP WALKTHROUGH |
| =============================================== |
| |
| ### GIVEN: |
| ``` |
| Inputs: x1 = {x1}, x2 = {x2} |
| Weights: w1 = {w1}, w2 = {w2} |
| Bias: b = {b} |
| True label: y_true = {y_true} |
| ``` |
| |
| =============================================== |
| ## PART 1: FORWARD PASS (review) |
| =============================================== |
| |
| **Step 1a: Weighted sum** |
| ``` |
| z = w1*x1 + w2*x2 + b |
| = ({w1})*({x1}) + ({w2})*({x2}) + ({b}) |
| = {z:.6f} |
| ``` |
| |
| **Step 1b: Sigmoid activation** |
| ``` |
| y_pred = sigmoid(z) = 1/(1+e^(-z)) |
| = 1/(1+e^(-{z:.4f})) |
| = {y_pred:.6f} |
| ``` |
| |
| **Step 1c: Binary Cross-Entropy Loss** |
| ``` |
| L = -[y_true*log(y_pred) + (1-y_true)*log(1-y_pred)] |
| = -[{y_true}*log({y_pred:.6f}) + {1-y_true}*log({1-y_pred:.6f})] |
| = -[{y_true * np.log(y_pred_clipped):.6f} + {(1-y_true) * np.log(1-y_pred_clipped):.6f}] |
| = {loss:.6f} |
| ``` |
| |
| =============================================== |
| ## PART 2: BACKWARD PASS (reversing the flow) |
| =============================================== |
| |
| We need: dL/dw1, dL/dw2, dL/db |
| |
| **The computation graph:** |
| ``` |
| w1,x1,w2,x2,b --> z --> y_pred --> L |
| | | | |
| dz/dw dy/dz dL/dy |
| ``` |
| |
| We work BACKWARDS from Loss to weights. |
| |
| ----------------------------------------------- |
| ### STEP 2a: dL/dy_pred (how loss changes with prediction) |
| |
| ``` |
| L = -y_true*log(y_pred) - (1-y_true)*log(1-y_pred) |
| |
| dL/dy_pred = -y_true/y_pred + (1-y_true)/(1-y_pred) |
| = -{y_true}/{y_pred:.6f} + {1-y_true}/{1-y_pred:.6f} |
| = {-y_true/y_pred_clipped:.6f} + {(1-y_true)/(1-y_pred_clipped):.6f} |
| = {dL_dy:.6f} |
| ``` |
| |
| ----------------------------------------------- |
| ### STEP 2b: dy_pred/dz (sigmoid derivative) |
| |
| Using: d/dz[sigmoid(z)] = sigmoid(z)*(1-sigmoid(z)) |
| |
| ``` |
| dy/dz = y_pred * (1 - y_pred) |
| = {y_pred:.6f} * (1 - {y_pred:.6f}) |
| = {y_pred:.6f} * {1-y_pred:.6f} |
| = {dy_dz:.6f} |
| ``` |
| |
| ----------------------------------------------- |
| ### STEP 2c: dz/dw1, dz/dw2, dz/db |
| |
| Since z = w1*x1 + w2*x2 + b: |
| |
| ``` |
| dz/dw1 = x1 = {dz_dw1} |
| dz/dw2 = x2 = {dz_dw2} |
| dz/db = 1 = {dz_db} |
| ``` |
| |
| ----------------------------------------------- |
| ### STEP 2d: CHAIN RULE - Put it together! |
| |
| First, compute dL/dz (the "upstream gradient"): |
| ``` |
| dL/dz = (dL/dy_pred) * (dy_pred/dz) |
| = {dL_dy:.6f} * {dy_dz:.6f} |
| = {dL_dz:.6f} |
| ``` |
| |
| Now chain to each weight: |
| ``` |
| dL/dw1 = (dL/dz) * (dz/dw1) |
| = {dL_dz:.6f} * {dz_dw1} |
| = {dL_dw1:.6f} |
| |
| dL/dw2 = (dL/dz) * (dz/dw2) |
| = {dL_dz:.6f} * {dz_dw2} |
| = {dL_dw2:.6f} |
| |
| dL/db = (dL/dz) * (dz/db) |
| = {dL_dz:.6f} * {dz_db} |
| = {dL_db:.6f} |
| ``` |
| |
| =============================================== |
| ## PART 3: GRADIENT DESCENT UPDATE |
| =============================================== |
| |
| With learning rate α = 0.1: |
| |
| ``` |
| w1_new = w1 - α * dL/dw1 |
| = {w1} - 0.1 * {dL_dw1:.6f} |
| = {w1 - 0.1 * dL_dw1:.6f} |
| |
| w2_new = w2 - α * dL/dw2 |
| = {w2} - 0.1 * {dL_dw2:.6f} |
| = {w2 - 0.1 * dL_dw2:.6f} |
| |
| b_new = b - α * dL/db |
| = {b} - 0.1 * {dL_db:.6f} |
| = {b - 0.1 * dL_db:.6f} |
| ``` |
| |
| **We've completed one step of learning!** |
| |
| =============================================== |
| ## SUMMARY TABLE |
| =============================================== |
| |
| | Gradient | Value | Meaning | |
| |----------|-------|---------| |
| | dL/dy | {dL_dy:.4f} | Loss sensitivity to prediction | |
| | dy/dz | {dy_dz:.4f} | Sigmoid sensitivity | |
| | dL/dz | {dL_dz:.4f} | "Upstream gradient" | |
| | dL/dw1 | {dL_dw1:.4f} | How to adjust w1 | |
| | dL/dw2 | {dL_dw2:.4f} | How to adjust w2 | |
| | dL/db | {dL_db:.4f} | How to adjust bias | |
| """ |
| return svg_diagram, explanation |
|
|
|
|
| |
| |
| |
|
|
| PRACTICE_INTRO = """ |
| # PRACTICE: COMPUTE BY HAND FIRST! |
| =============================================== |
| |
| Welcome to the Gradient Occupational Aptitude Test (G.O.A.T.). |
| Per Vault-Tec guidelines, pencil-and-paper practice builds neural |
| pathways (the biological kind). Complete these problems to determine |
| your future as a Machine Learning Specialist. |
| |
| ## TIPS FOR HAND CALCULATION |
| |
| 1. **Draw the computation graph** - boxes for operations, |
| arrows for data flow |
| |
| 2. **Forward pass first** - compute all intermediate values |
| |
| 3. **Backward pass** - start from loss, work backwards |
| |
| 4. **Check dimensions** - gradient of scalar w.r.t. vector |
| has same shape as the vector |
| |
| 5. **Verify numerically** - if unsure, use tiny h to approximate: |
| df/dx ≈ (f(x+h) - f(x)) / h |
| |
| ## PRACTICE PROBLEMS |
| |
| Select a problem below and try it before clicking "Check Answer"! |
| """ |
|
|
| def practice_problem(problem_num): |
| """Generate practice problems with solutions.""" |
|
|
| problems = { |
| 1: { |
| "question": """ |
| ### Problem 1: Simple Chain Rule |
| |
| Compute dy/dx where: |
| ``` |
| y = (2x + 3)^3 |
| ``` |
| |
| at x = 1. |
| |
| **Hint:** Let u = 2x + 3, so y = u^3 |
| """, |
| "solution": """ |
| ### Solution to Problem 1 |
| |
| **Step 1: Identify the composition** |
| ``` |
| u = 2x + 3 (inner) |
| y = u^3 (outer) |
| ``` |
| |
| **Step 2: Find individual derivatives** |
| ``` |
| du/dx = 2 |
| dy/du = 3u^2 |
| ``` |
| |
| **Step 3: Apply chain rule** |
| ``` |
| dy/dx = (dy/du) * (du/dx) |
| = 3u^2 * 2 |
| = 6u^2 |
| = 6(2x + 3)^2 |
| ``` |
| |
| **Step 4: Evaluate at x = 1** |
| ``` |
| dy/dx = 6(2*1 + 3)^2 |
| = 6(5)^2 |
| = 6 * 25 |
| = 150 |
| ``` |
| |
| **Answer: dy/dx = 150 at x = 1** |
| """ |
| }, |
| 2: { |
| "question": """ |
| ### Problem 2: Sigmoid Derivative |
| |
| Given z = 2, compute: |
| 1. sigmoid(z) |
| 2. d/dz[sigmoid(z)] |
| |
| **Reminder:** sigmoid(z) = 1/(1+e^(-z)) |
| d/dz[sigmoid(z)] = sigmoid(z) * (1 - sigmoid(z)) |
| """, |
| "solution": f""" |
| ### Solution to Problem 2 |
| |
| **Step 1: Compute sigmoid(2)** |
| ``` |
| sigmoid(2) = 1/(1 + e^(-2)) |
| = 1/(1 + {np.exp(-2):.6f}) |
| = 1/{1 + np.exp(-2):.6f} |
| = {1/(1+np.exp(-2)):.6f} |
| ``` |
| |
| **Step 2: Compute derivative** |
| ``` |
| Let s = sigmoid(2) = {1/(1+np.exp(-2)):.6f} |
| |
| ds/dz = s * (1 - s) |
| = {1/(1+np.exp(-2)):.6f} * (1 - {1/(1+np.exp(-2)):.6f}) |
| = {1/(1+np.exp(-2)):.6f} * {1 - 1/(1+np.exp(-2)):.6f} |
| = {(1/(1+np.exp(-2))) * (1 - 1/(1+np.exp(-2))):.6f} |
| ``` |
| |
| **Answers:** |
| - sigmoid(2) ≈ 0.8808 |
| - d/dz[sigmoid(2)] ≈ 0.1050 |
| """ |
| }, |
| 3: { |
| "question": """ |
| ### Problem 3: Full Backprop (Mini Version) |
| |
| Single neuron with: |
| ``` |
| x = 2 |
| w = 0.5 |
| b = -1 |
| y_true = 1 |
| ``` |
| |
| Using sigmoid activation and BCE loss, find dL/dw. |
| |
| **Steps to follow:** |
| 1. Forward: z = wx + b |
| 2. Forward: y_pred = sigmoid(z) |
| 3. Forward: L = BCE(y_true, y_pred) |
| 4. Backward: Apply chain rule |
| """, |
| "solution": """ |
| ### Solution to Problem 3 |
| |
| **Forward Pass:** |
| ``` |
| z = w*x + b = 0.5*2 + (-1) = 0 |
| |
| y_pred = sigmoid(0) = 0.5 |
| |
| L = -[1*log(0.5) + 0*log(0.5)] |
| = -log(0.5) |
| = 0.693 |
| ``` |
| |
| **Backward Pass:** |
| |
| dL/dy_pred: |
| ``` |
| = -y_true/y_pred + (1-y_true)/(1-y_pred) |
| = -1/0.5 + 0/0.5 |
| = -2 |
| ``` |
| |
| dy_pred/dz: |
| ``` |
| = y_pred * (1 - y_pred) |
| = 0.5 * 0.5 = 0.25 |
| ``` |
| |
| dz/dw: |
| ``` |
| = x = 2 |
| ``` |
| |
| **Chain Rule:** |
| ``` |
| dL/dw = (dL/dy) * (dy/dz) * (dz/dw) |
| = (-2) * (0.25) * (2) |
| = -1.0 |
| ``` |
| |
| **Answer: dL/dw = -1.0** |
| |
| **Interpretation:** Negative gradient means we should |
| INCREASE w to reduce loss (moving opposite to gradient). |
| """ |
| } |
| } |
|
|
| prob = problems.get(problem_num, problems[1]) |
| return prob["question"], prob["solution"] |
|
|
|
|
| |
| |
| |
|
|
| with gr.Blocks(title="BACKPROP TERMINAL v1.0") as demo: |
|
|
| gr.Markdown(""" |
| # > VAULT-TEC NEURAL NETWORK TRAINING TERMINAL |
| ## > SECURITY CLEARANCE: STAT 3106 |
| ### > INITIALIZING BACKPROPAGATION MODULES... |
| """) |
|
|
| with gr.Tabs(): |
| |
| with gr.TabItem("01: FORWARD PASS"): |
| gr.HTML(FORWARD_INTRO) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### INPUT PARAMETERS") |
| x1_input = gr.Slider(minimum=-5, maximum=5, value=1.0, step=0.1, label="x1 (input 1)") |
| x2_input = gr.Slider(minimum=-5, maximum=5, value=2.0, step=0.1, label="x2 (input 2)") |
| w1_input = gr.Slider(minimum=-2, maximum=2, value=0.5, step=0.1, label="w1 (weight 1)") |
| w2_input = gr.Slider(minimum=-2, maximum=2, value=-0.3, step=0.1, label="w2 (weight 2)") |
| b_input = gr.Slider(minimum=-2, maximum=2, value=0.1, step=0.1, label="b (bias)") |
| forward_btn = gr.Button(">> EXECUTE FORWARD PASS <<") |
|
|
| with gr.Column(scale=2): |
| forward_svg = gr.HTML(label="Computation Graph") |
| forward_output = gr.Markdown(label="Calculation") |
|
|
| forward_btn.click( |
| forward_pass_demo, |
| inputs=[x1_input, x2_input, w1_input, w2_input, b_input], |
| outputs=[forward_svg, forward_output] |
| ) |
|
|
| |
| with gr.TabItem("02: CHAIN RULE"): |
| gr.HTML(CHAIN_RULE_INTRO) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### FUNCTION: y = (ax + b)²") |
| a_input = gr.Slider(minimum=-5, maximum=5, value=3.0, step=0.1, label="a (coefficient)") |
| b2_input = gr.Slider(minimum=-5, maximum=5, value=2.0, step=0.1, label="b (constant)") |
| x_input = gr.Slider(minimum=-5, maximum=5, value=1.0, step=0.1, label="x (evaluation point)") |
| chain_btn = gr.Button(">> APPLY CHAIN RULE <<") |
|
|
| with gr.Column(scale=2): |
| chain_svg = gr.HTML(label="Chain Rule Visualization") |
| chain_output = gr.Markdown(label="Chain Rule Breakdown") |
|
|
| chain_btn.click( |
| chain_rule_calculator, |
| inputs=[a_input, b2_input, x_input], |
| outputs=[chain_svg, chain_output] |
| ) |
|
|
| |
| with gr.TabItem("03: KEY DERIVATIVES"): |
| gr.HTML(DERIVATIVES_INTRO) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### SIGMOID DERIVATIVE CALCULATOR") |
| z_input = gr.Slider( |
| minimum=-5, maximum=5, value=0, step=0.1, |
| label="z value" |
| ) |
| sigmoid_btn = gr.Button(">> COMPUTE SIGMOID DERIVATIVE <<") |
|
|
| with gr.Column(scale=2): |
| sigmoid_svg = gr.HTML(label="Sigmoid Visualization") |
| sigmoid_output = gr.Markdown(label="Derivative Calculation") |
|
|
| sigmoid_btn.click( |
| sigmoid_derivative_demo, |
| inputs=[z_input], |
| outputs=[sigmoid_svg, sigmoid_output] |
| ) |
|
|
| |
| with gr.TabItem("04: BACKWARD PASS"): |
| gr.HTML(BACKWARD_INTRO) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### NETWORK CONFIGURATION") |
| bx1 = gr.Slider(minimum=-5, maximum=5, value=1.0, step=0.1, label="x1") |
| bx2 = gr.Slider(minimum=-5, maximum=5, value=2.0, step=0.1, label="x2") |
| bw1 = gr.Slider(minimum=-2, maximum=2, value=0.5, step=0.1, label="w1") |
| bw2 = gr.Slider(minimum=-2, maximum=2, value=-0.3, step=0.1, label="w2") |
| bb = gr.Slider(minimum=-2, maximum=2, value=0.1, step=0.1, label="bias") |
| by_true = gr.Slider(minimum=0, maximum=1, value=1, step=1, label="y_true (0 or 1)") |
| back_btn = gr.Button(">> EXECUTE FULL BACKPROP <<") |
|
|
| with gr.Column(scale=2): |
| back_svg = gr.HTML(label="Backprop Graph") |
| back_output = gr.Markdown(label="Complete Backprop Trace") |
|
|
| back_btn.click( |
| backward_pass_demo, |
| inputs=[bx1, bx2, bw1, bw2, bb, by_true], |
| outputs=[back_svg, back_output] |
| ) |
|
|
| |
| with gr.TabItem("05: PRACTICE"): |
| gr.Markdown(PRACTICE_INTRO) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| problem_select = gr.Radio( |
| choices=["Problem 1: Chain Rule", "Problem 2: Sigmoid", "Problem 3: Full Backprop"], |
| label="Select Problem", |
| value="Problem 1: Chain Rule" |
| ) |
| show_problem_btn = gr.Button(">> SHOW PROBLEM <<") |
| show_answer_btn = gr.Button(">> REVEAL SOLUTION <<") |
|
|
| with gr.Column(): |
| problem_display = gr.Markdown(label="Problem") |
| solution_display = gr.Markdown(label="Solution", visible=False) |
|
|
| def show_problem(selection): |
| prob_num = int(selection.split(":")[0].split()[-1]) |
| q, _ = practice_problem(prob_num) |
| return q, gr.update(visible=False, value="") |
|
|
| def show_solution(selection): |
| prob_num = int(selection.split(":")[0].split()[-1]) |
| _, s = practice_problem(prob_num) |
| return gr.update(visible=True, value=s) |
|
|
| show_problem_btn.click( |
| show_problem, |
| inputs=[problem_select], |
| outputs=[problem_display, solution_display] |
| ) |
|
|
| show_answer_btn.click( |
| show_solution, |
| inputs=[problem_select], |
| outputs=[solution_display] |
| ) |
|
|
| |
| with gr.TabItem("06: REFERENCE"): |
| gr.Markdown(""" |
| # QUICK REFERENCE CARD |
| =============================================== |
| |
| ## CHAIN RULE |
| |
| ``` |
| y = f(g(x)) |
| |
| dy/dx = (df/dg) * (dg/dx) |
| ``` |
| |
| For longer chains: just multiply all the derivatives! |
| |
| ----------------------------------------------- |
| |
| ## COMMON DERIVATIVES |
| |
| | Function | Derivative | |
| |----------|------------| |
| | x^n | n*x^(n-1) | |
| | e^x | e^x | |
| | log(x) | 1/x | |
| | sigmoid(x) | sigmoid(x)*(1-sigmoid(x)) | |
| | ReLU(x) | 1 if x>0, else 0 | |
| |
| ----------------------------------------------- |
| |
| ## NEURAL NETWORK CHAIN |
| |
| For a single neuron with sigmoid: |
| ``` |
| z = Σ(wi*xi) + b |
| y = sigmoid(z) |
| L = loss(y, y_true) |
| |
| dL/dwi = (dL/dy) * (dy/dz) * (dz/dwi) |
| = (dL/dy) * sigmoid'(z) * xi |
| ``` |
| |
| ----------------------------------------------- |
| |
| ## GRADIENT DESCENT |
| |
| ``` |
| w_new = w_old - learning_rate * dL/dw |
| ``` |
| |
| The gradient points UPHILL; we go opposite direction. |
| |
| ----------------------------------------------- |
| |
| ## BCE LOSS GRADIENT (sigmoid output) |
| |
| For BCE loss with sigmoid output: |
| |
| ``` |
| dL/dz = y_pred - y_true |
| ``` |
| |
| This clean result comes from cancellation in the chain! |
| |
| ----------------------------------------------- |
| |
| ## DEBUGGING TIPS |
| |
| 1. **Gradient check:** Compare with numerical gradient |
| ``` |
| dL/dw ≈ [L(w+h) - L(w-h)] / (2h) |
| ``` |
| |
| 2. **Shapes must match:** gradient of L w.r.t. W has same shape as W |
| |
| 3. **Large gradients?** Try gradient clipping or smaller learning rate |
| |
| 4. **Vanishing gradients?** Consider ReLU or residual connections |
| """) |
|
|
| gr.Markdown(""" |
| --- |
| > TERMINAL SESSION ACTIVE |
| |
| > VAULT-TEC WISHES YOU A PLEASANT TRAINING EXPERIENCE |
| """) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch( |
| server_port=7860, |
| css=FALLOUT_CSS, |
| js=""" |
| () => { |
| // Force dark mode and hide theme toggle |
| document.body.classList.add('dark'); |
| const style = document.createElement('style'); |
| style.textContent = ` |
| .dark-mode-toggle, [aria-label="Toggle dark mode"], |
| button[title*="theme"], .theme-toggle { display: none !important; } |
| `; |
| document.head.appendChild(style); |
| } |
| """ |
| ) |
|
|