Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import jax | |
| import jax.numpy as jnp | |
| import matplotlib.pyplot as plt | |
| def parabola_fn(x): | |
| return x**0.5 | |
| def circle_fn(x): | |
| return (1 - x**2) ** 0.5 | |
| d_parabola_fn = jax.grad(parabola_fn) | |
| d_circle_fn = jax.grad(circle_fn) | |
| def loss_fn(params): | |
| x1 = params["x1"] | |
| x2 = params["x2"] | |
| # parpendicular line to the tangent of the parabola: y = m1 * x + c1 | |
| m1 = -1 / d_parabola_fn(x1) | |
| c1 = parabola_fn(x1) - m1 * x1 | |
| def perpendicular_parabola_fn(x): | |
| return m1 * x + c1 | |
| # parpendicular line to the tangent of the circle: y = m2 * x + c2 | |
| m2 = -1 / d_circle_fn(x2) | |
| c2 = circle_fn(x2) - m2 * x2 | |
| def perpendicular_circle_fn(x): | |
| return m2 * x + c2 | |
| # x_star and y_star are the intersection of the two lines | |
| x_star = (c2 - c1) / (m1 - m2) | |
| y_star = m1 * x_star + c1 | |
| # three quantities should be equal to each other | |
| # 1. distance between intersection and parabola | |
| # 2. distance between intersection and circle | |
| # 3. distance between intersection and x=0 line | |
| d1 = (x_star - x1) ** 2 + (y_star - parabola_fn(x1)) ** 2 | |
| d2 = (x_star - x2) ** 2 + (y_star - circle_fn(x2)) ** 2 | |
| d3 = x_star**2 | |
| aux = { | |
| "x_star": x_star, | |
| "y_star": y_star, | |
| "perpendicular_parabola_fn": perpendicular_parabola_fn, | |
| "perpendicular_circle_fn": perpendicular_circle_fn, | |
| "r": d1**0.5, | |
| } | |
| # final loss | |
| loss = (d1 - d2) ** 2 + (d1 - d3) ** 2 + (d2 - d3) ** 2 | |
| return loss, aux | |
| x = jnp.linspace(0, 1, 100) | |
| st.title("Radius of the Circle: Optimization Playground") | |
| st.markdown( | |
| r""" | |
| Inspired from: https://twitter.com/iwontoffendyou/status/1704935240907518367 | |
| Optimize the radius of the circle such that it is tangent to the parabola, unit circle and the x=0 line | |
| Method: | |
| - The inner circle is tangent to the parabola at $x=x_1$ and tangent to the unit circle at $x=x_2$. | |
| - Let's call the center of the inner circle as $(x^*, y^*)$. | |
| - We know that ditances between $(x^*, y^*)$ and the parabola, unit circle and the x=0 line should be equal to each other. | |
| - First, we can find analytical forms of perpendicular lines shown in the figure. They have the form of $y = m * x + c$ where $m = -\frac{1}{f'(x)}$ and $c = f(x) - m * x$. | |
| - Perpendicular line to the parabola: $y = m_1 * x + c_1$ | |
| - Perpendicular line to the unit circle: $y = m_2 * x + c_2$ | |
| - The intersection of the two lines is $(x^*, y^*)$. $x^* = \frac{c_2 - c_1}{m_1 -m_2}$ and $y^* = m_1 * x^* + c_1$ or $y^* = m_2 * x^* + c_2$. | |
| - We define three distances: $d_1 = (x^* - x_1)^2 + (y^* - f(x_1))^2$, $d_2 = (x^* - x_2)^2 + (y^* - f(x_2))^2$ and $d_3 = {x^*}^2$. | |
| - Our loss function is $L = (d_1 - d_2)^2 + (d_1 - d_3)^2 + (d_2 - d_3)^2$. | |
| """ | |
| ) | |
| col1, col2 = st.columns(2) | |
| x1 = col1.slider("initial x1 (x intersection with parabola)", 0.0, 1.0, 0.5) | |
| x2 = col1.slider("initial x2 (x intersection with the circle)", 0.0, 1.0, 0.5) | |
| n_epochs = col2.slider("n_epochs", 0, 1000, 50) | |
| lr = col2.slider("lr", 0.0, 1.0, value=0.1, step=0.01) | |
| # submit button | |
| submit = st.button("submit") | |
| # when submit button is clicked run the following code | |
| params = {"x1": x1, "x2": x2} | |
| losses = [] | |
| value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True) | |
| # initialize plot | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 6)) | |
| axes[0].set_xlim(0, 1) | |
| axes[0].set_ylim(0, 1) | |
| axes[0].set_aspect("equal") | |
| value, aux = loss_fn(params) | |
| (pbola_plot,) = axes[0].plot(x, parabola_fn(x), color="red") | |
| (pbola_perpendicular_plot,) = axes[0].plot(x, aux["perpendicular_parabola_fn"](x), color="red", linestyle="--") | |
| (cicle_plot,) = axes[0].plot(x, circle_fn(x), color="blue") | |
| (circle_perpendicular_plot,) = axes[0].plot(x, aux["perpendicular_circle_fn"](x), color="blue", linestyle="--") | |
| x_star, y_star = aux["x_star"], aux["y_star"] | |
| (x0_perpendicular_plot,) = axes[0].plot([0, 1], [y_star, y_star], color="black", linestyle="--") | |
| radius = aux["r"] | |
| axes[0].add_patch(plt.Circle((x_star, y_star), radius, fill=False)) | |
| axes[1].set_xlim(0, n_epochs) | |
| axes[1].set_ylim(0, value) | |
| (loss_plot,) = axes[1].plot(losses, color="black") | |
| pbar = st.progress(0) | |
| with st.empty(): | |
| st.pyplot(fig) | |
| if submit: | |
| for i in range(n_epochs): | |
| (value, _), grad = value_and_grad_fn(params) | |
| params["x1"] -= lr * grad["x1"] | |
| params["x2"] -= lr * grad["x2"] | |
| losses.append(value) | |
| _, aux = loss_fn(params) | |
| print(params, grad, lr) | |
| pbola_plot.set_data(x, parabola_fn(x)) | |
| pbola_perpendicular_plot.set_data(x, aux["perpendicular_parabola_fn"](x)) | |
| cicle_plot.set_data(x, circle_fn(x)) | |
| circle_perpendicular_plot.set_data(x, aux["perpendicular_circle_fn"](x)) | |
| x_star, y_star = aux["x_star"], aux["y_star"] | |
| x0_perpendicular_plot.set_data([0, 1], [y_star, y_star]) | |
| radius = aux["r"] | |
| axes[0].add_patch(plt.Circle((x_star, y_star), radius, fill=False)) | |
| loss_plot.set_data(range(len(losses)), losses) | |
| pbar.progress(i / n_epochs) | |
| axes[0].set_title(f"x1: {params['x1']:.3f}, x2: {params['x2']:.3f} \n r: {radius:.4f}") | |
| axes[1].set_title(f"epoch: {i}, loss: {value:.5f}") | |
| st.pyplot(fig) | |