Spaces:
Sleeping
Sleeping
Lennard Schober
commited on
Commit
·
d7ccd1d
1
Parent(s):
98baa22
Change wording
Browse files
app.py
CHANGED
|
@@ -46,7 +46,9 @@ def plot_heat_equation(m, approx_type, quality, rand_or_det):
|
|
| 46 |
new_nt = 1 * n_t
|
| 47 |
|
| 48 |
try:
|
| 49 |
-
loaded_values = np.load(
|
|
|
|
|
|
|
| 50 |
except:
|
| 51 |
raise gr.Error(f"First train the coefficients for {approx_type} and m = {m}")
|
| 52 |
alpha = loaded_values["alpha"]
|
|
@@ -141,7 +143,9 @@ def plot_errors(m, approx_type, quality, rand_or_det):
|
|
| 141 |
global n_x, n_t
|
| 142 |
|
| 143 |
try:
|
| 144 |
-
loaded_values = np.load(
|
|
|
|
|
|
|
| 145 |
except:
|
| 146 |
raise gr.Error(f"First train the coefficients for {approx_type} and m = {m}")
|
| 147 |
alpha = loaded_values["alpha"]
|
|
@@ -233,19 +237,21 @@ def generate_data():
|
|
| 233 |
return a_train, u_train, x, t
|
| 234 |
|
| 235 |
|
| 236 |
-
def features(a, theta_j,
|
| 237 |
-
"""Compute random features with adjustable
|
| 238 |
-
if
|
| 239 |
return np.sin(k * np.linalg.norm(a - theta_j, axis=-1) + eps)
|
| 240 |
-
elif
|
| 241 |
return np.log(np.linalg.norm(a - theta_j, axis=-1) + eps) / (2 * np.pi)
|
| 242 |
else:
|
| 243 |
-
raise ValueError("Unsupported
|
| 244 |
|
| 245 |
|
| 246 |
-
def design_matrix(a, theta,
|
| 247 |
"""Construct design matrix."""
|
| 248 |
-
return np.array(
|
|
|
|
|
|
|
| 249 |
|
| 250 |
|
| 251 |
def learn_coefficients(Phi, u):
|
|
@@ -253,13 +259,13 @@ def learn_coefficients(Phi, u):
|
|
| 253 |
return np.linalg.lstsq(Phi, u, rcond=None)[0]
|
| 254 |
|
| 255 |
|
| 256 |
-
def approximate_solution(a, alpha, theta,
|
| 257 |
"""Compute the approximation."""
|
| 258 |
-
Phi = design_matrix(a, theta,
|
| 259 |
return Phi @ alpha
|
| 260 |
|
| 261 |
|
| 262 |
-
def train_coefficients(m,
|
| 263 |
global glob_k, glob_a, glob_b, glob_c, n_x, n_t
|
| 264 |
# Start time for training
|
| 265 |
start_time = time.time()
|
|
@@ -278,22 +284,22 @@ def train_coefficients(m, kernel, quality, rand_or_det):
|
|
| 278 |
else:
|
| 279 |
theta = np.column_stack(
|
| 280 |
(
|
| 281 |
-
np.linspace(-
|
| 282 |
-
np.linspace(-5, 5, m), #
|
| 283 |
)
|
| 284 |
)
|
| 285 |
|
| 286 |
# Construct design matrix and learn coefficients
|
| 287 |
-
Phi = design_matrix(a_train, theta,
|
| 288 |
alpha = learn_coefficients(Phi, u_train)
|
| 289 |
|
| 290 |
end_time = f"{time.time() - start_time:.2f}"
|
| 291 |
|
| 292 |
# Save values to the npz folder
|
| 293 |
np.savez(
|
| 294 |
-
f"{
|
| 295 |
alpha=alpha,
|
| 296 |
-
|
| 297 |
Phi=Phi,
|
| 298 |
theta=theta,
|
| 299 |
)
|
|
@@ -310,7 +316,7 @@ def train_coefficients(m, kernel, quality, rand_or_det):
|
|
| 310 |
# U_approx = np.zeros_like(U_real)
|
| 311 |
# for i, xpos in enumerate(x_random):
|
| 312 |
# for j, tpos in enumerate(t_random):
|
| 313 |
-
# Phi_at_x_t = test_approx([xpos, tpos], theta,
|
| 314 |
# U_approx[j, i] = np.dot(Phi_at_x_t, alpha)
|
| 315 |
|
| 316 |
# Compute average error
|
|
@@ -384,13 +390,13 @@ def plot_function(k, a, b, c):
|
|
| 384 |
return fig
|
| 385 |
|
| 386 |
|
| 387 |
-
def plot_all(m,
|
| 388 |
# Generate the plot content (replace this with your actual plot logic)
|
| 389 |
approx_fig = plot_heat_equation(
|
| 390 |
-
m,
|
| 391 |
) # Replace with your function for approx_plot
|
| 392 |
error_fig = plot_errors(
|
| 393 |
-
m,
|
| 394 |
) # Replace with your function for error_plot
|
| 395 |
|
| 396 |
# Return the figures and make the plots visible
|
|
@@ -432,7 +438,7 @@ def create_gradio_ui():
|
|
| 432 |
$$
|
| 433 |
\argmin_{\alpha\in\mathbb{R}^m}\|{\alpha\Phi-u}\|_2^2,
|
| 434 |
$$
|
| 435 |
-
where $\Phi$ contains the features depending on the
|
| 436 |
"""
|
| 437 |
|
| 438 |
# Get the initial available files
|
|
@@ -440,7 +446,13 @@ def create_gradio_ui():
|
|
| 440 |
gr.Markdown("# Approximating a solution to the heat equation using RFM")
|
| 441 |
|
| 442 |
# Function parameter inputs
|
| 443 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
with gr.Row():
|
| 446 |
with gr.Column(min_width=500):
|
|
@@ -482,15 +494,15 @@ def create_gradio_ui():
|
|
| 482 |
|
| 483 |
with gr.Column():
|
| 484 |
with gr.Row():
|
| 485 |
-
#
|
| 486 |
quality_dropdown = gr.Dropdown(
|
| 487 |
label="Choose Quality", choices=["Low", "Mid", "High"], value="Low"
|
| 488 |
)
|
| 489 |
quality_dropdown.change(
|
| 490 |
fn=change_quality, inputs=quality_dropdown, outputs=None
|
| 491 |
)
|
| 492 |
-
|
| 493 |
-
label="Choose
|
| 494 |
)
|
| 495 |
m_slider = gr.Dropdown(
|
| 496 |
label="Number of Random Features (m)",
|
|
@@ -511,7 +523,12 @@ def create_gradio_ui():
|
|
| 511 |
# Function to trigger training and update dropdown
|
| 512 |
train_button.click(
|
| 513 |
fn=train_coefficients,
|
| 514 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
outputs=output,
|
| 516 |
)
|
| 517 |
approx_button = gr.Button("Plot Approximation")
|
|
@@ -524,7 +541,7 @@ def create_gradio_ui():
|
|
| 524 |
|
| 525 |
approx_button.click(
|
| 526 |
fn=plot_all,
|
| 527 |
-
inputs=[m_slider,
|
| 528 |
outputs=[approx_plot, error_plot],
|
| 529 |
)
|
| 530 |
|
|
|
|
| 46 |
new_nt = 1 * n_t
|
| 47 |
|
| 48 |
try:
|
| 49 |
+
loaded_values = np.load(
|
| 50 |
+
f"{approx_type}_m{m}_{str.lower(quality)}_{str.lower(rand_or_det)}.npz"
|
| 51 |
+
)
|
| 52 |
except:
|
| 53 |
raise gr.Error(f"First train the coefficients for {approx_type} and m = {m}")
|
| 54 |
alpha = loaded_values["alpha"]
|
|
|
|
| 143 |
global n_x, n_t
|
| 144 |
|
| 145 |
try:
|
| 146 |
+
loaded_values = np.load(
|
| 147 |
+
f"{approx_type}_m{m}_{str.lower(quality)}_{str.lower(rand_or_det)}.npz"
|
| 148 |
+
)
|
| 149 |
except:
|
| 150 |
raise gr.Error(f"First train the coefficients for {approx_type} and m = {m}")
|
| 151 |
alpha = loaded_values["alpha"]
|
|
|
|
| 237 |
return a_train, u_train, x, t
|
| 238 |
|
| 239 |
|
| 240 |
+
def features(a, theta_j, m, method="SINE", k=1, eps=1e-8):
|
| 241 |
+
"""Compute random features with adjustable method width."""
|
| 242 |
+
if method == "SINE":
|
| 243 |
return np.sin(k * np.linalg.norm(a - theta_j, axis=-1) + eps)
|
| 244 |
+
elif method == "GFF":
|
| 245 |
return np.log(np.linalg.norm(a - theta_j, axis=-1) + eps) / (2 * np.pi)
|
| 246 |
else:
|
| 247 |
+
raise ValueError("Unsupported method type!")
|
| 248 |
|
| 249 |
|
| 250 |
+
def design_matrix(a, theta, method):
|
| 251 |
"""Construct design matrix."""
|
| 252 |
+
return np.array(
|
| 253 |
+
[features(a, theta_j, theta.shape[0], method) for theta_j in theta]
|
| 254 |
+
).T
|
| 255 |
|
| 256 |
|
| 257 |
def learn_coefficients(Phi, u):
|
|
|
|
| 259 |
return np.linalg.lstsq(Phi, u, rcond=None)[0]
|
| 260 |
|
| 261 |
|
| 262 |
+
def approximate_solution(a, alpha, theta, method):
|
| 263 |
"""Compute the approximation."""
|
| 264 |
+
Phi = design_matrix(a, theta, method)
|
| 265 |
return Phi @ alpha
|
| 266 |
|
| 267 |
|
| 268 |
+
def train_coefficients(m, method, quality, rand_or_det):
|
| 269 |
global glob_k, glob_a, glob_b, glob_c, n_x, n_t
|
| 270 |
# Start time for training
|
| 271 |
start_time = time.time()
|
|
|
|
| 284 |
else:
|
| 285 |
theta = np.column_stack(
|
| 286 |
(
|
| 287 |
+
np.linspace(-0.5, 0.5, m), # Nonlinear spacing for x
|
| 288 |
+
np.linspace(-2.5, 2.5, m), # Nonlinear spacing for y
|
| 289 |
)
|
| 290 |
)
|
| 291 |
|
| 292 |
# Construct design matrix and learn coefficients
|
| 293 |
+
Phi = design_matrix(a_train, theta, method)
|
| 294 |
alpha = learn_coefficients(Phi, u_train)
|
| 295 |
|
| 296 |
end_time = f"{time.time() - start_time:.2f}"
|
| 297 |
|
| 298 |
# Save values to the npz folder
|
| 299 |
np.savez(
|
| 300 |
+
f"{method}_m{m}_{str.lower(quality)}_{str.lower(rand_or_det)}.npz",
|
| 301 |
alpha=alpha,
|
| 302 |
+
method=method,
|
| 303 |
Phi=Phi,
|
| 304 |
theta=theta,
|
| 305 |
)
|
|
|
|
| 316 |
# U_approx = np.zeros_like(U_real)
|
| 317 |
# for i, xpos in enumerate(x_random):
|
| 318 |
# for j, tpos in enumerate(t_random):
|
| 319 |
+
# Phi_at_x_t = test_approx([xpos, tpos], theta, method)
|
| 320 |
# U_approx[j, i] = np.dot(Phi_at_x_t, alpha)
|
| 321 |
|
| 322 |
# Compute average error
|
|
|
|
| 390 |
return fig
|
| 391 |
|
| 392 |
|
| 393 |
+
def plot_all(m, method, quality, rand_or_det):
|
| 394 |
# Generate the plot content (replace this with your actual plot logic)
|
| 395 |
approx_fig = plot_heat_equation(
|
| 396 |
+
m, method, quality, rand_or_det
|
| 397 |
) # Replace with your function for approx_plot
|
| 398 |
error_fig = plot_errors(
|
| 399 |
+
m, method, quality, rand_or_det
|
| 400 |
) # Replace with your function for error_plot
|
| 401 |
|
| 402 |
# Return the figures and make the plots visible
|
|
|
|
| 438 |
$$
|
| 439 |
\argmin_{\alpha\in\mathbb{R}^m}\|{\alpha\Phi-u}\|_2^2,
|
| 440 |
$$
|
| 441 |
+
where $\Phi$ contains the features depending on the method.
|
| 442 |
"""
|
| 443 |
|
| 444 |
# Get the initial available files
|
|
|
|
| 446 |
gr.Markdown("# Approximating a solution to the heat equation using RFM")
|
| 447 |
|
| 448 |
# Function parameter inputs
|
| 449 |
+
gr.Markdown(
|
| 450 |
+
markdown_content,
|
| 451 |
+
latex_delimiters=[
|
| 452 |
+
{"left": "$$", "right": "$$", "display": True},
|
| 453 |
+
{"left": "$", "right": "$", "display": False},
|
| 454 |
+
],
|
| 455 |
+
)
|
| 456 |
|
| 457 |
with gr.Row():
|
| 458 |
with gr.Column(min_width=500):
|
|
|
|
| 494 |
|
| 495 |
with gr.Column():
|
| 496 |
with gr.Row():
|
| 497 |
+
# method selection and slider for m
|
| 498 |
quality_dropdown = gr.Dropdown(
|
| 499 |
label="Choose Quality", choices=["Low", "Mid", "High"], value="Low"
|
| 500 |
)
|
| 501 |
quality_dropdown.change(
|
| 502 |
fn=change_quality, inputs=quality_dropdown, outputs=None
|
| 503 |
)
|
| 504 |
+
method_dropdown = gr.Dropdown(
|
| 505 |
+
label="Choose Method", choices=["SINE", "GFF"], value="SINE"
|
| 506 |
)
|
| 507 |
m_slider = gr.Dropdown(
|
| 508 |
label="Number of Random Features (m)",
|
|
|
|
| 523 |
# Function to trigger training and update dropdown
|
| 524 |
train_button.click(
|
| 525 |
fn=train_coefficients,
|
| 526 |
+
inputs=[
|
| 527 |
+
m_slider,
|
| 528 |
+
method_dropdown,
|
| 529 |
+
quality_dropdown,
|
| 530 |
+
rand_det_dropdown,
|
| 531 |
+
],
|
| 532 |
outputs=output,
|
| 533 |
)
|
| 534 |
approx_button = gr.Button("Plot Approximation")
|
|
|
|
| 541 |
|
| 542 |
approx_button.click(
|
| 543 |
fn=plot_all,
|
| 544 |
+
inputs=[m_slider, method_dropdown, quality_dropdown, rand_det_dropdown],
|
| 545 |
outputs=[approx_plot, error_plot],
|
| 546 |
)
|
| 547 |
|