Spaces:
Sleeping
Sleeping
Lennard Schober
commited on
Commit
·
16a2246
1
Parent(s):
c188d2e
Fix latex formatting
Browse files
app.py
CHANGED
|
@@ -13,6 +13,7 @@ glob_c = 7.5
|
|
| 13 |
|
| 14 |
n_x, n_t = 10, 10
|
| 15 |
|
|
|
|
| 16 |
def clear_npz():
|
| 17 |
current_directory = os.getcwd() # Get the current working directory
|
| 18 |
for filename in os.listdir(current_directory):
|
|
@@ -93,7 +94,6 @@ def plot_heat_equation(m, approx_type):
|
|
| 93 |
|
| 94 |
# Layout for the Plotly plot without controls
|
| 95 |
layout = go.Layout(
|
| 96 |
-
title=f"Heat Equation Approximation | Kernel = {approx_type} | m = {m}",
|
| 97 |
scene=dict(
|
| 98 |
camera=dict(
|
| 99 |
eye=dict(x=0, y=-2, z=0), # Front view
|
|
@@ -105,28 +105,6 @@ def plot_heat_equation(m, approx_type):
|
|
| 105 |
margin=dict(l=0, r=0, t=0, b=0), # Reduce margins
|
| 106 |
)
|
| 107 |
|
| 108 |
-
# Config to remove modebar buttons except the save image button
|
| 109 |
-
config = {
|
| 110 |
-
"modeBarButtonsToRemove": [
|
| 111 |
-
"pan",
|
| 112 |
-
"resetCameraLastSave",
|
| 113 |
-
"hoverClosest3d",
|
| 114 |
-
"hoverCompareCartesian",
|
| 115 |
-
"zoomIn",
|
| 116 |
-
"zoomOut",
|
| 117 |
-
"select2d",
|
| 118 |
-
"lasso2d",
|
| 119 |
-
"zoomIn2d",
|
| 120 |
-
"zoomOut2d",
|
| 121 |
-
"sendDataToCloud",
|
| 122 |
-
"zoom3d",
|
| 123 |
-
"orbitRotation",
|
| 124 |
-
"tableRotation",
|
| 125 |
-
],
|
| 126 |
-
"displayModeBar": True, # Keep the modebar visible
|
| 127 |
-
"displaylogo": False, # Hide the Plotly logo
|
| 128 |
-
}
|
| 129 |
-
|
| 130 |
# Create the figure
|
| 131 |
fig = go.Figure(data=traces, layout=layout)
|
| 132 |
|
|
@@ -148,7 +126,8 @@ def plot_heat_equation(m, approx_type):
|
|
| 148 |
"tableRotation",
|
| 149 |
"toImage",
|
| 150 |
"resetCameraDefault3d",
|
| 151 |
-
]
|
|
|
|
| 152 |
)
|
| 153 |
|
| 154 |
return fig
|
|
@@ -198,7 +177,6 @@ def plot_errors(m, approx_type):
|
|
| 198 |
|
| 199 |
# Layout for the Plotly plot without controls
|
| 200 |
layout = go.Layout(
|
| 201 |
-
title=f"Heat Equation Approximation Error | Kernel = {approx_type} | m = {m}",
|
| 202 |
scene=dict(
|
| 203 |
camera=dict(
|
| 204 |
eye=dict(x=0, y=-2, z=0), # Front view
|
|
@@ -210,28 +188,6 @@ def plot_errors(m, approx_type):
|
|
| 210 |
margin=dict(l=0, r=0, t=0, b=0), # Reduce margins
|
| 211 |
)
|
| 212 |
|
| 213 |
-
# Config to remove modebar buttons except the save image button
|
| 214 |
-
config = {
|
| 215 |
-
"modeBarButtonsToRemove": [
|
| 216 |
-
"pan",
|
| 217 |
-
"resetCameraLastSave",
|
| 218 |
-
"hoverClosest3d",
|
| 219 |
-
"hoverCompareCartesian",
|
| 220 |
-
"zoomIn",
|
| 221 |
-
"zoomOut",
|
| 222 |
-
"select2d",
|
| 223 |
-
"lasso2d",
|
| 224 |
-
"zoomIn2d",
|
| 225 |
-
"zoomOut2d",
|
| 226 |
-
"sendDataToCloud",
|
| 227 |
-
"zoom3d",
|
| 228 |
-
"orbitRotation",
|
| 229 |
-
"tableRotation",
|
| 230 |
-
],
|
| 231 |
-
"displayModeBar": True, # Keep the modebar visible
|
| 232 |
-
"displaylogo": False, # Hide the Plotly logo
|
| 233 |
-
}
|
| 234 |
-
|
| 235 |
# Create the figure
|
| 236 |
fig = go.Figure(data=traces, layout=layout)
|
| 237 |
|
|
@@ -253,7 +209,8 @@ def plot_errors(m, approx_type):
|
|
| 253 |
"tableRotation",
|
| 254 |
"toImage",
|
| 255 |
"resetCameraDefault3d",
|
| 256 |
-
]
|
|
|
|
| 257 |
)
|
| 258 |
|
| 259 |
return fig
|
|
@@ -340,7 +297,9 @@ def train_coefficients(m, kernel):
|
|
| 340 |
Phi = design_matrix(a_train, theta, kernel)
|
| 341 |
alpha = learn_coefficients(Phi, u_train)
|
| 342 |
# Validate and animate results
|
| 343 |
-
u_real = np.array(
|
|
|
|
|
|
|
| 344 |
a_test = np.c_[np.meshgrid(x, t)[0].ravel(), np.meshgrid(x, t)[1].ravel()]
|
| 345 |
u_approx = approximate_solution(a_test, alpha, theta, kernel).reshape(n_t, n_x)
|
| 346 |
|
|
@@ -415,7 +374,8 @@ def plot_function(k, a, b, c):
|
|
| 415 |
"tableRotation",
|
| 416 |
"toImage",
|
| 417 |
"resetCameraDefault3d",
|
| 418 |
-
]
|
|
|
|
| 419 |
)
|
| 420 |
|
| 421 |
return fig
|
|
@@ -434,36 +394,46 @@ def plot_all(m, kernel):
|
|
| 434 |
gr.update(visible=True, value=error_fig),
|
| 435 |
)
|
| 436 |
|
|
|
|
| 437 |
def change_quality(quality):
|
| 438 |
global n_x, n_t
|
| 439 |
-
|
| 440 |
if quality == "Low":
|
| 441 |
n_x, n_t = 10, 10
|
| 442 |
elif quality == "Mid":
|
| 443 |
n_x, n_t = 20, 20
|
| 444 |
elif quality == "High":
|
| 445 |
n_x, n_t = 40, 40
|
| 446 |
-
|
| 447 |
|
| 448 |
# Gradio interface
|
| 449 |
def create_gradio_ui():
|
| 450 |
global glob_k, glob_a, glob_b, glob_c
|
| 451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
# Get the initial available files
|
| 453 |
with gr.Blocks() as demo:
|
| 454 |
gr.Markdown("# Learn the Coefficients for the Heat Equation using the RFM")
|
| 455 |
|
| 456 |
# Function parameter inputs
|
| 457 |
-
gr.Markdown(
|
| 458 |
-
|
| 459 |
-
## Function: $$u(x, t)\\coloneqq\\exp(-\\textcolor{magenta}{k}(\\textcolor{cyan}{a}\\pi)^2t)\\sin(\\textcolor{cyan}{a}\\pi x)+0.5\\exp(-\\textcolor{magenta}{k}(\\textcolor{lime}{b}\\pi)^2t)\\sin(\\textcolor{lime}{b}\\pi x)+0.25\\exp(-\\textcolor{magenta}{k}(\\textcolor{orange}{c}\\pi)^2t)\\sin(\\textcolor{orange}{c}\\pi x)$$
|
| 460 |
-
|
| 461 |
-
Adjust the values for <span style='color: magenta;'>k</span>, <span style='color: cyan;'>a</span>, <span style='color: lime;'>b</span> and <span style='color: orange;'>c</span> with the sliders below.
|
| 462 |
-
"""
|
| 463 |
-
)
|
| 464 |
|
| 465 |
with gr.Row():
|
| 466 |
-
with gr.Column():
|
| 467 |
k_slider = gr.Slider(
|
| 468 |
minimum=0.0001, maximum=0.1, step=0.0001, value=glob_k, label="k"
|
| 469 |
)
|
|
@@ -476,8 +446,8 @@ def create_gradio_ui():
|
|
| 476 |
c_slider = gr.Slider(
|
| 477 |
minimum=-10, maximum=10, step=0.1, value=glob_c, label="c"
|
| 478 |
)
|
| 479 |
-
|
| 480 |
-
|
| 481 |
|
| 482 |
k_slider.change(
|
| 483 |
fn=plot_function,
|
|
@@ -506,7 +476,9 @@ def create_gradio_ui():
|
|
| 506 |
quality_dropdown = gr.Dropdown(
|
| 507 |
label="Choose Quality", choices=["Low", "Mid", "High"], value="Low"
|
| 508 |
)
|
| 509 |
-
quality_dropdown.change(
|
|
|
|
|
|
|
| 510 |
kernel_dropdown = gr.Dropdown(
|
| 511 |
label="Choose Kernel", choices=["SINE", "GFF"], value="SINE"
|
| 512 |
)
|
|
@@ -530,9 +502,9 @@ def create_gradio_ui():
|
|
| 530 |
approx_button = gr.Button("Plot Approximation")
|
| 531 |
|
| 532 |
with gr.Row():
|
| 533 |
-
with gr.Column(
|
| 534 |
approx_plot = gr.Plot(visible=False)
|
| 535 |
-
with gr.Column(
|
| 536 |
error_plot = gr.Plot(visible=False)
|
| 537 |
|
| 538 |
approx_button.click(
|
|
|
|
| 13 |
|
| 14 |
n_x, n_t = 10, 10
|
| 15 |
|
| 16 |
+
|
| 17 |
def clear_npz():
|
| 18 |
current_directory = os.getcwd() # Get the current working directory
|
| 19 |
for filename in os.listdir(current_directory):
|
|
|
|
| 94 |
|
| 95 |
# Layout for the Plotly plot without controls
|
| 96 |
layout = go.Layout(
|
|
|
|
| 97 |
scene=dict(
|
| 98 |
camera=dict(
|
| 99 |
eye=dict(x=0, y=-2, z=0), # Front view
|
|
|
|
| 105 |
margin=dict(l=0, r=0, t=0, b=0), # Reduce margins
|
| 106 |
)
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
# Create the figure
|
| 109 |
fig = go.Figure(data=traces, layout=layout)
|
| 110 |
|
|
|
|
| 126 |
"tableRotation",
|
| 127 |
"toImage",
|
| 128 |
"resetCameraDefault3d",
|
| 129 |
+
],
|
| 130 |
+
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.01),
|
| 131 |
)
|
| 132 |
|
| 133 |
return fig
|
|
|
|
| 177 |
|
| 178 |
# Layout for the Plotly plot without controls
|
| 179 |
layout = go.Layout(
|
|
|
|
| 180 |
scene=dict(
|
| 181 |
camera=dict(
|
| 182 |
eye=dict(x=0, y=-2, z=0), # Front view
|
|
|
|
| 188 |
margin=dict(l=0, r=0, t=0, b=0), # Reduce margins
|
| 189 |
)
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
# Create the figure
|
| 192 |
fig = go.Figure(data=traces, layout=layout)
|
| 193 |
|
|
|
|
| 209 |
"tableRotation",
|
| 210 |
"toImage",
|
| 211 |
"resetCameraDefault3d",
|
| 212 |
+
],
|
| 213 |
+
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.01),
|
| 214 |
)
|
| 215 |
|
| 216 |
return fig
|
|
|
|
| 297 |
Phi = design_matrix(a_train, theta, kernel)
|
| 298 |
alpha = learn_coefficients(Phi, u_train)
|
| 299 |
# Validate and animate results
|
| 300 |
+
u_real = np.array(
|
| 301 |
+
[complex_heat_eq_solution(x, t_i, glob_k, glob_a, glob_b, glob_c) for t_i in t]
|
| 302 |
+
)
|
| 303 |
a_test = np.c_[np.meshgrid(x, t)[0].ravel(), np.meshgrid(x, t)[1].ravel()]
|
| 304 |
u_approx = approximate_solution(a_test, alpha, theta, kernel).reshape(n_t, n_x)
|
| 305 |
|
|
|
|
| 374 |
"tableRotation",
|
| 375 |
"toImage",
|
| 376 |
"resetCameraDefault3d",
|
| 377 |
+
],
|
| 378 |
+
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.01),
|
| 379 |
)
|
| 380 |
|
| 381 |
return fig
|
|
|
|
| 394 |
gr.update(visible=True, value=error_fig),
|
| 395 |
)
|
| 396 |
|
| 397 |
+
|
| 398 |
def change_quality(quality):
|
| 399 |
global n_x, n_t
|
| 400 |
+
|
| 401 |
if quality == "Low":
|
| 402 |
n_x, n_t = 10, 10
|
| 403 |
elif quality == "Mid":
|
| 404 |
n_x, n_t = 20, 20
|
| 405 |
elif quality == "High":
|
| 406 |
n_x, n_t = 40, 40
|
| 407 |
+
|
| 408 |
|
| 409 |
# Gradio interface
|
| 410 |
def create_gradio_ui():
|
| 411 |
global glob_k, glob_a, glob_b, glob_c
|
| 412 |
|
| 413 |
+
markdown_content = r"""
|
| 414 |
+
## Function:
|
| 415 |
+
$$
|
| 416 |
+
\begin{alignat*}{5}
|
| 417 |
+
u(x, t)
|
| 418 |
+
\coloneqq &\exp(-\textcolor{magenta}{k}(&\textcolor{cyan}{a}&\pi)^2t)\sin(&\textcolor{cyan}{a}&\pi x) \\
|
| 419 |
+
+ &\exp(-\textcolor{magenta}{k}(&\textcolor{lime}{b}&\pi)^2t)\sin(&\textcolor{lime}{b}&\pi x) \\
|
| 420 |
+
+ &\exp(-\textcolor{magenta}{k}(&\textcolor{orange}{c}&\pi)^2t)\sin(&\textcolor{orange}{c}&\pi x)
|
| 421 |
+
\end{alignat*}
|
| 422 |
+
$$
|
| 423 |
+
|
| 424 |
+
Adjust the values for <span style='color: magenta;'>k</span>, <span style='color: cyan;'>a</span>, <span style='color: lime;'>b</span> and <span style='color: orange;'>c</span> with the sliders below.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
# Get the initial available files
|
| 428 |
with gr.Blocks() as demo:
|
| 429 |
gr.Markdown("# Learn the Coefficients for the Heat Equation using the RFM")
|
| 430 |
|
| 431 |
# Function parameter inputs
|
| 432 |
+
gr.Markdown(markdown_content)
|
| 433 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
|
| 435 |
with gr.Row():
|
| 436 |
+
with gr.Column(min_width=500):
|
| 437 |
k_slider = gr.Slider(
|
| 438 |
minimum=0.0001, maximum=0.1, step=0.0001, value=glob_k, label="k"
|
| 439 |
)
|
|
|
|
| 446 |
c_slider = gr.Slider(
|
| 447 |
minimum=-10, maximum=10, step=0.1, value=glob_c, label="c"
|
| 448 |
)
|
| 449 |
+
with gr.Column(min_width=500):
|
| 450 |
+
plot_output = gr.Plot()
|
| 451 |
|
| 452 |
k_slider.change(
|
| 453 |
fn=plot_function,
|
|
|
|
| 476 |
quality_dropdown = gr.Dropdown(
|
| 477 |
label="Choose Quality", choices=["Low", "Mid", "High"], value="Low"
|
| 478 |
)
|
| 479 |
+
quality_dropdown.change(
|
| 480 |
+
fn=change_quality, inputs=quality_dropdown, outputs=None
|
| 481 |
+
)
|
| 482 |
kernel_dropdown = gr.Dropdown(
|
| 483 |
label="Choose Kernel", choices=["SINE", "GFF"], value="SINE"
|
| 484 |
)
|
|
|
|
| 502 |
approx_button = gr.Button("Plot Approximation")
|
| 503 |
|
| 504 |
with gr.Row():
|
| 505 |
+
with gr.Column(min_width=500):
|
| 506 |
approx_plot = gr.Plot(visible=False)
|
| 507 |
+
with gr.Column(min_width=500):
|
| 508 |
error_plot = gr.Plot(visible=False)
|
| 509 |
|
| 510 |
approx_button.click(
|