rgaddamanugu3 commited on
Commit
dfe1048
·
verified ·
1 Parent(s): d1fe919

Upload final_graph.py

Browse files
Files changed (1) hide show
  1. final_graph.py +220 -0
final_graph.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import gradio as gr
4
+
5
+ def plot_secant(h):
6
+ x = np.linspace(-1, 2, 400)
7
+ y = x**2
8
+ m = (h**2) / h
9
+ fig, axs = plt.subplots(1, 2, figsize=(8, 4))
10
+ for ax in axs:
11
+ ax.set_xlim(-1, 2)
12
+ ax.set_ylim(-1, 4)
13
+ ax.set_xticks(np.arange(-1, 3, 1))
14
+ ax.set_yticks(np.arange(-1, 5, 1))
15
+ ax.grid(True, linestyle='--', linewidth=0.5, color='lightgray')
16
+ ax.spines['top'].set_visible(False)
17
+ ax.spines['right'].set_visible(False)
18
+ axs[0].plot(x, y, color='black')
19
+ axs[0].plot([0, h], [0, h**2], color='red', linewidth=2)
20
+ axs[0].scatter([0, h], [0, h**2], color='red', zorder=5)
21
+ axs[1].plot(x, m * x, color='red', linewidth=2)
22
+ plt.tight_layout()
23
+ return fig
24
+
25
+
26
+ def plot_tangent(x0):
27
+ x = np.linspace(-1, 2, 400)
28
+ y = x**2
29
+ m = 2 * x0
30
+ y0 = x0**2
31
+ fig, axs = plt.subplots(1, 2, figsize=(8, 4))
32
+ for ax in axs:
33
+ ax.set_xlim(-1, 2)
34
+ ax.set_ylim(-1, 4)
35
+ ax.set_xticks(np.arange(-1, 3, 1))
36
+ ax.set_yticks(np.arange(-1, 5, 1))
37
+ ax.grid(True, linestyle='--', linewidth=0.5, color='lightgray')
38
+ ax.spines['top'].set_visible(False)
39
+ ax.spines['right'].set_visible(False)
40
+ axs[0].plot(x, y, color='black')
41
+ axs[0].plot(x, m * (x - x0) + y0, color='red', linewidth=2)
42
+ axs[0].scatter([x0], [y0], color='red', zorder=5)
43
+ axs[1].plot(x, 2 * x, color='black')
44
+ axs[1].scatter([x0], [m], color='red', zorder=5)
45
+ plt.tight_layout()
46
+ return fig
47
+
48
+
49
+ def plot_gradient_descent(lr, init_x, steps):
50
+ n = int(steps)
51
+ path = [init_x]
52
+ for _ in range(n):
53
+ path.append(path[-1] - lr * 2 * path[-1])
54
+ xv = np.array(path)
55
+
56
+ fig, axs = plt.subplots(1, 2, figsize=(8, 4))
57
+ x_plot = np.linspace(-2, 2, 400)
58
+ axs[0].plot(x_plot, x_plot**2, color='black')
59
+ axs[0].plot(xv, xv**2, marker='o', color='red', linewidth=2)
60
+ for i in range(n):
61
+ axs[0].annotate('', xy=(xv[i+1], xv[i+1]**2), xytext=(xv[i], xv[i]**2), arrowprops=dict(arrowstyle='->', color='red'))
62
+ axs[0].set_xlim(-2, 2)
63
+ axs[0].set_ylim(-0.5, 5)
64
+ axs[0].set_title('Gradient Descent Path')
65
+ axs[0].grid(True, linestyle='--', linewidth=0.5, color='lightgray')
66
+
67
+ axs[1].plot(range(n+1), xv, marker='o', color='red', linewidth=2)
68
+ for i in range(n):
69
+ axs[1].annotate('', xy=(i+1, xv[i+1]), xytext=(i, xv[i]), arrowprops=dict(arrowstyle='->', color='red'))
70
+ axs[1].set_xlim(0, n)
71
+ axs[1].set_ylim(xv.min() - 0.5, xv.max() + 0.5)
72
+ axs[1].set_xticks(range(0, n+1, max(1, n//5)))
73
+ axs[1].set_xlabel('Iteration')
74
+ axs[1].set_title('x over Iterations')
75
+ axs[1].grid(True, linestyle='--', linewidth=0.5, color='lightgray')
76
+
77
+ plt.tight_layout()
78
+ return fig
79
+
80
+
81
+ def plot_chain_network(x):
82
+ y = 2 * x
83
+ z = 3 * y
84
+ L = 4 * z
85
+ fig, ax = plt.subplots(figsize=(6, 2))
86
+ ax.axis('off')
87
+ pos = {'x': 0.1, 'y': 0.3, 'z': 0.5, 'L': 0.7}
88
+ for name in pos:
89
+ ax.add_patch(plt.Circle((pos[name], 0.5), 0.05, fill=False))
90
+ ax.text(pos[name], 0.5, name, ha='center', va='center')
91
+ for src, dst, lbl in [
92
+ ('x', 'y', r'$\partial y/\partial x=2$'),
93
+ ('y', 'z', r'$\partial z/\partial y=3$'),
94
+ ('z', 'L', r'$\partial L/\partial z=4$')
95
+ ]:
96
+ sx, dx = pos[src], pos[dst]
97
+ ax.annotate('', xy=(dx, 0.5), xytext=(sx, 0.5), arrowprops=dict(arrowstyle='->'))
98
+ ax.text((sx + dx) / 2, 0.6, lbl, ha='center', va='center')
99
+ ax.text(0.02, 0.15,
100
+ r'$\frac{\partial L}{\partial x}=\frac{\partial L}{\partial z}\cdot\frac{\partial z}{\partial y}\cdot\frac{\partial y}{\partial x}$',
101
+ transform=ax.transAxes, ha='left')
102
+ ax.text(0.02, 0.02, r'$=4\times3\times2=24$', transform=ax.transAxes, ha='left')
103
+ for name, val in [('x', x), ('y', y), ('z', z), ('L', L)]:
104
+ ax.text(pos[name], 0.3, f"{name}={val:.2f}", ha='center')
105
+ plt.tight_layout()
106
+ return fig
107
+
108
+
109
+ def plot_backprop_dnn(x, w1, w2, t):
110
+ a = w1 * x
111
+ y = w2 * a
112
+ L = 0.5 * (y - t)**2
113
+ fig, ax = plt.subplots(figsize=(6, 2))
114
+ ax.axis('off')
115
+ pos = {'x': 0.1, 'a': 0.3, 'y': 0.5, 'L': 0.7}
116
+ for name in pos:
117
+ ax.add_patch(plt.Circle((pos[name], 0.5), 0.05, fill=False))
118
+ ax.text(pos[name], 0.5, name, ha='center', va='center')
119
+ for src, dst, lbl in [
120
+ ('x', 'a', r'$\partial a/\partial x=w_1$'),
121
+ ('a', 'y', r'$\partial y/\partial a=w_2$'),
122
+ ('y', 'L', r'$\partial L/\partial y=(y-t)$')
123
+ ]:
124
+ sx, dx = pos[src], pos[dst]
125
+ ax.annotate('', xy=(dx, 0.5), xytext=(sx, 0.5), arrowprops=dict(arrowstyle='->'))
126
+ ax.text((sx + dx) / 2, 0.6, lbl, ha='center', va='center')
127
+ ax.text(0.02, 0.15, r'$\partial L/\partial w_2=(y-t)\cdot a$', transform=ax.transAxes, ha='left')
128
+ ax.text(0.02, 0.02, r'$\partial L/\partial w_1=(y-t)\cdot w_2\cdot x$', transform=ax.transAxes, ha='left')
129
+ for name, val in [('x', x), ('a', a), ('y', y), ('L', L)]:
130
+ ax.text(pos[name], 0.3, f"{name}={val:.2f}", ha='center')
131
+ plt.tight_layout()
132
+ return fig
133
+
134
+
135
+ def update_secant(h): return plot_secant(h), f'**Δx=h={h:.4f}**, (f(x+h)-f(x))/h={(h**2)/h:.4f}'
136
+ def update_tangent(x0): return plot_tangent(x0), f'**x={x0:.2f}**, dy/dx={2*x0:.2f}'
137
+ def update_gd(lr, init_x, steps): return plot_gradient_descent(lr, init_x, steps), f'lr={lr:.2f}, init={init_x:.2f}, steps={int(steps)}'
138
+ def update_chain(x):
139
+ msg = (f"**Current values:**\n"
140
+ f"- y = 2·{x:.2f} = {2*x:.2f}\n"
141
+ f"- z = 3·{2*x:.2f} = {3*(2*x):.2f}\n"
142
+ f"- L = 4·{3*(2*x):.2f} = {4*(3*(2*x)):.2f}\n\n"
143
+ "**Chain Rule:** dL/dx = 4 × 3 × 2 = 24")
144
+ return plot_chain_network(x), msg
145
+ def update_bp(x, w1, w2, t): return plot_backprop_dnn(x, w1, w2, t), ''
146
+
147
+ def load_secant(): return plot_secant(0.01), "**Hint:** try moving the slider!"
148
+ def load_tangent(): return plot_tangent(0.0), "**Hint:** try moving the slider!"
149
+ def load_gd(): return plot_gradient_descent(0.1, 1.0, 10), "**Hint:** try moving the sliders!"
150
+ def load_chain(): return plot_chain_network(1.0), "**Hint:** try moving the slider!"
151
+ def load_bp(): return plot_backprop_dnn(0.5, 1.0, 1.0, 0.0), "**Hint:** try moving the sliders!"
152
+
153
+ def reset_all(): return (
154
+ gr.update(value=0.01), gr.update(value=0.0), gr.update(value=0.1),
155
+ gr.update(value=1.0), gr.update(value=10),
156
+ gr.update(value=1.0), gr.update(value=0.5), gr.update(value=1.0),
157
+ gr.update(value=1.0), gr.update(value=0.0)
158
+ )
159
+
160
+ with gr.Blocks() as demo:
161
+ with gr.Tabs():
162
+ with gr.TabItem("Secant Approximation"):
163
+ gr.HTML("<p><strong>Secant Approximation</strong> <span style='cursor:help' title='Approximates derivative via (f(x+h)-f(x))/h'>❔</span></p>")
164
+ with gr.Row():
165
+ with gr.Column(scale=3):
166
+ h = gr.Slider(0.001, 1.0, value=0.01, step=0.001, label="h")
167
+ p1, m1 = gr.Plot(), gr.Markdown()
168
+ h.change(update_secant, [h], [p1, m1])
169
+ with gr.Column(scale=1):
170
+ gr.HTML("<p><strong>Key Question:</strong><br>What does the secant slope approximate?<br><span style='cursor:help' title='The instantaneous rate of change (derivative).'>❔</span></p>")
171
+ with gr.TabItem("Tangent Visualization"):
172
+ gr.HTML("<p><strong>Tangent Visualization</strong> <span style='cursor:help' title='Shows tangent line at x and slope dy/dx'>❔</span></p>")
173
+ with gr.Row():
174
+ with gr.Column(scale=3):
175
+ x0 = gr.Slider(-1.0, 2.0, value=0.0, step=0.1, label="x")
176
+ p2, m2 = gr.Plot(), gr.Markdown()
177
+ x0.change(update_tangent, [x0], [p2, m2])
178
+ with gr.Column(scale=1):
179
+ gr.HTML("<p><strong>Key Question:</strong><br>What does the tangent line represent?<br><span style='cursor:help' title='The instantaneous rate of change at the point.'>❔</span></p>")
180
+ with gr.TabItem("Gradient Descent"):
181
+ gr.HTML("<p><strong>Gradient Descent</strong> <span style='cursor:help' title='Shows gradient descent steps on x^2 curve'>❔</span></p>")
182
+ with gr.Row():
183
+ with gr.Column(scale=3):
184
+ lr = gr.Slider(0.01, 0.5, value=0.1, step=0.01, label="Learning Rate")
185
+ init = gr.Slider(-2.0, 2.0, value=1.0, step=0.1, label="Initial x")
186
+ st = gr.Slider(1, 50, value=10, step=1, label="Iterations")
187
+ pg, mg = gr.Plot(), gr.Markdown()
188
+ for inp in [lr, init, st]: inp.change(update_gd, [lr, init, st], [pg, mg])
189
+ with gr.Column(scale=1):
190
+ gr.HTML("<p><strong>Key Question:</strong><br>How does gradient descent move?<br><span style='cursor:help' title='It moves opposite to the gradient towards the function minimum.'>❔</span></p>")
191
+ with gr.TabItem("Chain Rule"):
192
+ gr.HTML("<p><strong>Chain Rule</strong> <span style='cursor:help' title='Visualizes chain rule in computation graph'>❔</span></p>")
193
+ with gr.Row():
194
+ with gr.Column(scale=3):
195
+ x_s = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label="x")
196
+ cp, cm = gr.Plot(), gr.Markdown()
197
+ x_s.change(update_chain, [x_s], [cp, cm])
198
+ with gr.Column(scale=1):
199
+ gr.HTML("<p><strong>Key Question:</strong><br>How is dL/dx computed?<br><span style='cursor:help' title='By multiplying partial derivatives along graph: 4×3×2.'>❔</span></p>")
200
+ with gr.TabItem("Backpropagation"):
201
+ gr.HTML("<p><strong>Backpropagation</strong> <span style='cursor:help' title='Visualizes backprop in a simple DNN'>❔</span></p>")
202
+ with gr.Row():
203
+ with gr.Column(scale=3):
204
+ xb = gr.Slider(-2.0, 2.0, value=0.5, step=0.1, label="x")
205
+ w1b = gr.Slider(-2.0, 2.0, value=1.0, step=0.1, label="w1")
206
+ w2b = gr.Slider(-2.0, 2.0, value=1.0, step=0.1, label="w2")
207
+ tb = gr.Slider(-2.0, 2.0, value=0.0, step=0.1, label="t")
208
+ pb, mb = gr.Plot(), gr.Markdown()
209
+ for inp in [xb, w1b, w2b, tb]: inp.change(update_bp, [xb, w1b, w2b, tb], [pb, mb])
210
+ demo.load(load_secant, [], [p1, m1])
211
+ demo.load(load_tangent, [], [p2, m2])
212
+ demo.load(load_gd, [], [pg, mg])
213
+ demo.load(load_chain, [], [cp, cm])
214
+ demo.load(load_bp, [], [pb, mb])
215
+ with gr.Row():
216
+ reset_btn = gr.Button("Reset to default settings")
217
+ gr.HTML("<span style='cursor:help' title='Reset all sliders to defaults.'>❔</span>")
218
+ reset_btn.click(reset_all, [], [h, x0, lr, init, st, x_s, xb, w1b, w2b, tb])
219
+
220
+ demo.launch()