Spaces:
Sleeping
Sleeping
Lennard Schober
commited on
Commit
·
c9793cd
1
Parent(s):
e1d92e1
Plot everything on page
Browse files
app.py
CHANGED
|
@@ -11,6 +11,7 @@ glob_b = -2
|
|
| 11 |
glob_c = -4
|
| 12 |
glob_d = 7
|
| 13 |
|
|
|
|
| 14 |
def clear_npz():
|
| 15 |
current_directory = os.getcwd() # Get the current working directory
|
| 16 |
for filename in os.listdir(current_directory):
|
|
@@ -128,7 +129,28 @@ def plot_heat_equation(m, approx_type):
|
|
| 128 |
# Create the figure
|
| 129 |
fig = go.Figure(data=traces, layout=layout)
|
| 130 |
|
| 131 |
-
fig.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
|
| 134 |
def plot_errors(m, approx_type):
|
|
@@ -213,7 +235,29 @@ def plot_errors(m, approx_type):
|
|
| 213 |
# Create the figure
|
| 214 |
fig = go.Figure(data=traces, layout=layout)
|
| 215 |
|
| 216 |
-
fig.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
def generate_data(n_x=32, n_t=50):
|
| 219 |
"""Generate training data."""
|
|
@@ -318,12 +362,12 @@ def plot_function(a, b, c, d, k=0.5):
|
|
| 318 |
global glob_a, glob_b, glob_c, glob_d
|
| 319 |
|
| 320 |
glob_a, glob_b, glob_c, glob_d = a, b, c, d
|
| 321 |
-
|
| 322 |
x = np.linspace(0, 1, 100)
|
| 323 |
t = np.linspace(0, 5, 500)
|
| 324 |
X, T = np.meshgrid(x, t) # Create the mesh grid
|
| 325 |
Z = complex_heat_eq_solution(X, T, a, b, c, d)
|
| 326 |
-
|
| 327 |
traces = []
|
| 328 |
traces.append(
|
| 329 |
go.Surface(
|
|
@@ -348,11 +392,10 @@ def plot_function(a, b, c, d, k=0.5):
|
|
| 348 |
),
|
| 349 |
margin=dict(l=0, r=0, t=0, b=0), # Reduce margins
|
| 350 |
)
|
| 351 |
-
|
| 352 |
# Create the figure
|
| 353 |
fig = go.Figure(data=traces, layout=layout)
|
| 354 |
|
| 355 |
-
# fig.show(config=config)
|
| 356 |
fig.update_layout(
|
| 357 |
modebar_remove=[
|
| 358 |
"pan",
|
|
@@ -370,12 +413,22 @@ def plot_function(a, b, c, d, k=0.5):
|
|
| 370 |
"orbitRotation",
|
| 371 |
"tableRotation",
|
| 372 |
"toImage",
|
| 373 |
-
"resetCameraDefault3d"
|
| 374 |
]
|
| 375 |
)
|
| 376 |
|
| 377 |
return fig
|
| 378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
|
| 380 |
# Gradio interface
|
| 381 |
def create_gradio_ui():
|
|
@@ -394,10 +447,18 @@ def create_gradio_ui():
|
|
| 394 |
|
| 395 |
with gr.Row():
|
| 396 |
with gr.Column():
|
| 397 |
-
a_slider = gr.Slider(
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
|
| 402 |
plot_output = gr.Plot()
|
| 403 |
|
|
@@ -437,7 +498,7 @@ def create_gradio_ui():
|
|
| 437 |
# Output to show status
|
| 438 |
output = gr.Textbox(label="Status", interactive=False)
|
| 439 |
|
| 440 |
-
with gr.
|
| 441 |
# Button to train coefficients
|
| 442 |
train_button = gr.Button("Train Coefficients")
|
| 443 |
# Function to trigger training and update dropdown
|
|
@@ -446,18 +507,24 @@ def create_gradio_ui():
|
|
| 446 |
inputs=[m_slider, kernel_dropdown],
|
| 447 |
outputs=output,
|
| 448 |
)
|
| 449 |
-
with gr.Row():
|
| 450 |
approx_button = gr.Button("Plot Approximation")
|
| 451 |
-
approx_button.click(
|
| 452 |
-
fn=plot_heat_equation, inputs=[m_slider, kernel_dropdown], outputs=None
|
| 453 |
-
)
|
| 454 |
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
demo.load(fn=clear_npz, inputs=None, outputs=None)
|
| 460 |
-
demo.load(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
|
| 462 |
return demo
|
| 463 |
|
|
|
|
| 11 |
glob_c = -4
|
| 12 |
glob_d = 7
|
| 13 |
|
| 14 |
+
|
| 15 |
def clear_npz():
|
| 16 |
current_directory = os.getcwd() # Get the current working directory
|
| 17 |
for filename in os.listdir(current_directory):
|
|
|
|
| 129 |
# Create the figure
|
| 130 |
fig = go.Figure(data=traces, layout=layout)
|
| 131 |
|
| 132 |
+
fig.update_layout(
|
| 133 |
+
modebar_remove=[
|
| 134 |
+
"pan",
|
| 135 |
+
"resetCameraLastSave",
|
| 136 |
+
"hoverClosest3d",
|
| 137 |
+
"hoverCompareCartesian",
|
| 138 |
+
"zoomIn",
|
| 139 |
+
"zoomOut",
|
| 140 |
+
"select2d",
|
| 141 |
+
"lasso2d",
|
| 142 |
+
"zoomIn2d",
|
| 143 |
+
"zoomOut2d",
|
| 144 |
+
"sendDataToCloud",
|
| 145 |
+
"zoom3d",
|
| 146 |
+
"orbitRotation",
|
| 147 |
+
"tableRotation",
|
| 148 |
+
"toImage",
|
| 149 |
+
"resetCameraDefault3d",
|
| 150 |
+
]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
return fig
|
| 154 |
|
| 155 |
|
| 156 |
def plot_errors(m, approx_type):
|
|
|
|
| 235 |
# Create the figure
|
| 236 |
fig = go.Figure(data=traces, layout=layout)
|
| 237 |
|
| 238 |
+
fig.update_layout(
|
| 239 |
+
modebar_remove=[
|
| 240 |
+
"pan",
|
| 241 |
+
"resetCameraLastSave",
|
| 242 |
+
"hoverClosest3d",
|
| 243 |
+
"hoverCompareCartesian",
|
| 244 |
+
"zoomIn",
|
| 245 |
+
"zoomOut",
|
| 246 |
+
"select2d",
|
| 247 |
+
"lasso2d",
|
| 248 |
+
"zoomIn2d",
|
| 249 |
+
"zoomOut2d",
|
| 250 |
+
"sendDataToCloud",
|
| 251 |
+
"zoom3d",
|
| 252 |
+
"orbitRotation",
|
| 253 |
+
"tableRotation",
|
| 254 |
+
"toImage",
|
| 255 |
+
"resetCameraDefault3d",
|
| 256 |
+
]
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return fig
|
| 260 |
+
|
| 261 |
|
| 262 |
def generate_data(n_x=32, n_t=50):
|
| 263 |
"""Generate training data."""
|
|
|
|
| 362 |
global glob_a, glob_b, glob_c, glob_d
|
| 363 |
|
| 364 |
glob_a, glob_b, glob_c, glob_d = a, b, c, d
|
| 365 |
+
|
| 366 |
x = np.linspace(0, 1, 100)
|
| 367 |
t = np.linspace(0, 5, 500)
|
| 368 |
X, T = np.meshgrid(x, t) # Create the mesh grid
|
| 369 |
Z = complex_heat_eq_solution(X, T, a, b, c, d)
|
| 370 |
+
|
| 371 |
traces = []
|
| 372 |
traces.append(
|
| 373 |
go.Surface(
|
|
|
|
| 392 |
),
|
| 393 |
margin=dict(l=0, r=0, t=0, b=0), # Reduce margins
|
| 394 |
)
|
| 395 |
+
|
| 396 |
# Create the figure
|
| 397 |
fig = go.Figure(data=traces, layout=layout)
|
| 398 |
|
|
|
|
| 399 |
fig.update_layout(
|
| 400 |
modebar_remove=[
|
| 401 |
"pan",
|
|
|
|
| 413 |
"orbitRotation",
|
| 414 |
"tableRotation",
|
| 415 |
"toImage",
|
| 416 |
+
"resetCameraDefault3d",
|
| 417 |
]
|
| 418 |
)
|
| 419 |
|
| 420 |
return fig
|
| 421 |
|
| 422 |
+
def plot_all(m, kernel):
|
| 423 |
+
# Generate the plot content (replace this with your actual plot logic)
|
| 424 |
+
approx_fig = plot_heat_equation(m, kernel) # Replace with your function for approx_plot
|
| 425 |
+
error_fig = plot_errors(m, kernel) # Replace with your function for error_plot
|
| 426 |
+
|
| 427 |
+
# Return the figures and make the plots visible
|
| 428 |
+
return (
|
| 429 |
+
gr.update(visible=True, value=approx_fig),
|
| 430 |
+
gr.update(visible=True, value=error_fig),
|
| 431 |
+
)
|
| 432 |
|
| 433 |
# Gradio interface
|
| 434 |
def create_gradio_ui():
|
|
|
|
| 447 |
|
| 448 |
with gr.Row():
|
| 449 |
with gr.Column():
|
| 450 |
+
a_slider = gr.Slider(
|
| 451 |
+
minimum=-10, maximum=-1, step=1, value=-2, label="a"
|
| 452 |
+
)
|
| 453 |
+
b_slider = gr.Slider(
|
| 454 |
+
minimum=-10, maximum=10, step=1, value=-2, label="b"
|
| 455 |
+
)
|
| 456 |
+
c_slider = gr.Slider(
|
| 457 |
+
minimum=-10, maximum=-1, step=1, value=-4, label="c"
|
| 458 |
+
)
|
| 459 |
+
d_slider = gr.Slider(
|
| 460 |
+
minimum=-10, maximum=10, step=1, value=7, label="d"
|
| 461 |
+
)
|
| 462 |
|
| 463 |
plot_output = gr.Plot()
|
| 464 |
|
|
|
|
| 498 |
# Output to show status
|
| 499 |
output = gr.Textbox(label="Status", interactive=False)
|
| 500 |
|
| 501 |
+
with gr.Column():
|
| 502 |
# Button to train coefficients
|
| 503 |
train_button = gr.Button("Train Coefficients")
|
| 504 |
# Function to trigger training and update dropdown
|
|
|
|
| 507 |
inputs=[m_slider, kernel_dropdown],
|
| 508 |
outputs=output,
|
| 509 |
)
|
|
|
|
| 510 |
approx_button = gr.Button("Plot Approximation")
|
|
|
|
|
|
|
|
|
|
| 511 |
|
| 512 |
+
with gr.Row():
|
| 513 |
+
approx_plot = gr.Plot(visible=False)
|
| 514 |
+
error_plot = gr.Plot(visible=False)
|
| 515 |
+
|
| 516 |
+
approx_button.click(
|
| 517 |
+
fn=plot_all,
|
| 518 |
+
inputs=[m_slider, kernel_dropdown],
|
| 519 |
+
outputs=[approx_plot, error_plot],
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
demo.load(fn=clear_npz, inputs=None, outputs=None)
|
| 523 |
+
demo.load(
|
| 524 |
+
fn=plot_function,
|
| 525 |
+
inputs=[a_slider, b_slider, c_slider, d_slider],
|
| 526 |
+
outputs=[plot_output],
|
| 527 |
+
)
|
| 528 |
|
| 529 |
return demo
|
| 530 |
|