CodePDE / solvers /advection /solver_central_FDM.py
LDA1020's picture
feat: code release
56c4b9b verified
import numpy as np
def solver(u0_batch, t_coordinate, beta):
"""Solves the Advection equation for all times in t_coordinate.
Args:
u0_batch (np.ndarray): Initial condition [batch_size, N],
where batch_size is the number of different initial conditions,
and N is the number of spatial grid points.
t_coordinate (np.ndarray): Time coordinates of shape [T+1].
It begins with t_0=0 and follows the time steps t_1, ..., t_T.
beta (float): Constant advection speed.
Returns:
solutions (np.ndarray): Shape [batch_size, T+1, N].
solutions[:, 0, :] contains the initial conditions (u0_batch),
solutions[:, i, :] contains the solutions at time t_coordinate[i].
"""
# Try to auto-detect and use the best available backend
try:
import torch
use_torch = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using PyTorch backend on {device}")
except ImportError:
use_torch = False
try:
import jax
import jax.numpy as jnp
use_jax = True
print(f"Using JAX backend")
except ImportError:
use_jax = False
print(f"Using NumPy backend")
batch_size, N = u0_batch.shape
T = len(t_coordinate) - 1 # Number of time steps to output
# Spatial discretization
dx = 1.0 / N # Spatial step size assuming domain [0, 1]
# For beta = 0.1, optimize time step selection
# Determine a suitable internal time step based on CFL condition
# Using a safety factor of 0.8 for stability
cfl_factor = 0.8
dt_cfl = cfl_factor * dx / beta
# Determine the number of internal time steps needed
total_time = t_coordinate[-1]
n_internal_steps = int(np.ceil(total_time / dt_cfl))
dt_internal = total_time / n_internal_steps
print(f"Spatial step size (dx): {dx:.6f}")
print(f"Internal time step (dt): {dt_internal:.6f}")
print(f"Number of internal time steps: {n_internal_steps}")
print(f"CFL number: {beta * dt_internal / dx:.6f}")
# Initialize solutions array
solutions = np.zeros((batch_size, T+1, N))
solutions[:, 0, :] = u0_batch # Set initial condition
# Precompute coefficient for time stepping
coeff = beta * dt_internal / (2 * dx)
# Set up current state based on backend
if use_torch:
u_current = torch.tensor(u0_batch, dtype=torch.float32, device=device)
elif use_jax:
u_current = jnp.array(u0_batch, dtype=jnp.float32)
# JIT-compile the time stepping function for better performance
@jax.jit
def time_step(u):
# Central difference with periodic boundary conditions
u_prev = jnp.roll(u, 1, axis=1) # Shift right (i-1)
u_next = jnp.roll(u, -1, axis=1) # Shift left (i+1)
return u - coeff * (u_next - u_prev)
else: # numpy
u_current = u0_batch.copy()
# Time-stepping loop
current_time = 0.0
next_output_idx = 1 # Index for the next output time
for step in range(n_internal_steps):
# Calculate next time
current_time += dt_internal
# Apply time step based on backend
if use_torch:
# Central difference with periodic boundary conditions
u_prev = torch.roll(u_current, 1, dims=1) # Shift right (i-1)
u_next = torch.roll(u_current, -1, dims=1) # Shift left (i+1)
u_current = u_current - coeff * (u_next - u_prev)
elif use_jax:
u_current = time_step(u_current)
else: # numpy
# Central difference with periodic boundary conditions
u_prev = np.roll(u_current, 1, axis=1) # Shift right (i-1)
u_next = np.roll(u_current, -1, axis=1) # Shift left (i+1)
u_current = u_current - coeff * (u_next - u_prev)
# Check if we need to store the solution at this time
while next_output_idx <= T and current_time >= t_coordinate[next_output_idx]:
# Convert to numpy if needed
if use_torch:
solutions[:, next_output_idx, :] = u_current.cpu().numpy()
elif use_jax:
solutions[:, next_output_idx, :] = np.array(u_current)
else: # numpy
solutions[:, next_output_idx, :] = u_current
# Print progress periodically
if next_output_idx % 10 == 0 or next_output_idx == T:
print(f"Stored solution at time {t_coordinate[next_output_idx]:.4f} (Step {step+1}/{n_internal_steps})")
next_output_idx += 1
return solutions