rgaddamanugu3 commited on
Commit
5a1d76f
·
verified ·
1 Parent(s): 579fb08

Create final_prototype.py

Browse files
Files changed (1) hide show
  1. final_prototype.py +231 -0
final_prototype.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import gradio as gr
4
+
5
+ def plot_secant(h):
6
+ x = np.linspace(-1, 2, 400)
7
+ y = x**2
8
+ m = (h**2) / h
9
+ fig, axs = plt.subplots(1, 2, figsize=(8, 4))
10
+ for ax in axs:
11
+ ax.set_xlim(-1, 2)
12
+ ax.set_ylim(-1, 4)
13
+ ax.set_xticks(np.arange(-1, 3, 1))
14
+ ax.set_yticks(np.arange(-1, 5, 1))
15
+ ax.grid(True, linestyle='--', linewidth=0.5, color='lightgray')
16
+ ax.spines['top'].set_visible(False)
17
+ ax.spines['right'].set_visible(False)
18
+ axs[0].plot(x, y, color='black')
19
+ axs[0].plot([0, h], [0, h**2], color='red', linewidth=2)
20
+ axs[0].scatter([0, h], [0, h**2], color='red', zorder=5)
21
+ axs[1].plot(x, m * x, color='red', linewidth=2)
22
+ plt.tight_layout()
23
+ return fig
24
+
25
+
26
+ def plot_tangent(x0):
27
+ x = np.linspace(-1, 2, 400)
28
+ y = x**2
29
+ m = 2 * x0
30
+ y0 = x0**2
31
+ fig, axs = plt.subplots(1, 2, figsize=(8, 4))
32
+ for ax in axs:
33
+ ax.set_xlim(-1, 2)
34
+ ax.set_ylim(-1, 4)
35
+ ax.set_xticks(np.arange(-1, 3, 1))
36
+ ax.set_yticks(np.arange(-1, 5, 1))
37
+ ax.grid(True, linestyle='--', linewidth=0.5, color='lightgray')
38
+ ax.spines['top'].set_visible(False)
39
+ ax.spines['right'].set_visible(False)
40
+ axs[0].plot(x, y, color='black')
41
+ axs[0].plot(x, m * (x - x0) + y0, color='red', linewidth=2)
42
+ axs[0].scatter([x0], [y0], color='red', zorder=5)
43
+ axs[1].plot(x, 2 * x, color='black')
44
+ axs[1].scatter([x0], [m], color='red', zorder=5)
45
+ plt.tight_layout()
46
+ return fig
47
+
48
+
49
+ def plot_gradient_descent(lr, init_x, steps):
50
+ n = int(steps)
51
+ path = [init_x]
52
+ for _ in range(n):
53
+ path.append(path[-1] - lr * 2 * path[-1])
54
+ xv = np.array(path)
55
+ fig, axs = plt.subplots(1, 2, figsize=(8, 4))
56
+ x_plot = np.linspace(-2, 2, 400)
57
+ axs[0].plot(x_plot, x_plot**2, color='black')
58
+ axs[0].plot(xv, xv**2, marker='o', color='red', linewidth=2)
59
+ for i in range(n):
60
+ axs[0].annotate('', xy=(xv[i+1], xv[i+1]**2), xytext=(xv[i], xv[i]**2), arrowprops=dict(arrowstyle='->', color='red'))
61
+ axs[0].set_xlim(-2, 2)
62
+ axs[0].set_ylim(-0.5, 5)
63
+ axs[0].set_title('Gradient Descent Path')
64
+ axs[0].grid(True, linestyle='--', linewidth=0.5, color='lightgray')
65
+ axs[1].plot(range(n+1), xv, marker='o', color='red', linewidth=2)
66
+ for i in range(n):
67
+ axs[1].annotate('', xy=(i+1, xv[i+1]), xytext=(i, xv[i]), arrowprops=dict(arrowstyle='->', color='red'))
68
+ axs[1].set_xlim(0, n)
69
+ axs[1].set_ylim(xv.min() - 0.5, xv.max() + 0.5)
70
+ axs[1].set_xticks(range(0, n+1, max(1, n//5)))
71
+ axs[1].set_xlabel('Iteration')
72
+ axs[1].set_title('x over Iterations')
73
+ axs[1].grid(True, linestyle='--', linewidth=0.5, color='lightgray')
74
+ plt.tight_layout()
75
+ return fig
76
+
77
+
78
+ def plot_chain_network(x):
79
+ y = 2 * x
80
+ z = 3 * y
81
+ L = 4 * z
82
+ fig, ax = plt.subplots(figsize=(6, 2))
83
+ ax.axis('off')
84
+ pos = {'x': 0.1, 'y': 0.3, 'z': 0.5, 'L': 0.7}
85
+ for name in pos:
86
+ ax.add_patch(plt.Circle((pos[name], 0.5), 0.05, fill=False))
87
+ ax.text(pos[name], 0.5, name, ha='center', va='center')
88
+ for src, dst, lbl in [
89
+ ('x', 'y', r'$\partial y/\partial x=2$'),
90
+ ('y', 'z', r'$\partial z/\partial y=3$'),
91
+ ('z', 'L', r'$\partial L/\partial z=4$')
92
+ ]:
93
+ sx, dx = pos[src], pos[dst]
94
+ ax.annotate('', xy=(dx, 0.5), xytext=(sx, 0.5), arrowprops=dict(arrowstyle='->'))
95
+ ax.text((sx + dx)/2, 0.6, lbl, ha='center', va='center')
96
+ ax.text(0.02, 0.15, r'$\frac{\partial L}{\partial x}=4\times3\times2$', transform=ax.transAxes, ha='left')
97
+ for name, val in [('x', x), ('y', y), ('z', z), ('L', L)]:
98
+ ax.text(pos[name], 0.3, f"{name}={val:.2f}", ha='center')
99
+ plt.tight_layout()
100
+ return fig
101
+
102
+
103
+ def plot_backprop_dnn(x, w1, w2, t):
104
+ a = w1 * x
105
+ y = w2 * a
106
+ L = 0.5 * (y - t)**2
107
+ fig, ax = plt.subplots(figsize=(6, 2))
108
+ ax.axis('off')
109
+ pos = {'x': 0.1, 'a': 0.3, 'y': 0.5, 'L': 0.7}
110
+ for name in pos:
111
+ ax.add_patch(plt.Circle((pos[name], 0.5), 0.05, fill=False))
112
+ ax.text(pos[name], 0.5, name, ha='center', va='center')
113
+ for src, dst, lbl in [
114
+ ('x', 'a', '∂a/∂x = w₁'),
115
+ ('a', 'y', '∂y/∂a = w₂'),
116
+ ('y', 'L', '∂L/∂y = (y - t)')
117
+ ]:
118
+ sx, dx = pos[src], pos[dst]
119
+ ax.annotate('', xy=(dx, 0.5), xytext=(sx, 0.5), arrowprops=dict(arrowstyle='->'))
120
+ ax.text((sx + dx)/2, 0.6, lbl, ha='center', va='center')
121
+ ax.text(0.02, 0.15, '∂L/∂w₂ = (y - t) · a', transform=ax.transAxes, ha='left')
122
+ ax.text(0.02, 0.02, '∂L/∂w₁ = (y - t) · w₂ · x', transform=ax.transAxes, ha='left')
123
+ plt.tight_layout()
124
+ return fig
125
+
126
+ def update_secant(h): return plot_secant(h), f'**Δx=h={h:.4f}**, slope={(h**2)/h:.4f}'
127
+
128
+ def update_tangent(x0): return plot_tangent(x0), f'**x={x0:.2f}**, dy/dx={2*x0:.2f}'
129
+
130
+ def update_gd(lr, init_x, steps): return plot_gradient_descent(lr, init_x, steps), f'lr={lr:.2f}, init={init_x:.2f}, steps={int(steps)}'
131
+
132
+ def update_chain(x): return plot_chain_network(x), f"**Values:** y={2*x:.2f}, z={3*(2*x):.2f}, L={4*(3*(2*x)):.2f}\n**dL/dx=24**"
133
+
134
+ def update_bp(x, w1, w2, t): return plot_backprop_dnn(x, w1, w2, t), ''
135
+
136
+ def load_secant(): return plot_secant(0.01), '**Hint:** try moving the slider!'
137
+
138
+ def load_tangent(): return plot_tangent(0.0), '**Hint:** try moving the slider!'
139
+
140
+ def load_gd(): return plot_gradient_descent(0.1, 1.0, 10), '**Hint:** try moving the sliders!'
141
+
142
+ def load_chain(): return plot_chain_network(1.0), '**Hint:** try moving the slider!'
143
+
144
+ def load_bp(): return plot_backprop_dnn(0.5, 1.0, 1.0, 0.0), '**Hint:** try moving the sliders!'
145
+
146
+ def reset_all(): return (
147
+ gr.update(value=0.01), gr.update(value=0.0), gr.update(value=0.1),
148
+ gr.update(value=1.0), gr.update(value=10),
149
+ gr.update(value=1.0), gr.update(value=0.5), gr.update(value=1.0),
150
+ gr.update(value=1.0), gr.update(value=0.0)
151
+ )
152
+
153
+
154
+ demo = gr.Blocks()
155
+ with demo:
156
+ gr.HTML('<div style="border:1px solid lightgray; padding:10px; border-radius:5px">')
157
+ with gr.Tabs():
158
+ with gr.TabItem('Secant Approximation'):
159
+ gr.Markdown('''
160
+ **Feature:** Explore secant line approximations via Δx/h
161
+ - Draws a chord between (x, x²) and (x+h, (x+h)²)
162
+ - As h → 0, the chord converges to the true tangent
163
+ - Goal: See how (f(x+h)-f(x)) / h approaches the instantaneous derivative
164
+ ''')
165
+ h = gr.Slider(0.001, 1.0, value=0.01, step=0.001, label='h')
166
+ p1 = gr.Plot()
167
+ m1 = gr.Markdown()
168
+ h.change(update_secant, [h], [p1, m1])
169
+ with gr.TabItem('Tangent Visualization'):
170
+ gr.Markdown('''
171
+ **Feature:** Visualize tangent line and instantaneous slope
172
+ - Draws the exact tangent at (x, x²)
173
+ - Slide x to update both the red line and its numeric slope
174
+ - Goal: Grasp the derivative as the “best‐fit” rate of change at one point
175
+ ''')
176
+ x0 = gr.Slider(-1.0, 2.0, value=0.0, step=0.1, label='x')
177
+ p2 = gr.Plot()
178
+ m2 = gr.Markdown()
179
+ x0.change(update_tangent, [x0], [p2, m2])
180
+ with gr.TabItem('Gradient Descent'):
181
+ gr.Markdown('''
182
+ **Feature:** Observe gradient descent on y=x²
183
+ - Treats the curve as a valley; you take downhill steps
184
+ - Left: path on the curve with arrows; Right: x vs. iteration
185
+ - Goal: Feel how step size (learning rate) affects speed and stability toward the minimum
186
+ ''')
187
+ lr = gr.Slider(0.01, 0.5, value=0.1, step=0.01, label='Learning Rate')
188
+ init = gr.Slider(-2.0, 2.0, value=1.0, step=0.1, label='Initial x')
189
+ st = gr.Slider(1, 50, value=10, step=1, label='Iterations')
190
+ pg = gr.Plot()
191
+ mg = gr.Markdown()
192
+ for inp in [lr, init, st]: inp.change(update_gd, [lr, init, st], [pg, mg])
193
+ with gr.TabItem('Chain Rule'):
194
+ gr.Markdown('''
195
+ **Feature:** Demonstrate the chain rule via a computation graph
196
+ - Nodes: x → y=2x → z=3y → L=4z with arrows showing ∂→
197
+ - Multiply local slopes 2 × 3 × 4 = 24 for overall dL/dx
198
+ - Goal: Visualize why and how partial derivatives combine to yield the total derivative
199
+ ''')
200
+ xs = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label='x')
201
+ cp = gr.Plot()
202
+ cm = gr.Markdown()
203
+ xs.change(update_chain, [xs], [cp, cm])
204
+ with gr.TabItem('Backpropagation'):
205
+ gr.Markdown('''
206
+ **Feature:** Visualize backpropagation derivatives in a simple DNN
207
+ - Network: x → a=w₁x → y=w₂a → L=½(y–t)² with local ∂ annotations
208
+ - Tweak x, w₁, w₂, or t and watch error gradients flow backward
209
+ - Goal: Demystify how output error propagates to compute each weight’s gradient
210
+ ''')
211
+ xb = gr.Slider(-2.0, 2.0, value=0.5, step=0.1, label='x')
212
+ w1b = gr.Slider(-2.0, 2.0, value=1.0, step=0.1, label='w1')
213
+ w2b = gr.Slider(-2.0, 2.0, value=1.0, step=0.1, label='w2')
214
+ tb = gr.Slider(-2.0, 2.0, value=0.0, step=0.1, label='t')
215
+ pb = gr.Plot()
216
+ mb = gr.Markdown()
217
+ for inp in [xb, w1b, w2b, tb]: inp.change(update_bp, [xb, w1b, w2b, tb], [pb, mb])
218
+
219
+ demo.load(load_secant, [], [p1, m1])
220
+ demo.load(load_tangent, [], [p2, m2])
221
+ demo.load(load_gd, [], [pg, mg])
222
+ demo.load(load_chain, [], [cp, cm])
223
+ demo.load(load_bp, [], [pb, mb])
224
+
225
+ with gr.Row():
226
+ reset_btn = gr.Button('Reset to default settings')
227
+ gr.HTML("<span style='cursor:help' title='Reset all sliders to defaults.'></span>")
228
+ reset_btn.click(reset_all, [], [h, x0, lr, init, st, xs, xb, w1b, w2b, tb])
229
+ gr.HTML('</div>')
230
+
231
+ demo.launch()