| import numpy as np |
| import matplotlib.pyplot as plt |
| import gradio as gr |
|
|
| def f(x, func_name="Quadratic"): |
| if func_name == "Quadratic": |
| return (x - 2)**2 + 1 |
| elif func_name == "Quartic": |
| return x**4 - 3*(x**2) + 2 |
|
|
| def grad_f(x, func_name="Quadratic"): |
| if func_name == "Quadratic": |
| return 2*(x - 2) |
| elif func_name == "Quartic": |
| return 4*(x**3) - 6*x |
|
|
| def run_gd(func_name, x0, lr, steps, x_min, x_max): |
| xs = [float(x0)] |
| ys = [float(f(x0, func_name))] |
| x = float(x0) |
|
|
| for _ in range(int(steps)): |
| g = float(grad_f(x, func_name)) |
| x = x - float(lr) * g |
| xs.append(x) |
| ys.append(float(f(x, func_name))) |
|
|
| grid = np.linspace(float(x_min), float(x_max), 400) |
| vals = f(grid, func_name) |
|
|
| fig1 = plt.figure() |
| plt.plot(grid, vals) |
| plt.scatter(xs, ys, s=30) |
| plt.plot(xs, ys, linestyle="--") |
| plt.title(f"Gradient Descent Path on {func_name}") |
| plt.xlabel("x") |
| plt.ylabel("f(x)") |
| plt.grid(True) |
|
|
| fig2 = plt.figure() |
| plt.plot(range(len(ys)), ys) |
| plt.title("Objective Value Over Iterations") |
| plt.xlabel("iteration") |
| plt.ylabel("f(x)") |
| plt.grid(True) |
|
|
| final = f"Final x = {xs[-1]:.6f}, f(x) = {ys[-1]:.6f}" |
| return fig1, fig2, final |
|
|
| demo = gr.Interface( |
| fn=run_gd, |
| inputs=[ |
| gr.Dropdown(["Quadratic", "Quartic"], value="Quadratic", label="Function"), |
| gr.Slider(-10, 10, value=8, step=0.1, label="Initial x0"), |
| gr.Slider(0.001, 1.0, value=0.1, step=0.001, label="Learning rate (lr)"), |
| gr.Slider(1, 200, value=30, step=1, label="Steps"), |
| gr.Slider(-15, 0, value=-5, step=0.5, label="Plot x_min"), |
| gr.Slider(0, 15, value=10, step=0.5, label="Plot x_max"), |
| ], |
| outputs=[ |
| gr.Plot(label="Function + GD path"), |
| gr.Plot(label="Loss curve"), |
| gr.Textbox(label="Result"), |
| ], |
| title="Gradient Descent Visualizer (from scratch)", |
| description="Adjust learning rate, starting point, and steps to see how gradient descent moves. Update rule is implemented manually." |
| ) |
|
|
| demo.launch() |
|
|