rgaddamanugu3 commited on
Commit
1707037
verified
1 Parent(s): 79d9b7f

Upload graph.py

Browse files
Files changed (1) hide show
  1. graph.py +167 -0
graph.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import gradio as gr
4
+
5
+
6
+ def plot_secant(h):
7
+ x = np.linspace(-1, 2, 400)
8
+ y = x**2
9
+ m = (h**2 - 0) / h
10
+ fig, axs = plt.subplots(1, 2, figsize=(8, 4))
11
+ for ax in axs:
12
+ ax.set_xlim(-1, 2); ax.set_ylim(-1, 4)
13
+ ax.set_xticks(np.arange(-1, 3, 1)); ax.set_yticks(np.arange(-1, 5, 1))
14
+ ax.grid(True, linestyle='--', linewidth=0.5, color='lightgray')
15
+ ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
16
+ axs[0].plot(x, y, color='black'); axs[0].plot([0, h], [0, h**2], color='red', linewidth=2)
17
+ axs[0].scatter([0, h], [0, h**2], color='red', zorder=5)
18
+ axs[1].plot(x, m * x, color='red', linewidth=2)
19
+ plt.tight_layout(); return fig
20
+
21
+ def plot_tangent(x0):
22
+ x = np.linspace(-1, 2, 400); y = x**2; m = 2 * x0; y0 = x0**2
23
+ fig, axs = plt.subplots(1, 2, figsize=(8, 4))
24
+ for ax in axs:
25
+ ax.set_xlim(-1, 2); ax.set_ylim(-1, 4)
26
+ ax.set_xticks(np.arange(-1, 3, 1)); ax.set_yticks(np.arange(-1, 5, 1))
27
+ ax.grid(True, linestyle='--', linewidth=0.5, color='lightgray')
28
+ ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
29
+ axs[0].plot(x, y, color='black'); axs[0].plot(x, m*(x-x0)+y0, color='red', linewidth=2)
30
+ axs[0].scatter([x0], [y0], color='red', zorder=5); axs[1].plot(x, 2*x, color='black')
31
+ axs[1].scatter([x0], [m], color='red', zorder=5)
32
+ plt.tight_layout(); return fig
33
+
34
+ def plot_gradient_descent(lr, init_x, steps):
35
+ n = int(steps); path=[init_x]
36
+ for _ in range(n): path.append(path[-1] - lr*2*path[-1])
37
+ xv = np.array(path)
38
+ fig, axs = plt.subplots(1, 2, figsize=(8, 4))
39
+ x_plot = np.linspace(-2, 2, 400); axs[0].plot(x_plot, x_plot**2, color='black')
40
+ axs[0].plot(xv, xv**2, marker='o', color='red', linewidth=2)
41
+ for i in range(n): axs[0].annotate('', xy=(xv[i+1], xv[i+1]**2), xytext=(xv[i], xv[i]**2), arrowprops=dict(arrowstyle='->', color='red'))
42
+ axs[0].set_xlim(-2, 2); axs[0].set_ylim(-0.5, 5); axs[0].set_title('Gradient Descent Path')
43
+ axs[0].grid(True, linestyle='--', linewidth=0.5, color='lightgray')
44
+ axs[1].plot(range(n+1), xv, marker='o', color='red', linewidth=2)
45
+ for i in range(n): axs[1].annotate('', xy=(i+1, xv[i+1]), xytext=(i, xv[i]), arrowprops=dict(arrowstyle='->', color='red'))
46
+ axs[1].set_xlim(0, n); axs[1].set_ylim(xv.min()-0.5, xv.max()+0.5)
47
+ axs[1].set_xticks(range(0, n+1, 5)); axs[1].set_xlabel('Iteration'); axs[1].set_title('x over Iterations')
48
+ axs[1].grid(True, linestyle='--', linewidth=0.5, color='lightgray')
49
+ plt.tight_layout(); return fig
50
+
51
+
52
+ def plot_chain_network(x):
53
+ y = 2 * x
54
+ z = 3 * y
55
+ L = 4 * z
56
+ fig, ax = plt.subplots(figsize=(6, 2)); ax.axis('off')
57
+ pos = {'x':0.1, 'y':0.3, 'z':0.5, 'L':0.7}
58
+
59
+ for name in pos:
60
+ ax.add_patch(plt.Circle((pos[name], 0.5), 0.05, fill=False))
61
+ ax.text(pos[name], 0.5, name, ha='center', va='center')
62
+
63
+ chain = [
64
+ ('x','y',r'$\partial y/\partial x=2$'),
65
+ ('y','z',r'$\partial z/\partial y=3$'),
66
+ ('z','L',r'$\partial L/\partial z=4$')
67
+ ]
68
+ for src, dst, lbl in chain:
69
+ sx, sy = pos[src], 0.5
70
+ dx, dy = pos[dst], 0.5
71
+ ax.annotate('', xy=(dx, dy), xytext=(sx, sy), arrowprops=dict(arrowstyle='->'))
72
+ ax.text((sx+dx)/2, 0.6, lbl, ha='center', va='center')
73
+
74
+ ax.text(0.02, 0.15,
75
+ r'$\frac{\partial L}{\partial x}=\frac{\partial L}{\partial z}\cdot\frac{\partial z}{\partial y}\cdot\frac{\partial y}{\partial x}$',
76
+ transform=ax.transAxes, ha='left')
77
+ ax.text(0.02, 0.02, r'$=4\times3\times2=24$', transform=ax.transAxes, ha='left')
78
+
79
+ ax.text(pos['x'], 0.3, f"x={x:.2f}", ha='center')
80
+ ax.text(pos['y'], 0.3, f"y={y:.2f}", ha='center')
81
+ ax.text(pos['z'], 0.3, f"z={z:.2f}", ha='center')
82
+ ax.text(pos['L'], 0.3, f"L={L:.2f}", ha='center')
83
+ plt.tight_layout(); return fig
84
+
85
+ def plot_backprop_dnn(x, w1, w2, t):
86
+ a = w1 * x
87
+ y = w2 * a
88
+ L = 0.5 * (y - t)**2
89
+ fig, ax = plt.subplots(figsize=(6, 2)); ax.axis('off')
90
+ pos = {'x':0.1,'a':0.3,'y':0.5,'L':0.7}
91
+ for name in pos:
92
+ ax.add_patch(plt.Circle((pos[name], 0.5), 0.05, fill=False))
93
+ ax.text(pos[name], 0.5, name, ha='center', va='center')
94
+ bp = [
95
+ ('x','a',r'$\partial a/\partial x=w_1$'),
96
+ ('a','y',r'$\partial y/\partial a=w_2$'),
97
+ ('y','L',r'$\partial L/\partial y=(y-t)$')
98
+ ]
99
+ for src, dst, lbl in bp:
100
+ sx, sy = pos[src], 0.5
101
+ dx, dy = pos[dst], 0.5
102
+ ax.annotate('', xy=(dx, dy), xytext=(sx, sy), arrowprops=dict(arrowstyle='->'))
103
+ ax.text((sx+dx)/2, 0.6, lbl, ha='center', va='center')
104
+ ax.text(0.02, 0.15, r'$\partial L/\partial w_2=(y-t)\cdot a$', transform=ax.transAxes, ha='left')
105
+ ax.text(0.02, 0.02, r'$\partial L/\partial w_1=(y-t)\cdot w_2\cdot x$', transform=ax.transAxes, ha='left')
106
+ ax.text(pos['x'], 0.3, f"x={x:.2f}", ha='center')
107
+ ax.text(pos['a'], 0.3, f"a={a:.2f}", ha='center')
108
+ ax.text(pos['y'], 0.3, f"y={y:.2f}", ha='center')
109
+ ax.text(pos['L'], 0.3, f"L={L:.2f}", ha='center')
110
+ plt.tight_layout(); return fig
111
+
112
+ with gr.Blocks() as demo:
113
+ with gr.Tabs():
114
+ with gr.TabItem('Secant Approximation'):
115
+ h = gr.Slider(1e-9, 2.0, value=0.01, step=0.001, label='h')
116
+ p1 = gr.Plot(); m1 = gr.Markdown()
117
+ h.change(lambda v: (plot_secant(v), f'**螖x=h={v:.4f}**, (f(x+h)-f(x))/h={(v**2)/v:.4f}'), h, [p1, m1])
118
+ p1.figure, m1.value = plot_secant(0.01), '**螖x=h=0.0100**, (f(x+h)-f(x))/h=0.0100'
119
+
120
+ with gr.TabItem('Tangent Visualization'):
121
+ x0 = gr.Slider(-1.0, 2.0, value=0.0, step=0.1, label='x')
122
+ p2 = gr.Plot(); m2 = gr.Markdown()
123
+ x0.change(lambda v: (plot_tangent(v), f'**x={v:.2f}**, dy/dx={2*v:.2f}'), x0, [p2, m2])
124
+ p2.figure, m2.value = plot_tangent(0.0), '**x=0.00**, dy/dx=0.00'
125
+
126
+ with gr.TabItem('Gradient Descent'):
127
+ lr = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label='Learning Rate')
128
+ init = gr.Slider(-2.0, 2.0, value=1.0, step=0.1, label='Initial x')
129
+ st = gr.Slider(1, 50, value=10, step=1, label='Iterations')
130
+ pg = gr.Plot(); mg = gr.Markdown()
131
+ for inp in [lr, init, st]:
132
+ inp.change(lambda a, b, c: (plot_gradient_descent(a, b, c), f'lr={a:.2f}, init={b:.2f}, steps={c}'), [lr, init, st], [pg, mg])
133
+ pg.figure, mg.value = plot_gradient_descent(0.1, 1.0, 10), 'lr=0.10, init=1.00, steps=10'
134
+
135
+ with gr.TabItem('Chain Rule'):
136
+ x_slider = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label='x')
137
+ chain_plot = gr.Plot()
138
+ chain_note = gr.Markdown()
139
+ def update_chain(x):
140
+ y = 2 * x
141
+ z = 3 * y
142
+ L = 4 * z
143
+ fig = plot_chain_network(x)
144
+ note = (
145
+ f"**Current values:** \n"
146
+ f"- y = 2路{x:.2f} = {y:.2f} \n"
147
+ f"- z = 3路{y:.2f} = {z:.2f} \n"
148
+ f"- L = 4路{z:.2f} = {L:.2f} \n\n"
149
+ "**Chain Rule:** dL/dx = 4 脳 3 脳 2 = 24"
150
+ )
151
+
152
+ return fig, note
153
+ x_slider.change(update_chain, x_slider, [chain_plot, chain_note])
154
+ chain_plot.figure, chain_note.value = update_chain(1.0)
155
+
156
+
157
+ with gr.TabItem('Backpropagation'):
158
+ xb = gr.Slider(-2.0, 2.0, value=0.5, step=0.1, label='x')
159
+ w1b = gr.Slider(-2.0, 2.0, value=1.0, step=0.1, label='w1')
160
+ w2b = gr.Slider(-2.0, 2.0, value=1.0, step=0.1, label='w2')
161
+ tb = gr.Slider(-2.0, 2.0, value=0.0, step=0.1, label='t')
162
+ pb = gr.Plot(); mb = gr.Markdown()
163
+ for inp in [xb, w1b, w2b, tb]:
164
+ inp.change(lambda a, b, c, d: (plot_backprop_dnn(a, b, c, d), ''), [xb, w1b, w2b, tb], [pb, mb])
165
+ pb.figure, mb.value = plot_backprop_dnn(0.5, 1.0, 1.0, 0.0), ''
166
+
167
+ demo.launch()