Spaces:
Sleeping
Sleeping
Lennard Schober
commited on
Commit
·
0bb1ad3
1
Parent(s):
c9793cd
Fix heat equation solution
Browse files
app.py
CHANGED
|
@@ -6,10 +6,10 @@ import numpy as np
|
|
| 6 |
import gradio as gr
|
| 7 |
import plotly.graph_objs as go
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
|
| 14 |
|
| 15 |
def clear_npz():
|
|
@@ -26,12 +26,13 @@ def clear_npz():
|
|
| 26 |
print(f"Failed to delete {file_path}. Reason: {e}")
|
| 27 |
|
| 28 |
|
| 29 |
-
def complex_heat_eq_solution(x, t, a=glob_a, b=glob_b, c=glob_c
|
| 30 |
-
global glob_a, glob_b, glob_c
|
|
|
|
| 31 |
return (
|
| 32 |
-
np.exp(-
|
| 33 |
-
+
|
| 34 |
-
+
|
| 35 |
)
|
| 36 |
|
| 37 |
|
|
@@ -358,15 +359,15 @@ def train_coefficients(m, kernel):
|
|
| 358 |
return f"Training completed in {time.time() - start_time:.2f} seconds. The average error is {avg_err}."
|
| 359 |
|
| 360 |
|
| 361 |
-
def plot_function(a, b, c
|
| 362 |
-
global glob_a, glob_b, glob_c
|
| 363 |
|
| 364 |
-
glob_a, glob_b, glob_c
|
| 365 |
|
| 366 |
x = np.linspace(0, 1, 100)
|
| 367 |
t = np.linspace(0, 5, 500)
|
| 368 |
X, T = np.meshgrid(x, t) # Create the mesh grid
|
| 369 |
-
Z = complex_heat_eq_solution(X, T,
|
| 370 |
|
| 371 |
traces = []
|
| 372 |
traces.append(
|
|
@@ -419,19 +420,25 @@ def plot_function(a, b, c, d, k=0.5):
|
|
| 419 |
|
| 420 |
return fig
|
| 421 |
|
|
|
|
| 422 |
def plot_all(m, kernel):
|
| 423 |
# Generate the plot content (replace this with your actual plot logic)
|
| 424 |
-
approx_fig = plot_heat_equation(
|
| 425 |
-
|
| 426 |
-
|
|
|
|
|
|
|
| 427 |
# Return the figures and make the plots visible
|
| 428 |
return (
|
| 429 |
gr.update(visible=True, value=approx_fig),
|
| 430 |
gr.update(visible=True, value=error_fig),
|
| 431 |
)
|
| 432 |
|
|
|
|
| 433 |
# Gradio interface
|
| 434 |
def create_gradio_ui():
|
|
|
|
|
|
|
| 435 |
# Get the initial available files
|
| 436 |
with gr.Blocks() as demo:
|
| 437 |
gr.Markdown("# Learn the Coefficients for the Heat Equation using the RFM")
|
|
@@ -447,39 +454,39 @@ def create_gradio_ui():
|
|
| 447 |
|
| 448 |
with gr.Row():
|
| 449 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
| 450 |
a_slider = gr.Slider(
|
| 451 |
-
minimum=-10, maximum
|
| 452 |
)
|
| 453 |
b_slider = gr.Slider(
|
| 454 |
-
minimum=-10, maximum=10, step=1, value
|
| 455 |
)
|
| 456 |
c_slider = gr.Slider(
|
| 457 |
-
minimum=-10, maximum
|
| 458 |
-
)
|
| 459 |
-
d_slider = gr.Slider(
|
| 460 |
-
minimum=-10, maximum=10, step=1, value=7, label="d"
|
| 461 |
)
|
| 462 |
|
| 463 |
plot_output = gr.Plot()
|
| 464 |
-
|
| 465 |
-
|
| 466 |
fn=plot_function,
|
| 467 |
-
inputs=[a_slider, b_slider, c_slider
|
| 468 |
outputs=[plot_output],
|
| 469 |
)
|
| 470 |
-
|
| 471 |
fn=plot_function,
|
| 472 |
-
inputs=[a_slider, b_slider, c_slider
|
| 473 |
outputs=[plot_output],
|
| 474 |
)
|
| 475 |
-
|
| 476 |
fn=plot_function,
|
| 477 |
-
inputs=[a_slider, b_slider, c_slider
|
| 478 |
outputs=[plot_output],
|
| 479 |
)
|
| 480 |
-
|
| 481 |
fn=plot_function,
|
| 482 |
-
inputs=[a_slider, b_slider, c_slider
|
| 483 |
outputs=[plot_output],
|
| 484 |
)
|
| 485 |
|
|
@@ -522,7 +529,7 @@ def create_gradio_ui():
|
|
| 522 |
demo.load(fn=clear_npz, inputs=None, outputs=None)
|
| 523 |
demo.load(
|
| 524 |
fn=plot_function,
|
| 525 |
-
inputs=[a_slider, b_slider, c_slider
|
| 526 |
outputs=[plot_output],
|
| 527 |
)
|
| 528 |
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
import plotly.graph_objs as go
|
| 8 |
|
| 9 |
+
glob_k = 0.0025
|
| 10 |
+
glob_a = -2.
|
| 11 |
+
glob_b = 4.
|
| 12 |
+
glob_c = 7.5
|
| 13 |
|
| 14 |
|
| 15 |
def clear_npz():
|
|
|
|
| 26 |
print(f"Failed to delete {file_path}. Reason: {e}")
|
| 27 |
|
| 28 |
|
| 29 |
+
def complex_heat_eq_solution(x, t, k=glob_k, a=glob_a, b=glob_b, c=glob_c):
|
| 30 |
+
global glob_k, glob_a, glob_b, glob_c
|
| 31 |
+
glob_k, glob_a, glob_b, glob_c = k, a, b, c
|
| 32 |
return (
|
| 33 |
+
np.exp(-glob_k * (glob_b * np.pi) ** 2 * t) * np.cos(glob_a * np.pi * x)
|
| 34 |
+
+ np.exp(-glob_k * (glob_b * np.pi) ** 2 * t) * np.sin(glob_b * np.pi * x)
|
| 35 |
+
+ np.exp(-glob_k * (glob_c * np.pi) ** 2 * t) * np.sin(glob_c * np.pi * x)
|
| 36 |
)
|
| 37 |
|
| 38 |
|
|
|
|
| 359 |
return f"Training completed in {time.time() - start_time:.2f} seconds. The average error is {avg_err}."
|
| 360 |
|
| 361 |
|
| 362 |
+
def plot_function(k, a, b, c):
|
| 363 |
+
global glob_k, glob_a, glob_b, glob_c
|
| 364 |
|
| 365 |
+
glob_k, glob_a, glob_b, glob_c = k, a, b, c
|
| 366 |
|
| 367 |
x = np.linspace(0, 1, 100)
|
| 368 |
t = np.linspace(0, 5, 500)
|
| 369 |
X, T = np.meshgrid(x, t) # Create the mesh grid
|
| 370 |
+
Z = complex_heat_eq_solution(X, T, glob_k, glob_a, glob_b, glob_c)
|
| 371 |
|
| 372 |
traces = []
|
| 373 |
traces.append(
|
|
|
|
| 420 |
|
| 421 |
return fig
|
| 422 |
|
| 423 |
+
|
| 424 |
def plot_all(m, kernel):
|
| 425 |
# Generate the plot content (replace this with your actual plot logic)
|
| 426 |
+
approx_fig = plot_heat_equation(
|
| 427 |
+
m, kernel
|
| 428 |
+
) # Replace with your function for approx_plot
|
| 429 |
+
error_fig = plot_errors(m, kernel) # Replace with your function for error_plot
|
| 430 |
+
|
| 431 |
# Return the figures and make the plots visible
|
| 432 |
return (
|
| 433 |
gr.update(visible=True, value=approx_fig),
|
| 434 |
gr.update(visible=True, value=error_fig),
|
| 435 |
)
|
| 436 |
|
| 437 |
+
|
| 438 |
# Gradio interface
|
| 439 |
def create_gradio_ui():
|
| 440 |
+
global glob_k, glob_a, glob_b, glob_c
|
| 441 |
+
|
| 442 |
# Get the initial available files
|
| 443 |
with gr.Blocks() as demo:
|
| 444 |
gr.Markdown("# Learn the Coefficients for the Heat Equation using the RFM")
|
|
|
|
| 454 |
|
| 455 |
with gr.Row():
|
| 456 |
with gr.Column():
|
| 457 |
+
k_slider = gr.Slider(
|
| 458 |
+
minimum=0.0001, maximum=0.1, step=0.0001, value=glob_k, label="k"
|
| 459 |
+
)
|
| 460 |
a_slider = gr.Slider(
|
| 461 |
+
minimum=-10, maximum=10, step=0.1, value=glob_a, label="a"
|
| 462 |
)
|
| 463 |
b_slider = gr.Slider(
|
| 464 |
+
minimum=-10, maximum=10, step=0.1, value=glob_b, label="b"
|
| 465 |
)
|
| 466 |
c_slider = gr.Slider(
|
| 467 |
+
minimum=-10, maximum=10, step=0.1, value=glob_c, label="c"
|
|
|
|
|
|
|
|
|
|
| 468 |
)
|
| 469 |
|
| 470 |
plot_output = gr.Plot()
|
| 471 |
+
|
| 472 |
+
k_slider.change(
|
| 473 |
fn=plot_function,
|
| 474 |
+
inputs=[k_slider, a_slider, b_slider, c_slider],
|
| 475 |
outputs=[plot_output],
|
| 476 |
)
|
| 477 |
+
a_slider.change(
|
| 478 |
fn=plot_function,
|
| 479 |
+
inputs=[k_slider, a_slider, b_slider, c_slider],
|
| 480 |
outputs=[plot_output],
|
| 481 |
)
|
| 482 |
+
b_slider.change(
|
| 483 |
fn=plot_function,
|
| 484 |
+
inputs=[k_slider, a_slider, b_slider, c_slider],
|
| 485 |
outputs=[plot_output],
|
| 486 |
)
|
| 487 |
+
c_slider.change(
|
| 488 |
fn=plot_function,
|
| 489 |
+
inputs=[k_slider, a_slider, b_slider, c_slider],
|
| 490 |
outputs=[plot_output],
|
| 491 |
)
|
| 492 |
|
|
|
|
| 529 |
demo.load(fn=clear_npz, inputs=None, outputs=None)
|
| 530 |
demo.load(
|
| 531 |
fn=plot_function,
|
| 532 |
+
inputs=[k_slider, a_slider, b_slider, c_slider],
|
| 533 |
outputs=[plot_output],
|
| 534 |
)
|
| 535 |
|