Spaces:
Sleeping
Sleeping
Upload final_graph.py
Browse files- final_graph.py +220 -0
final_graph.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
def plot_secant(h):
|
| 6 |
+
x = np.linspace(-1, 2, 400)
|
| 7 |
+
y = x**2
|
| 8 |
+
m = (h**2) / h
|
| 9 |
+
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
|
| 10 |
+
for ax in axs:
|
| 11 |
+
ax.set_xlim(-1, 2)
|
| 12 |
+
ax.set_ylim(-1, 4)
|
| 13 |
+
ax.set_xticks(np.arange(-1, 3, 1))
|
| 14 |
+
ax.set_yticks(np.arange(-1, 5, 1))
|
| 15 |
+
ax.grid(True, linestyle='--', linewidth=0.5, color='lightgray')
|
| 16 |
+
ax.spines['top'].set_visible(False)
|
| 17 |
+
ax.spines['right'].set_visible(False)
|
| 18 |
+
axs[0].plot(x, y, color='black')
|
| 19 |
+
axs[0].plot([0, h], [0, h**2], color='red', linewidth=2)
|
| 20 |
+
axs[0].scatter([0, h], [0, h**2], color='red', zorder=5)
|
| 21 |
+
axs[1].plot(x, m * x, color='red', linewidth=2)
|
| 22 |
+
plt.tight_layout()
|
| 23 |
+
return fig
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def plot_tangent(x0):
|
| 27 |
+
x = np.linspace(-1, 2, 400)
|
| 28 |
+
y = x**2
|
| 29 |
+
m = 2 * x0
|
| 30 |
+
y0 = x0**2
|
| 31 |
+
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
|
| 32 |
+
for ax in axs:
|
| 33 |
+
ax.set_xlim(-1, 2)
|
| 34 |
+
ax.set_ylim(-1, 4)
|
| 35 |
+
ax.set_xticks(np.arange(-1, 3, 1))
|
| 36 |
+
ax.set_yticks(np.arange(-1, 5, 1))
|
| 37 |
+
ax.grid(True, linestyle='--', linewidth=0.5, color='lightgray')
|
| 38 |
+
ax.spines['top'].set_visible(False)
|
| 39 |
+
ax.spines['right'].set_visible(False)
|
| 40 |
+
axs[0].plot(x, y, color='black')
|
| 41 |
+
axs[0].plot(x, m * (x - x0) + y0, color='red', linewidth=2)
|
| 42 |
+
axs[0].scatter([x0], [y0], color='red', zorder=5)
|
| 43 |
+
axs[1].plot(x, 2 * x, color='black')
|
| 44 |
+
axs[1].scatter([x0], [m], color='red', zorder=5)
|
| 45 |
+
plt.tight_layout()
|
| 46 |
+
return fig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def plot_gradient_descent(lr, init_x, steps):
|
| 50 |
+
n = int(steps)
|
| 51 |
+
path = [init_x]
|
| 52 |
+
for _ in range(n):
|
| 53 |
+
path.append(path[-1] - lr * 2 * path[-1])
|
| 54 |
+
xv = np.array(path)
|
| 55 |
+
|
| 56 |
+
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
|
| 57 |
+
x_plot = np.linspace(-2, 2, 400)
|
| 58 |
+
axs[0].plot(x_plot, x_plot**2, color='black')
|
| 59 |
+
axs[0].plot(xv, xv**2, marker='o', color='red', linewidth=2)
|
| 60 |
+
for i in range(n):
|
| 61 |
+
axs[0].annotate('', xy=(xv[i+1], xv[i+1]**2), xytext=(xv[i], xv[i]**2), arrowprops=dict(arrowstyle='->', color='red'))
|
| 62 |
+
axs[0].set_xlim(-2, 2)
|
| 63 |
+
axs[0].set_ylim(-0.5, 5)
|
| 64 |
+
axs[0].set_title('Gradient Descent Path')
|
| 65 |
+
axs[0].grid(True, linestyle='--', linewidth=0.5, color='lightgray')
|
| 66 |
+
|
| 67 |
+
axs[1].plot(range(n+1), xv, marker='o', color='red', linewidth=2)
|
| 68 |
+
for i in range(n):
|
| 69 |
+
axs[1].annotate('', xy=(i+1, xv[i+1]), xytext=(i, xv[i]), arrowprops=dict(arrowstyle='->', color='red'))
|
| 70 |
+
axs[1].set_xlim(0, n)
|
| 71 |
+
axs[1].set_ylim(xv.min() - 0.5, xv.max() + 0.5)
|
| 72 |
+
axs[1].set_xticks(range(0, n+1, max(1, n//5)))
|
| 73 |
+
axs[1].set_xlabel('Iteration')
|
| 74 |
+
axs[1].set_title('x over Iterations')
|
| 75 |
+
axs[1].grid(True, linestyle='--', linewidth=0.5, color='lightgray')
|
| 76 |
+
|
| 77 |
+
plt.tight_layout()
|
| 78 |
+
return fig
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def plot_chain_network(x):
|
| 82 |
+
y = 2 * x
|
| 83 |
+
z = 3 * y
|
| 84 |
+
L = 4 * z
|
| 85 |
+
fig, ax = plt.subplots(figsize=(6, 2))
|
| 86 |
+
ax.axis('off')
|
| 87 |
+
pos = {'x': 0.1, 'y': 0.3, 'z': 0.5, 'L': 0.7}
|
| 88 |
+
for name in pos:
|
| 89 |
+
ax.add_patch(plt.Circle((pos[name], 0.5), 0.05, fill=False))
|
| 90 |
+
ax.text(pos[name], 0.5, name, ha='center', va='center')
|
| 91 |
+
for src, dst, lbl in [
|
| 92 |
+
('x', 'y', r'$\partial y/\partial x=2$'),
|
| 93 |
+
('y', 'z', r'$\partial z/\partial y=3$'),
|
| 94 |
+
('z', 'L', r'$\partial L/\partial z=4$')
|
| 95 |
+
]:
|
| 96 |
+
sx, dx = pos[src], pos[dst]
|
| 97 |
+
ax.annotate('', xy=(dx, 0.5), xytext=(sx, 0.5), arrowprops=dict(arrowstyle='->'))
|
| 98 |
+
ax.text((sx + dx) / 2, 0.6, lbl, ha='center', va='center')
|
| 99 |
+
ax.text(0.02, 0.15,
|
| 100 |
+
r'$\frac{\partial L}{\partial x}=\frac{\partial L}{\partial z}\cdot\frac{\partial z}{\partial y}\cdot\frac{\partial y}{\partial x}$',
|
| 101 |
+
transform=ax.transAxes, ha='left')
|
| 102 |
+
ax.text(0.02, 0.02, r'$=4\times3\times2=24$', transform=ax.transAxes, ha='left')
|
| 103 |
+
for name, val in [('x', x), ('y', y), ('z', z), ('L', L)]:
|
| 104 |
+
ax.text(pos[name], 0.3, f"{name}={val:.2f}", ha='center')
|
| 105 |
+
plt.tight_layout()
|
| 106 |
+
return fig
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def plot_backprop_dnn(x, w1, w2, t):
|
| 110 |
+
a = w1 * x
|
| 111 |
+
y = w2 * a
|
| 112 |
+
L = 0.5 * (y - t)**2
|
| 113 |
+
fig, ax = plt.subplots(figsize=(6, 2))
|
| 114 |
+
ax.axis('off')
|
| 115 |
+
pos = {'x': 0.1, 'a': 0.3, 'y': 0.5, 'L': 0.7}
|
| 116 |
+
for name in pos:
|
| 117 |
+
ax.add_patch(plt.Circle((pos[name], 0.5), 0.05, fill=False))
|
| 118 |
+
ax.text(pos[name], 0.5, name, ha='center', va='center')
|
| 119 |
+
for src, dst, lbl in [
|
| 120 |
+
('x', 'a', r'$\partial a/\partial x=w_1$'),
|
| 121 |
+
('a', 'y', r'$\partial y/\partial a=w_2$'),
|
| 122 |
+
('y', 'L', r'$\partial L/\partial y=(y-t)$')
|
| 123 |
+
]:
|
| 124 |
+
sx, dx = pos[src], pos[dst]
|
| 125 |
+
ax.annotate('', xy=(dx, 0.5), xytext=(sx, 0.5), arrowprops=dict(arrowstyle='->'))
|
| 126 |
+
ax.text((sx + dx) / 2, 0.6, lbl, ha='center', va='center')
|
| 127 |
+
ax.text(0.02, 0.15, r'$\partial L/\partial w_2=(y-t)\cdot a$', transform=ax.transAxes, ha='left')
|
| 128 |
+
ax.text(0.02, 0.02, r'$\partial L/\partial w_1=(y-t)\cdot w_2\cdot x$', transform=ax.transAxes, ha='left')
|
| 129 |
+
for name, val in [('x', x), ('a', a), ('y', y), ('L', L)]:
|
| 130 |
+
ax.text(pos[name], 0.3, f"{name}={val:.2f}", ha='center')
|
| 131 |
+
plt.tight_layout()
|
| 132 |
+
return fig
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def update_secant(h): return plot_secant(h), f'**Δx=h={h:.4f}**, (f(x+h)-f(x))/h={(h**2)/h:.4f}'
|
| 136 |
+
def update_tangent(x0): return plot_tangent(x0), f'**x={x0:.2f}**, dy/dx={2*x0:.2f}'
|
| 137 |
+
def update_gd(lr, init_x, steps): return plot_gradient_descent(lr, init_x, steps), f'lr={lr:.2f}, init={init_x:.2f}, steps={int(steps)}'
|
| 138 |
+
def update_chain(x):
|
| 139 |
+
msg = (f"**Current values:**\n"
|
| 140 |
+
f"- y = 2·{x:.2f} = {2*x:.2f}\n"
|
| 141 |
+
f"- z = 3·{2*x:.2f} = {3*(2*x):.2f}\n"
|
| 142 |
+
f"- L = 4·{3*(2*x):.2f} = {4*(3*(2*x)):.2f}\n\n"
|
| 143 |
+
"**Chain Rule:** dL/dx = 4 × 3 × 2 = 24")
|
| 144 |
+
return plot_chain_network(x), msg
|
| 145 |
+
def update_bp(x, w1, w2, t): return plot_backprop_dnn(x, w1, w2, t), ''
|
| 146 |
+
|
| 147 |
+
def load_secant(): return plot_secant(0.01), "**Hint:** try moving the slider!"
|
| 148 |
+
def load_tangent(): return plot_tangent(0.0), "**Hint:** try moving the slider!"
|
| 149 |
+
def load_gd(): return plot_gradient_descent(0.1, 1.0, 10), "**Hint:** try moving the sliders!"
|
| 150 |
+
def load_chain(): return plot_chain_network(1.0), "**Hint:** try moving the slider!"
|
| 151 |
+
def load_bp(): return plot_backprop_dnn(0.5, 1.0, 1.0, 0.0), "**Hint:** try moving the sliders!"
|
| 152 |
+
|
| 153 |
+
def reset_all(): return (
|
| 154 |
+
gr.update(value=0.01), gr.update(value=0.0), gr.update(value=0.1),
|
| 155 |
+
gr.update(value=1.0), gr.update(value=10),
|
| 156 |
+
gr.update(value=1.0), gr.update(value=0.5), gr.update(value=1.0),
|
| 157 |
+
gr.update(value=1.0), gr.update(value=0.0)
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
with gr.Blocks() as demo:
|
| 161 |
+
with gr.Tabs():
|
| 162 |
+
with gr.TabItem("Secant Approximation"):
|
| 163 |
+
gr.HTML("<p><strong>Secant Approximation</strong> <span style='cursor:help' title='Approximates derivative via (f(x+h)-f(x))/h'>❔</span></p>")
|
| 164 |
+
with gr.Row():
|
| 165 |
+
with gr.Column(scale=3):
|
| 166 |
+
h = gr.Slider(0.001, 1.0, value=0.01, step=0.001, label="h")
|
| 167 |
+
p1, m1 = gr.Plot(), gr.Markdown()
|
| 168 |
+
h.change(update_secant, [h], [p1, m1])
|
| 169 |
+
with gr.Column(scale=1):
|
| 170 |
+
gr.HTML("<p><strong>Key Question:</strong><br>What does the secant slope approximate?<br><span style='cursor:help' title='The instantaneous rate of change (derivative).'>❔</span></p>")
|
| 171 |
+
with gr.TabItem("Tangent Visualization"):
|
| 172 |
+
gr.HTML("<p><strong>Tangent Visualization</strong> <span style='cursor:help' title='Shows tangent line at x and slope dy/dx'>❔</span></p>")
|
| 173 |
+
with gr.Row():
|
| 174 |
+
with gr.Column(scale=3):
|
| 175 |
+
x0 = gr.Slider(-1.0, 2.0, value=0.0, step=0.1, label="x")
|
| 176 |
+
p2, m2 = gr.Plot(), gr.Markdown()
|
| 177 |
+
x0.change(update_tangent, [x0], [p2, m2])
|
| 178 |
+
with gr.Column(scale=1):
|
| 179 |
+
gr.HTML("<p><strong>Key Question:</strong><br>What does the tangent line represent?<br><span style='cursor:help' title='The instantaneous rate of change at the point.'>❔</span></p>")
|
| 180 |
+
with gr.TabItem("Gradient Descent"):
|
| 181 |
+
gr.HTML("<p><strong>Gradient Descent</strong> <span style='cursor:help' title='Shows gradient descent steps on x^2 curve'>❔</span></p>")
|
| 182 |
+
with gr.Row():
|
| 183 |
+
with gr.Column(scale=3):
|
| 184 |
+
lr = gr.Slider(0.01, 0.5, value=0.1, step=0.01, label="Learning Rate")
|
| 185 |
+
init = gr.Slider(-2.0, 2.0, value=1.0, step=0.1, label="Initial x")
|
| 186 |
+
st = gr.Slider(1, 50, value=10, step=1, label="Iterations")
|
| 187 |
+
pg, mg = gr.Plot(), gr.Markdown()
|
| 188 |
+
for inp in [lr, init, st]: inp.change(update_gd, [lr, init, st], [pg, mg])
|
| 189 |
+
with gr.Column(scale=1):
|
| 190 |
+
gr.HTML("<p><strong>Key Question:</strong><br>How does gradient descent move?<br><span style='cursor:help' title='It moves opposite to the gradient towards the function minimum.'>❔</span></p>")
|
| 191 |
+
with gr.TabItem("Chain Rule"):
|
| 192 |
+
gr.HTML("<p><strong>Chain Rule</strong> <span style='cursor:help' title='Visualizes chain rule in computation graph'>❔</span></p>")
|
| 193 |
+
with gr.Row():
|
| 194 |
+
with gr.Column(scale=3):
|
| 195 |
+
x_s = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label="x")
|
| 196 |
+
cp, cm = gr.Plot(), gr.Markdown()
|
| 197 |
+
x_s.change(update_chain, [x_s], [cp, cm])
|
| 198 |
+
with gr.Column(scale=1):
|
| 199 |
+
gr.HTML("<p><strong>Key Question:</strong><br>How is dL/dx computed?<br><span style='cursor:help' title='By multiplying partial derivatives along graph: 4×3×2.'>❔</span></p>")
|
| 200 |
+
with gr.TabItem("Backpropagation"):
|
| 201 |
+
gr.HTML("<p><strong>Backpropagation</strong> <span style='cursor:help' title='Visualizes backprop in a simple DNN'>❔</span></p>")
|
| 202 |
+
with gr.Row():
|
| 203 |
+
with gr.Column(scale=3):
|
| 204 |
+
xb = gr.Slider(-2.0, 2.0, value=0.5, step=0.1, label="x")
|
| 205 |
+
w1b = gr.Slider(-2.0, 2.0, value=1.0, step=0.1, label="w1")
|
| 206 |
+
w2b = gr.Slider(-2.0, 2.0, value=1.0, step=0.1, label="w2")
|
| 207 |
+
tb = gr.Slider(-2.0, 2.0, value=0.0, step=0.1, label="t")
|
| 208 |
+
pb, mb = gr.Plot(), gr.Markdown()
|
| 209 |
+
for inp in [xb, w1b, w2b, tb]: inp.change(update_bp, [xb, w1b, w2b, tb], [pb, mb])
|
| 210 |
+
demo.load(load_secant, [], [p1, m1])
|
| 211 |
+
demo.load(load_tangent, [], [p2, m2])
|
| 212 |
+
demo.load(load_gd, [], [pg, mg])
|
| 213 |
+
demo.load(load_chain, [], [cp, cm])
|
| 214 |
+
demo.load(load_bp, [], [pb, mb])
|
| 215 |
+
with gr.Row():
|
| 216 |
+
reset_btn = gr.Button("Reset to default settings")
|
| 217 |
+
gr.HTML("<span style='cursor:help' title='Reset all sliders to defaults.'>❔</span>")
|
| 218 |
+
reset_btn.click(reset_all, [], [h, x0, lr, init, st, x_s, xb, w1b, w2b, tb])
|
| 219 |
+
|
| 220 |
+
demo.launch()
|