File size: 2,996 Bytes
c98136b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
import gradio as gr
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt


def initialize_input(N, std):
    x = np.random.uniform(0, 1, N)
    x.sort()
    noise = np.random.normal(0, std, N)
    fx = lambda input: np.sin(2 * np.pi * input)
    y = fx(x) + noise
    return x, y, fx


def fit_poly(X, y, M, regularization):
    lam = np.exp(regularization)
    X = X[:, np.newaxis]
    poly = PolynomialFeatures(degree=M)
    X_poly = poly.fit_transform(X)
    model = Ridge(alpha=lam)
    model.fit(X_poly, y)
    return model, poly


def plot_and_error(M, N, std, seed, regularization, show_truth):
    np.random.seed(int(seed))

    # Prepare data
    X, y, fx = initialize_input(N, std)
    model, poly = fit_poly(X, y, M, regularization)

    # Plot curve
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(X, y, 'o', mfc='none', mec='b', ms=6, label="Data")

    X_truth = np.linspace(0, 1, 400)
    y_hat = model.predict(poly.transform(X_truth[:, np.newaxis]))

    if show_truth:
        ax.plot(X_truth, np.sin(2 * np.pi * X_truth), color='lightgreen', lw=2, label=r"$\sin(2\pi x)$")

    ax.plot(X_truth, y_hat, 'r-', label="Predicted polynomial")

    ax.set_xlim(0, 1)
    ax.set_ylim(-1.1, 1.1)
    ax.set_xlabel(r"$x$")
    ax.set_ylabel(r"$y$")
    ax.set_title(f"Polynomial Fit (M={M})")
    ax.legend(frameon=False)

    # Compute error
    y_pred = model.predict(poly.transform(X[:, np.newaxis]))
    mse = mean_squared_error(y, y_pred)
    error_text = f"📉 Mean Squared Error (MSE): **{mse:.4f}**"

    return fig, error_text


with gr.Blocks(css="""
    #title {text-align: center; font-size: 28px; font-weight: bold; margin-bottom: 20px;}
    #sliders {padding: 12px; background: #f9f9f9; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.1);}
    .gradio-container {font-family: 'Segoe UI', sans-serif;}
""") as demo:
    gr.Markdown("<div id='title'>🎯 Polynomial Fitting Playground</div>")
    with gr.Row():
        with gr.Column(elem_id="sliders", scale=1):
            M = gr.Slider(0, 15, value=3, step=1, label="Polynomial Degree (M)")
            N = gr.Slider(5, 100, value=10, step=1, label="Number of Points (N)")
            sigma = gr.Slider(0, 1, value=0.1, step=0.05, label="Noise Level (σ)")
            reg = gr.Slider(-50, 0, value=-18, step=1, label="Regularization (ln λ)")
            seed = gr.Number(value=100, label="Random Seed", precision=0)
            show_truth = gr.Checkbox(value=True, label="Show Ground Truth Function")

        with gr.Column(scale=2):
            plot_output = gr.Plot()
            error_output = gr.Markdown()

    # Auto update on input change
    inputs = [M, N, sigma, seed, reg, show_truth]
    for inp in inputs:
        inp.change(fn=plot_and_error, inputs=inputs, outputs=[plot_output, error_output])

if __name__ == "__main__":
    demo.launch(share=True)