|
|
from __future__ import annotations |
|
|
|
|
|
import logging |
|
|
import os |
|
|
import random |
|
|
import sys |
|
|
import time |
|
|
from functools import partial |
|
|
from math import ceil, exp, log |
|
|
import math as mt |
|
|
import numpy as np |
|
|
import hydra |
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
from scipy.interpolate import interp1d |
|
|
from jax import device_put, jit, lax, nn, random, scipy, vmap |
|
|
from omegaconf import DictConfig |
|
|
|
|
|
def interpolate_solution(u_fine, x_fine, t_fine, x_coarse, t_coarse): |
|
|
""" |
|
|
Interpolates the fine solution onto the coarse grid in both space and time. |
|
|
""" |
|
|
|
|
|
space_interp_func = interp1d(x_fine, u_fine, axis=2, kind='linear', fill_value="extrapolate") |
|
|
|
|
|
u_fine_interp_space = space_interp_func(x_coarse) |
|
|
|
|
|
|
|
|
time_interp_func = interp1d(t_fine, u_fine_interp_space, axis=1, kind='linear', fill_value="extrapolate") |
|
|
|
|
|
u_fine_interp = time_interp_func(t_coarse) |
|
|
|
|
|
return u_fine_interp |
|
|
|
|
|
|
|
|
def compute_error(coarse_tuple, fine_tuple): |
|
|
""" |
|
|
Computes the error between coarse and fine grid solutions by interpolating in both space and time. |
|
|
""" |
|
|
u_coarse, x_coarse, t_coarse = coarse_tuple |
|
|
u_fine, x_fine, t_fine = fine_tuple |
|
|
u_fine_interp = interpolate_solution(u_fine, x_fine, t_fine, x_coarse, t_coarse) |
|
|
|
|
|
|
|
|
print(u_coarse.shape) |
|
|
print(u_fine_interp.shape) |
|
|
error = np.linalg.norm(u_coarse - u_fine_interp) / np.sqrt(u_coarse.size) |
|
|
return error |
|
|
|
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(3, 4, 5, 6, 7, 8, 9)) |
|
|
def init_multi_HD( |
|
|
xc, |
|
|
yc, |
|
|
zc, |
|
|
numbers=10000, |
|
|
k_tot=10, |
|
|
init_key=2022, |
|
|
num_choise_k=2, |
|
|
if_renorm=False, |
|
|
umax=1.0e4, |
|
|
umin=1.0e-8, |
|
|
): |
|
|
""" |
|
|
:param xc: cell center coordinate |
|
|
:param mode: initial condition |
|
|
:return: 1D scalar function u at cell center |
|
|
""" |
|
|
|
|
|
def _pass(carry): |
|
|
return carry |
|
|
|
|
|
def select_A(carry): |
|
|
def _func(carry): |
|
|
return jnp.abs(carry) |
|
|
|
|
|
cond, value = carry |
|
|
value = lax.cond(cond == 1, _func, _pass, value) |
|
|
return cond, value |
|
|
|
|
|
def select_W(carry): |
|
|
def _window(carry): |
|
|
xx, val, xL, xR, trns = carry |
|
|
val = 0.5 * (jnp.tanh((xx - xL) / trns) - jnp.tanh((xx - xR) / trns)) |
|
|
return xx, val, xL, xR, trns |
|
|
|
|
|
cond, value, xx, xL, xR, trns = carry |
|
|
|
|
|
carry = xx, value, xL, xR, trns |
|
|
xx, value, xL, xR, trns = lax.cond(cond == 1, _window, _pass, carry) |
|
|
return cond, value, xx, xL, xR, trns |
|
|
|
|
|
def renormalize(carry): |
|
|
def _norm(carry): |
|
|
u, key = carry |
|
|
u -= jnp.min(u, axis=1, keepdims=True) |
|
|
u /= jnp.max(u, axis=1, keepdims=True) |
|
|
|
|
|
key, subkey = random.split(key) |
|
|
m_val = random.uniform( |
|
|
key, shape=[numbers], minval=mt.log(umin), maxval=mt.log(umax) |
|
|
) |
|
|
m_val = jnp.exp(m_val) |
|
|
key, subkey = random.split(key) |
|
|
b_val = random.uniform( |
|
|
key, shape=[numbers], minval=mt.log(umin), maxval=mt.log(umax) |
|
|
) |
|
|
b_val = jnp.exp(b_val) |
|
|
return u * m_val[:, None] + b_val[:, None], key |
|
|
|
|
|
cond, u, key = carry |
|
|
carry = u, key |
|
|
u, key = lax.cond(cond is True, _norm, _pass, carry) |
|
|
return cond, u, key |
|
|
|
|
|
assert numbers % jax.device_count() == 0, "numbers should be : GPUs x integer!!" |
|
|
|
|
|
key = random.PRNGKey(init_key) |
|
|
|
|
|
selected = random.randint( |
|
|
key, shape=[numbers, num_choise_k], minval=0, maxval=k_tot |
|
|
) |
|
|
selected = nn.one_hot(selected, k_tot, dtype=int).sum(axis=1) |
|
|
kk = jnp.pi * 2.0 * jnp.arange(1, k_tot + 1) * selected / (xc[-1] - xc[0]) |
|
|
amp = random.uniform(key, shape=[numbers, k_tot, 1]) |
|
|
key, subkey = random.split(key) |
|
|
|
|
|
phs = 2.0 * jnp.pi * random.uniform(key, shape=[numbers, k_tot, 1]) |
|
|
_u = amp * jnp.sin(kk[:, :, jnp.newaxis] * xc[jnp.newaxis, jnp.newaxis, :] + phs) |
|
|
_u = jnp.sum(_u, axis=1) |
|
|
|
|
|
|
|
|
cond = random.choice(key, 2, p=jnp.array([0.9, 0.1]), shape=([numbers])) |
|
|
carry = (cond, _u) |
|
|
|
|
|
cond, _u = vmap(select_A, 0, 0)(carry) |
|
|
sgn = random.choice(key, a=jnp.array([1, -1]), shape=([numbers, 1])) |
|
|
_u *= sgn |
|
|
|
|
|
|
|
|
key, subkey = random.split(key) |
|
|
cond = random.choice(key, 2, p=jnp.array([0.5, 0.5]), shape=([numbers])) |
|
|
_xc = jnp.repeat(xc[None, :], numbers, axis=0) |
|
|
mask = jnp.ones_like(_xc) |
|
|
xL = random.uniform(key, shape=([numbers]), minval=0.1, maxval=0.45) |
|
|
xR = random.uniform(key, shape=([numbers]), minval=0.55, maxval=0.9) |
|
|
trns = 0.01 * jnp.ones_like(cond) |
|
|
carry = cond, mask, _xc, xL, xR, trns |
|
|
cond, mask, _xc, xL, xR, trns = vmap(select_W, 0, 0)(carry) |
|
|
|
|
|
_u *= mask |
|
|
|
|
|
carry = if_renorm, _u, key |
|
|
_, _u, _ = renormalize(carry) |
|
|
|
|
|
return _u[..., None, None] |
|
|
|
|
|
|
|
|
def VLlimiter(a, b, c, alpha=2.0): |
|
|
return ( |
|
|
jnp.sign(c) |
|
|
* (0.5 + 0.5 * jnp.sign(a * b)) |
|
|
* jnp.minimum(alpha * jnp.minimum(jnp.abs(a), jnp.abs(b)), jnp.abs(c)) |
|
|
) |
|
|
|
|
|
def limiting_HD(u, if_second_order): |
|
|
_, nx, _, _ = u.shape |
|
|
uL, uR = u, u |
|
|
nx -= 4 |
|
|
|
|
|
du_L = u[:, 1 : nx + 3, :, :] - u[:, 0 : nx + 2, :, :] |
|
|
du_R = u[:, 2 : nx + 4, :, :] - u[:, 1 : nx + 3, :, :] |
|
|
du_M = (u[:, 2 : nx + 4, :, :] - u[:, 0 : nx + 2, :, :]) * 0.5 |
|
|
gradu = VLlimiter(du_L, du_R, du_M) * if_second_order |
|
|
|
|
|
uL = uL.at[:, 1 : nx + 3, :, :].set( |
|
|
u[:, 1 : nx + 3, :, :] - 0.5 * gradu |
|
|
) |
|
|
uR = uR.at[:, 1 : nx + 3, :, :].set( |
|
|
u[:, 1 : nx + 3, :, :] + 0.5 * gradu |
|
|
) |
|
|
|
|
|
uL = jnp.where(uL[0] > 0.0, uL, u) |
|
|
uL = jnp.where(uL[4] > 0.0, uL, u) |
|
|
uR = jnp.where(uR[0] > 0.0, uR, u) |
|
|
uR = jnp.where(uR[4] > 0.0, uR, u) |
|
|
|
|
|
return uL, uR |
|
|
|
|
|
def Courant_HD(u, dx, dy, dz, gamma): |
|
|
cs = jnp.sqrt(gamma * u[4] / u[0]) |
|
|
stability_adv_x = dx / (jnp.max(cs + jnp.abs(u[1])) + 1.0e-8) |
|
|
stability_adv_y = dy / (jnp.max(cs + jnp.abs(u[2])) + 1.0e-8) |
|
|
stability_adv_z = dz / (jnp.max(cs + jnp.abs(u[3])) + 1.0e-8) |
|
|
return jnp.min(jnp.array([stability_adv_x, stability_adv_y, stability_adv_z])) |
|
|
|
|
|
|
|
|
def Courant_vis_HD(dx, dy, dz, eta, zeta): |
|
|
|
|
|
visc = 4.0 / 3.0 * eta + zeta |
|
|
stability_dif_x = 0.5 * dx**2 / (visc + 1.0e-8) |
|
|
stability_dif_y = 0.5 * dy**2 / (visc + 1.0e-8) |
|
|
stability_dif_z = 0.5 * dz**2 / (visc + 1.0e-8) |
|
|
return jnp.min(jnp.array([stability_dif_x, stability_dif_y, stability_dif_z])) |
|
|
|
|
|
def bc_HD(u, mode): |
|
|
_, Nx, Ny, Nz = u.shape |
|
|
Nx -= 2 |
|
|
Ny -= 2 |
|
|
Nz -= 2 |
|
|
if mode == "periodic": |
|
|
|
|
|
u = u.at[:, 0:2, 2:-2, 2:-2].set(u[:, Nx - 2 : Nx, 2:-2, 2:-2]) |
|
|
u = u.at[:, 2:-2, 0:2, 2:-2].set(u[:, 2:-2, Ny - 2 : Ny, 2:-2]) |
|
|
u = u.at[:, 2:-2, 2:-2, 0:2].set(u[:, 2:-2, 2:-2, Nz - 2 : Nz]) |
|
|
u = u.at[:, Nx : Nx + 2, 2:-2, 2:-2].set(u[:, 2:4, 2:-2, 2:-2]) |
|
|
u = u.at[:, 2:-2, Ny : Ny + 2, 2:-2].set(u[:, 2:-2, 2:4, 2:-2]) |
|
|
u = u.at[:, 2:-2, 2:-2, Nz : Nz + 2].set(u[:, 2:-2, 2:-2, 2:4]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif mode == "trans": |
|
|
|
|
|
u = u.loc[:, 0, 2:-2, 2:-2].set(u[:, 3, 2:-2, 2:-2]) |
|
|
u = u.loc[:, 2:-2, 0, 2:-2].set(u[:, 2:-2, 3, 2:-2]) |
|
|
u = u.loc[:, 2:-2, 2:-2, 0].set(u[:, 2:-2, 2:-2, 3]) |
|
|
u = u.loc[:, 1, 2:-2, 2:-2].set(u[:, 2, 2:-2, 2:-2]) |
|
|
u = u.loc[:, 2:-2, 1, 2:-2].set(u[:, 2:-2, 2, 2:-2]) |
|
|
u = u.loc[:, 2:-2, 2:-2, 1].set(u[:, 2:-2, 2:-2, 2]) |
|
|
|
|
|
u = u.loc[:, -2, 2:-2, 2:-2].set(u[:, -3, 2:-2, 2:-2]) |
|
|
u = u.loc[:, 2:-2, -2, 2:-2].set(u[:, 2:-2, -3, 2:-2]) |
|
|
u = u.loc[:, 2:-2, 2:-2, -2].set(u[:, 2:-2, 2:-2, -3]) |
|
|
u = u.loc[:, -1, 2:-2, 2:-2].set(u[:, -4, 2:-2, 2:-2]) |
|
|
u = u.loc[:, 2:-2, -1, 2:-2].set(u[:, 2:-2, -4, 2:-2]) |
|
|
u = u.loc[:, 2:-2, 2:-2, -1].set(u[:, 2:-2, 2:-2, -4]) |
|
|
elif mode == "KHI": |
|
|
|
|
|
u = u.loc[:, 0:2, 2:-2, 2:-2].set(u[:, Nx - 2 : Nx, 2:-2, 2:-2]) |
|
|
u = u.loc[:, 2:-2, 0, 2:-2].set(u[:, 2:-2, 3, 2:-2]) |
|
|
u = u.loc[:, 2:-2, 2:-2, 0].set(u[:, 2:-2, 2:-2, 3]) |
|
|
u = u.loc[:, 2:-2, 1, 2:-2].set(u[:, 2:-2, 2, 2:-2]) |
|
|
u = u.loc[:, 2:-2, 2:-2, 1].set(u[:, 2:-2, 2:-2, 2]) |
|
|
|
|
|
u = u.loc[:, Nx : Nx + 2, 2:-2, 2:-2].set(u[:, 2:4, 2:-2, 2:-2]) |
|
|
u = u.loc[:, 2:-2, -2, 2:-2].set(u[:, 2:-2, -3, 2:-2]) |
|
|
u = u.loc[:, 2:-2, 2:-2, -2].set(u[:, 2:-2, 2:-2, -3]) |
|
|
u = u.loc[:, 2:-2, -1, 2:-2].set(u[:, 2:-2, -4, 2:-2]) |
|
|
u = u.loc[:, 2:-2, 2:-2, -1].set(u[:, 2:-2, 2:-2, -4]) |
|
|
return u |
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" |
|
|
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".9" |
|
|
|
|
|
sys.path.append("..") |
|
|
|
|
|
|
|
|
def _pass(carry): |
|
|
return carry |
|
|
|
|
|
|
|
|
|
|
|
def run_step(cfg, nx, dt_save): |
|
|
|
|
|
ny = 1 |
|
|
nz = 1 |
|
|
gamma = cfg.args.gamma |
|
|
gammi1 = gamma - 1.0 |
|
|
gamminv1 = 1.0 / gammi1 |
|
|
gamgamm1inv = gamma * gamminv1 |
|
|
gammi1 = gamma - 1.0 |
|
|
|
|
|
BCs = ["trans", "periodic", "KHI"] |
|
|
assert cfg.args.bc in BCs, "bc should be in 'trans, reflect, periodic'" |
|
|
|
|
|
dx = (cfg.args.xR - cfg.args.xL) / nx |
|
|
dx_inv = 1.0 / dx |
|
|
dy = (cfg.args.yR - cfg.args.yL) / ny |
|
|
dy_inv = 1.0 / dy |
|
|
dz = (cfg.args.zR - cfg.args.zL) / nz |
|
|
dz_inv = 1.0 / dz |
|
|
|
|
|
|
|
|
xe = jnp.linspace(cfg.args.xL, cfg.args.xR, nx + 1) |
|
|
ye = jnp.linspace(cfg.args.yL, cfg.args.yR, ny + 1) |
|
|
ze = jnp.linspace(cfg.args.zL, cfg.args.zR, nz + 1) |
|
|
|
|
|
xc = xe[:-1] + 0.5 * dx |
|
|
yc = ye[:-1] + 0.5 * dy |
|
|
zc = ze[:-1] + 0.5 * dz |
|
|
|
|
|
show_steps = cfg.args.show_steps |
|
|
ini_time = cfg.args.ini_time |
|
|
fin_time = cfg.args.fin_time |
|
|
|
|
|
|
|
|
it_tot = ceil((fin_time - ini_time) / dt_save) + 1 |
|
|
tc = jnp.arange(it_tot + 1) * dt_save |
|
|
|
|
|
|
|
|
if cfg.args.if_rand_param: |
|
|
zeta = exp( |
|
|
random.uniform(log(0.001), log(10)) |
|
|
) |
|
|
eta = exp( |
|
|
random.uniform(log(0.001), log(10)) |
|
|
) |
|
|
else: |
|
|
zeta = cfg.args.zeta |
|
|
eta = cfg.args.eta |
|
|
logger.info(f"zeta: {zeta:>5f}, eta: {eta:>5f}") |
|
|
visc = zeta + eta / 3.0 |
|
|
|
|
|
def evolve(Q): |
|
|
t = ini_time |
|
|
tsave = t |
|
|
steps = 0 |
|
|
i_save = 0 |
|
|
dt = 0.0 |
|
|
|
|
|
tm_ini = time.time() |
|
|
|
|
|
DDD = jnp.zeros([it_tot, nx, ny, nz]) |
|
|
VVx = jnp.zeros([it_tot, nx, ny, nz]) |
|
|
VVy = jnp.zeros([it_tot, nx, ny, nz]) |
|
|
VVz = jnp.zeros([it_tot, nx, ny, nz]) |
|
|
PPP = jnp.zeros([it_tot, nx, ny, nz]) |
|
|
|
|
|
DDD = DDD.at[0].set(Q[0, 2:-2, 2:-2, 2:-2]) |
|
|
VVx = VVx.at[0].set(Q[1, 2:-2, 2:-2, 2:-2]) |
|
|
VVy = VVy.at[0].set(Q[2, 2:-2, 2:-2, 2:-2]) |
|
|
VVz = VVz.at[0].set(Q[3, 2:-2, 2:-2, 2:-2]) |
|
|
PPP = PPP.at[0].set(Q[4, 2:-2, 2:-2, 2:-2]) |
|
|
|
|
|
cond_fun = lambda x: x[0] < fin_time |
|
|
|
|
|
def _body_fun(carry): |
|
|
def _save(_carry): |
|
|
Q, tsave, i_save, DDD, VVx, VVy, VVz, PPP = _carry |
|
|
|
|
|
DDD = DDD.at[i_save].set(Q[0, 2:-2, 2:-2, 2:-2]) |
|
|
VVx = VVx.at[i_save].set(Q[1, 2:-2, 2:-2, 2:-2]) |
|
|
VVy = VVy.at[i_save].set(Q[2, 2:-2, 2:-2, 2:-2]) |
|
|
VVz = VVz.at[i_save].set(Q[3, 2:-2, 2:-2, 2:-2]) |
|
|
PPP = PPP.at[i_save].set(Q[4, 2:-2, 2:-2, 2:-2]) |
|
|
|
|
|
tsave += dt_save |
|
|
i_save += 1 |
|
|
return (Q, tsave, i_save, DDD, VVx, VVy, VVz, PPP) |
|
|
|
|
|
t, tsave, steps, i_save, dt, Q, DDD, VVx, VVy, VVz, PPP = carry |
|
|
|
|
|
|
|
|
carry = (Q, tsave, i_save, DDD, VVx, VVy, VVz, PPP) |
|
|
Q, tsave, i_save, DDD, VVx, VVy, VVz, PPP = lax.cond( |
|
|
t >= tsave, _save, _pass, carry |
|
|
) |
|
|
|
|
|
carry = (Q, t, dt, steps, tsave) |
|
|
Q, t, dt, steps, tsave = lax.fori_loop(0, show_steps, simulation_fn, carry) |
|
|
|
|
|
return (t, tsave, steps, i_save, dt, Q, DDD, VVx, VVy, VVz, PPP) |
|
|
|
|
|
carry = t, tsave, steps, i_save, dt, Q, DDD, VVx, VVy, VVz, PPP |
|
|
t, tsave, steps, i_save, dt, Q, DDD, VVx, VVy, VVz, PPP = lax.while_loop( |
|
|
cond_fun, _body_fun, carry |
|
|
) |
|
|
|
|
|
tm_fin = time.time() |
|
|
logger.info(f"total elapsed time is {tm_fin - tm_ini} sec") |
|
|
DDD = DDD.at[-1].set(Q[0, 2:-2, 2:-2, 2:-2]) |
|
|
VVx = VVx.at[-1].set(Q[1, 2:-2, 2:-2, 2:-2]) |
|
|
VVy = VVy.at[-1].set(Q[2, 2:-2, 2:-2, 2:-2]) |
|
|
VVz = VVz.at[-1].set(Q[3, 2:-2, 2:-2, 2:-2]) |
|
|
PPP = PPP.at[-1].set(Q[4, 2:-2, 2:-2, 2:-2]) |
|
|
return t, DDD, VVx, VVy, VVz, PPP |
|
|
|
|
|
@jit |
|
|
def simulation_fn(i, carry): |
|
|
Q, t, dt, steps, tsave = carry |
|
|
dt = ( |
|
|
Courant_HD(Q[:, 2:-2, 2:-2, 2:-2], dx, dy, dz, cfg.args.gamma) |
|
|
* cfg.args.CFL |
|
|
) |
|
|
dt = jnp.min(jnp.array([dt, cfg.args.fin_time - t, tsave - t])) |
|
|
|
|
|
def _update(carry): |
|
|
Q, dt = carry |
|
|
|
|
|
|
|
|
Q_tmp = bc_HD( |
|
|
Q, mode=cfg.args.bc |
|
|
) |
|
|
Q_tmp = update(Q, Q_tmp, dt * 0.5) |
|
|
|
|
|
Q_tmp = bc_HD( |
|
|
Q_tmp, mode=cfg.args.bc |
|
|
) |
|
|
Q = update(Q, Q_tmp, dt) |
|
|
|
|
|
|
|
|
dt_vis = Courant_vis_HD(dx, dy, dz, eta, zeta) * cfg.args.CFL |
|
|
dt_vis = jnp.min(jnp.array([dt_vis, dt])) |
|
|
t_vis = 0.0 |
|
|
|
|
|
carry = Q, dt, dt_vis, t_vis |
|
|
Q, dt, dt_vis, t_vis = lax.while_loop( |
|
|
lambda x: x[1] - x[3] > 1.0e-8, update_vis, carry |
|
|
) |
|
|
return Q, dt |
|
|
|
|
|
carry = Q, dt |
|
|
Q, dt = lax.cond(dt > 1.0e-8, _update, _pass, carry) |
|
|
|
|
|
t += dt |
|
|
steps += 1 |
|
|
return Q, t, dt, steps, tsave |
|
|
|
|
|
@jit |
|
|
def update(Q, Q_tmp, dt): |
|
|
|
|
|
D0 = Q[0] |
|
|
Mx = Q[1] * Q[0] |
|
|
My = Q[2] * Q[0] |
|
|
Mz = Q[3] * Q[0] |
|
|
E0 = Q[4] * gamminv1 + 0.5 * (Mx * Q[1] + My * Q[2] + Mz * Q[3]) |
|
|
|
|
|
D0 = D0[2:-2, 2:-2, 2:-2] |
|
|
Mx = Mx[2:-2, 2:-2, 2:-2] |
|
|
My = My[2:-2, 2:-2, 2:-2] |
|
|
Mz = Mz[2:-2, 2:-2, 2:-2] |
|
|
E0 = E0[2:-2, 2:-2, 2:-2] |
|
|
|
|
|
|
|
|
fx = flux_x(Q_tmp) |
|
|
fy = flux_y(Q_tmp) |
|
|
fz = flux_z(Q_tmp) |
|
|
|
|
|
|
|
|
dtdx, dtdy, dtdz = dt * dx_inv, dt * dy_inv, dt * dz_inv |
|
|
D0 -= ( |
|
|
dtdx * (fx[0, 1:, 2:-2, 2:-2] - fx[0, :-1, 2:-2, 2:-2]) |
|
|
+ dtdy * (fy[0, 2:-2, 1:, 2:-2] - fy[0, 2:-2, :-1, 2:-2]) |
|
|
+ dtdz * (fz[0, 2:-2, 2:-2, 1:] - fz[0, 2:-2, 2:-2, :-1]) |
|
|
) |
|
|
|
|
|
Mx -= ( |
|
|
dtdx * (fx[1, 1:, 2:-2, 2:-2] - fx[1, :-1, 2:-2, 2:-2]) |
|
|
+ dtdy * (fy[1, 2:-2, 1:, 2:-2] - fy[1, 2:-2, :-1, 2:-2]) |
|
|
+ dtdz * (fz[1, 2:-2, 2:-2, 1:] - fz[1, 2:-2, 2:-2, :-1]) |
|
|
) |
|
|
|
|
|
My -= ( |
|
|
dtdx * (fx[2, 1:, 2:-2, 2:-2] - fx[2, :-1, 2:-2, 2:-2]) |
|
|
+ dtdy * (fy[2, 2:-2, 1:, 2:-2] - fy[2, 2:-2, :-1, 2:-2]) |
|
|
+ dtdz * (fz[2, 2:-2, 2:-2, 1:] - fz[2, 2:-2, 2:-2, :-1]) |
|
|
) |
|
|
|
|
|
Mz -= ( |
|
|
dtdx * (fx[3, 1:, 2:-2, 2:-2] - fx[3, :-1, 2:-2, 2:-2]) |
|
|
+ dtdy * (fy[3, 2:-2, 1:, 2:-2] - fy[3, 2:-2, :-1, 2:-2]) |
|
|
+ dtdz * (fz[3, 2:-2, 2:-2, 1:] - fz[3, 2:-2, 2:-2, :-1]) |
|
|
) |
|
|
|
|
|
E0 -= ( |
|
|
dtdx * (fx[4, 1:, 2:-2, 2:-2] - fx[4, :-1, 2:-2, 2:-2]) |
|
|
+ dtdy * (fy[4, 2:-2, 1:, 2:-2] - fy[4, 2:-2, :-1, 2:-2]) |
|
|
+ dtdz * (fz[4, 2:-2, 2:-2, 1:] - fz[4, 2:-2, 2:-2, :-1]) |
|
|
) |
|
|
|
|
|
|
|
|
Q = Q.at[0, 2:-2, 2:-2, 2:-2].set(D0) |
|
|
Q = Q.at[1, 2:-2, 2:-2, 2:-2].set(Mx / D0) |
|
|
Q = Q.at[2, 2:-2, 2:-2, 2:-2].set(My / D0) |
|
|
Q = Q.at[3, 2:-2, 2:-2, 2:-2].set(Mz / D0) |
|
|
Q = Q.at[4, 2:-2, 2:-2, 2:-2].set( |
|
|
gammi1 * (E0 - 0.5 * (Mx**2 + My**2 + Mz**2) / D0) |
|
|
) |
|
|
return Q.at[4].set(jnp.where(Q[4] > 1.0e-8, Q[4], cfg.args.p_floor)) |
|
|
|
|
|
@jit |
|
|
def update_vis(carry): |
|
|
def _update_vis_x(carry): |
|
|
Q, dt = carry |
|
|
|
|
|
D0 = Q[0] |
|
|
Mx = Q[1] * D0 |
|
|
My = Q[2] * D0 |
|
|
Mz = Q[3] * D0 |
|
|
E0 = Q[4] * gamminv1 + 0.5 * (Mx * Q[1] + My * Q[2] + Mz * Q[3]) |
|
|
|
|
|
|
|
|
dtdx = dt * dx_inv |
|
|
|
|
|
Dm = 0.5 * (D0[2:-1, 2:-2, 2:-2] + D0[1:-2, 2:-2, 2:-2]) |
|
|
|
|
|
fMx = ( |
|
|
(eta + visc) |
|
|
* Dm |
|
|
* dx_inv |
|
|
* (Q[1, 2:-1, 2:-2, 2:-2] - Q[1, 1:-2, 2:-2, 2:-2]) |
|
|
) |
|
|
fMy = eta * Dm * dx_inv * (Q[2, 2:-1, 2:-2, 2:-2] - Q[2, 1:-2, 2:-2, 2:-2]) |
|
|
fMz = eta * Dm * dx_inv * (Q[3, 2:-1, 2:-2, 2:-2] - Q[3, 1:-2, 2:-2, 2:-2]) |
|
|
fE = 0.5 * (eta + visc) * Dm * dx_inv * ( |
|
|
Q[1, 2:-1, 2:-2, 2:-2] ** 2 - Q[1, 1:-2, 2:-2, 2:-2] ** 2 |
|
|
) + 0.5 * eta * Dm * dx_inv * ( |
|
|
(Q[2, 2:-1, 2:-2, 2:-2] ** 2 - Q[2, 1:-2, 2:-2, 2:-2] ** 2) |
|
|
+ (Q[3, 2:-1, 2:-2, 2:-2] ** 2 - Q[3, 1:-2, 2:-2, 2:-2] ** 2) |
|
|
) |
|
|
|
|
|
D0 = D0[2:-2, 2:-2, 2:-2] |
|
|
Mx = Mx[2:-2, 2:-2, 2:-2] |
|
|
My = My[2:-2, 2:-2, 2:-2] |
|
|
Mz = Mz[2:-2, 2:-2, 2:-2] |
|
|
E0 = E0[2:-2, 2:-2, 2:-2] |
|
|
|
|
|
Mx += dtdx * (fMx[1:, :, :] - fMx[:-1, :, :]) |
|
|
My += dtdx * (fMy[1:, :, :] - fMy[:-1, :, :]) |
|
|
Mz += dtdx * (fMz[1:, :, :] - fMz[:-1, :, :]) |
|
|
E0 += dtdx * (fE[1:, :, :] - fE[:-1, :, :]) |
|
|
|
|
|
|
|
|
Q = Q.at[1, 2:-2, 2:-2, 2:-2].set(Mx / D0) |
|
|
Q = Q.at[2, 2:-2, 2:-2, 2:-2].set(My / D0) |
|
|
Q = Q.at[3, 2:-2, 2:-2, 2:-2].set(Mz / D0) |
|
|
Q = Q.at[4, 2:-2, 2:-2, 2:-2].set( |
|
|
gammi1 * (E0 - 0.5 * (Mx**2 + My**2 + Mz**2) / D0) |
|
|
) |
|
|
|
|
|
return Q, dt |
|
|
|
|
|
def _update_vis_y(carry): |
|
|
Q, dt = carry |
|
|
|
|
|
D0 = Q[0] |
|
|
Mx = Q[1] * D0 |
|
|
My = Q[2] * D0 |
|
|
Mz = Q[3] * D0 |
|
|
E0 = Q[4] * gamminv1 + 0.5 * (Mx * Q[1] + My * Q[2] + Mz * Q[3]) |
|
|
|
|
|
|
|
|
dtdy = dt * dy_inv |
|
|
|
|
|
Dm = 0.5 * (D0[2:-2, 2:-1, 2:-2] + D0[2:-2, 1:-2, 2:-2]) |
|
|
|
|
|
fMx = eta * Dm * dy_inv * (Q[1, 2:-2, 2:-1, 2:-2] - Q[1, 2:-2, 1:-2, 2:-2]) |
|
|
fMy = ( |
|
|
(eta + visc) |
|
|
* Dm |
|
|
* dy_inv |
|
|
* (Q[2, 2:-2, 2:-1, 2:-2] - Q[2, 2:-2, 1:-2, 2:-2]) |
|
|
) |
|
|
fMz = eta * Dm * dy_inv * (Q[3, 2:-2, 2:-1, 2:-2] - Q[3, 2:-2, 1:-2, 2:-2]) |
|
|
fE = 0.5 * (eta + visc) * Dm * dy_inv * ( |
|
|
Q[2, 2:-2, 2:-1, 2:-2] ** 2 - Q[2, 2:-2, 1:-2, 2:-2] ** 2 |
|
|
) + 0.5 * eta * Dm * dy_inv * ( |
|
|
(Q[3, 2:-2, 2:-1, 2:-2] ** 2 - Q[3, 2:-2, 1:-2, 2:-2] ** 2) |
|
|
+ (Q[1, 2:-2, 2:-1, 2:-2] ** 2 - Q[1, 2:-2, 1:-2, 2:-2] ** 2) |
|
|
) |
|
|
|
|
|
D0 = D0[2:-2, 2:-2, 2:-2] |
|
|
Mx = Mx[2:-2, 2:-2, 2:-2] |
|
|
My = My[2:-2, 2:-2, 2:-2] |
|
|
Mz = Mz[2:-2, 2:-2, 2:-2] |
|
|
E0 = E0[2:-2, 2:-2, 2:-2] |
|
|
|
|
|
Mx += dtdy * (fMx[:, 1:, :] - fMx[:, :-1, :]) |
|
|
My += dtdy * (fMy[:, 1:, :] - fMy[:, :-1, :]) |
|
|
Mz += dtdy * (fMz[:, 1:, :] - fMz[:, :-1, :]) |
|
|
E0 += dtdy * (fE[:, 1:, :] - fE[:, :-1, :]) |
|
|
|
|
|
|
|
|
Q = Q.at[1, 2:-2, 2:-2, 2:-2].set(Mx / D0) |
|
|
Q = Q.at[2, 2:-2, 2:-2, 2:-2].set(My / D0) |
|
|
Q = Q.at[3, 2:-2, 2:-2, 2:-2].set(Mz / D0) |
|
|
Q = Q.at[4, 2:-2, 2:-2, 2:-2].set( |
|
|
gammi1 * (E0 - 0.5 * (Mx**2 + My**2 + Mz**2) / D0) |
|
|
) |
|
|
|
|
|
return Q, dt |
|
|
|
|
|
def _update_vis_z(carry): |
|
|
Q, dt = carry |
|
|
|
|
|
D0 = Q[0] |
|
|
Mx = Q[1] * D0 |
|
|
My = Q[2] * D0 |
|
|
Mz = Q[3] * D0 |
|
|
E0 = Q[4] * gamminv1 + 0.5 * (Mx * Q[1] + My * Q[2] + Mz * Q[3]) |
|
|
|
|
|
|
|
|
dtdz = dt * dz_inv |
|
|
|
|
|
Dm = 0.5 * (D0[2:-2, 2:-2, 2:-1] + D0[2:-2, 2:-2, 1:-2]) |
|
|
|
|
|
fMx = eta * Dm * dz_inv * (Q[1, 2:-2, 2:-2, 2:-1] - Q[1, 2:-2, 2:-2, 1:-2]) |
|
|
fMy = eta * Dm * dz_inv * (Q[2, 2:-2, 2:-2, 2:-1] - Q[2, 2:-2, 2:-2, 1:-2]) |
|
|
fMz = ( |
|
|
(eta + visc) |
|
|
* Dm |
|
|
* dz_inv |
|
|
* (Q[3, 2:-2, 2:-2, 2:-1] - Q[3, 2:-2, 2:-2, 1:-2]) |
|
|
) |
|
|
fE = 0.5 * (eta + visc) * Dm * dz_inv * ( |
|
|
Q[3, 2:-2, 2:-2, 2:-1] ** 2 - Q[3, 2:-2, 2:-2, 1:-2] ** 2 |
|
|
) + 0.5 * eta * Dm * dz_inv * ( |
|
|
(Q[1, 2:-2, 2:-2, 2:-1] ** 2 - Q[1, 2:-2, 2:-2, 1:-2] ** 2) |
|
|
+ (Q[2, 2:-2, 2:-2, 2:-1] ** 2 - Q[2, 2:-2, 2:-2, 1:-2] ** 2) |
|
|
) |
|
|
|
|
|
D0 = D0[2:-2, 2:-2, 2:-2] |
|
|
Mx = Mx[2:-2, 2:-2, 2:-2] |
|
|
My = My[2:-2, 2:-2, 2:-2] |
|
|
Mz = Mz[2:-2, 2:-2, 2:-2] |
|
|
E0 = E0[2:-2, 2:-2, 2:-2] |
|
|
|
|
|
Mx += dtdz * (fMx[:, :, 1:] - fMx[:, :, :-1]) |
|
|
My += dtdz * (fMy[:, :, 1:] - fMy[:, :, :-1]) |
|
|
Mz += dtdz * (fMz[:, :, 1:] - fMz[:, :, :-1]) |
|
|
E0 += dtdz * (fE[:, :, 1:] - fE[:, :, :-1]) |
|
|
|
|
|
|
|
|
Q = Q.at[1, 2:-2, 2:-2, 2:-2].set(Mx / D0) |
|
|
Q = Q.at[2, 2:-2, 2:-2, 2:-2].set(My / D0) |
|
|
Q = Q.at[3, 2:-2, 2:-2, 2:-2].set(Mz / D0) |
|
|
Q = Q.at[4, 2:-2, 2:-2, 2:-2].set( |
|
|
gammi1 * (E0 - 0.5 * (Mx**2 + My**2 + Mz**2) / D0) |
|
|
) |
|
|
|
|
|
return Q, dt |
|
|
|
|
|
Q, dt, dt_vis, t_vis = carry |
|
|
Q = bc_HD( |
|
|
Q, mode=cfg.args.bc |
|
|
) |
|
|
dt_ev = jnp.min(jnp.array([dt, dt_vis, dt - t_vis])) |
|
|
|
|
|
carry = Q, dt_ev |
|
|
|
|
|
carry = _update_vis_x(carry) |
|
|
carry = _update_vis_y(carry) |
|
|
Q, d_ev = _update_vis_z(carry) |
|
|
|
|
|
t_vis += dt_ev |
|
|
|
|
|
return Q, dt, dt_vis, t_vis |
|
|
|
|
|
@jit |
|
|
def flux_x(Q): |
|
|
QL, QR = limiting_HD(Q, if_second_order=cfg.args.if_second_order) |
|
|
|
|
|
return HLLC(QL, QR, direc=0) |
|
|
|
|
|
@jit |
|
|
def flux_y(Q): |
|
|
_Q = jnp.transpose(Q, (0, 2, 3, 1)) |
|
|
QL, QR = limiting_HD(_Q, if_second_order=cfg.args.if_second_order) |
|
|
|
|
|
return jnp.transpose(HLLC(QL, QR, direc=1), (0, 3, 1, 2)) |
|
|
|
|
|
@jit |
|
|
def flux_z(Q): |
|
|
_Q = jnp.transpose(Q, (0, 3, 1, 2)) |
|
|
QL, QR = limiting_HD(_Q, if_second_order=cfg.args.if_second_order) |
|
|
|
|
|
return jnp.transpose(HLLC(QL, QR, direc=2), (0, 2, 3, 1)) |
|
|
|
|
|
@partial(jit, static_argnums=(2,)) |
|
|
def HLL(QL, QR, direc): |
|
|
|
|
|
iX, iY, iZ = direc + 1, (direc + 1) % 3 + 1, (direc + 2) % 3 + 1 |
|
|
cfL = jnp.sqrt(gamma * QL[4] / QL[0]) |
|
|
cfR = jnp.sqrt(gamma * QR[4] / QR[0]) |
|
|
Sfl = jnp.minimum(QL[iX, 2:-1], QR[iX, 1:-2]) - jnp.maximum( |
|
|
cfL[2:-1], cfR[1:-2] |
|
|
) |
|
|
Sfr = jnp.maximum(QL[iX, 2:-1], QR[iX, 1:-2]) + jnp.maximum( |
|
|
cfL[2:-1], cfR[1:-2] |
|
|
) |
|
|
dcfi = 1.0 / (Sfr - Sfl + 1.0e-8) |
|
|
|
|
|
UL, UR = jnp.zeros_like(QL), jnp.zeros_like(QR) |
|
|
UL = UL.at[0].set(QL[0]) |
|
|
UL = UL.at[iX].set(QL[0] * QL[iX]) |
|
|
UL = UL.at[iY].set(QL[0] * QL[iY]) |
|
|
UL = UL.at[iZ].set(QL[0] * QL[iZ]) |
|
|
UL = UL.at[4].set( |
|
|
gamminv1 * QL[4] |
|
|
+ 0.5 * (UL[iX] * QL[iX] + UL[iY] * QL[iY] + UL[iZ] * QL[iZ]) |
|
|
) |
|
|
UR = UR.at[0].set(QR[0]) |
|
|
UR = UR.at[iX].set(QR[0] * QR[iX]) |
|
|
UR = UR.at[iY].set(QR[0] * QR[iY]) |
|
|
UR = UR.at[iZ].set(QR[0] * QR[iZ]) |
|
|
UR = UR.at[4].set( |
|
|
gamminv1 * QR[4] |
|
|
+ 0.5 * (UR[iX] * QR[iX] + UR[iY] * QR[iY] + UR[iZ] * QR[iZ]) |
|
|
) |
|
|
|
|
|
fL, fR = jnp.zeros_like(QL), jnp.zeros_like(QR) |
|
|
fL = fL.at[0].set(UL[iX]) |
|
|
fL = fL.at[iX].set(UL[iX] * QL[iX] + QL[4]) |
|
|
fL = fL.at[iY].set(UL[iX] * QL[iY]) |
|
|
fL = fL.at[iZ].set(UL[iX] * QL[iZ]) |
|
|
fL = fL.at[4].set((UL[4] + QL[4]) * QL[iX]) |
|
|
fR = fR.at[0].set(UR[iX]) |
|
|
fR = fR.at[iX].set(UR[iX] * QR[iX] + QR[4]) |
|
|
fR = fR.at[iY].set(UR[iX] * QR[iY]) |
|
|
fR = fR.at[iZ].set(UR[iX] * QR[iZ]) |
|
|
fR = fR.at[4].set((UR[4] + QR[4]) * QR[iX]) |
|
|
|
|
|
fHLL = dcfi * ( |
|
|
Sfr * fR[:, 1:-2] |
|
|
- Sfl * fL[:, 2:-1] |
|
|
+ Sfl * Sfr * (UL[:, 2:-1] - UR[:, 1:-2]) |
|
|
) |
|
|
|
|
|
|
|
|
f_Riemann = jnp.where(Sfl > 0.0, fR[:, 1:-2], fHLL) |
|
|
return jnp.where(Sfr < 0.0, fL[:, 2:-1], f_Riemann) |
|
|
|
|
|
@partial(jit, static_argnums=(2,)) |
|
|
def HLLC(QL, QR, direc): |
|
|
"""full-Godunov method -- exact shock solution""" |
|
|
|
|
|
iX, iY, iZ = direc + 1, (direc + 1) % 3 + 1, (direc + 2) % 3 + 1 |
|
|
cfL = jnp.sqrt(gamma * QL[4] / QL[0]) |
|
|
cfR = jnp.sqrt(gamma * QR[4] / QR[0]) |
|
|
Sfl = jnp.minimum(QL[iX, 2:-1], QR[iX, 1:-2]) - jnp.maximum( |
|
|
cfL[2:-1], cfR[1:-2] |
|
|
) |
|
|
Sfr = jnp.maximum(QL[iX, 2:-1], QR[iX, 1:-2]) + jnp.maximum( |
|
|
cfL[2:-1], cfR[1:-2] |
|
|
) |
|
|
|
|
|
UL, UR = jnp.zeros_like(QL), jnp.zeros_like(QR) |
|
|
UL = UL.at[0].set(QL[0]) |
|
|
UL = UL.at[iX].set(QL[0] * QL[iX]) |
|
|
UL = UL.at[iY].set(QL[0] * QL[iY]) |
|
|
UL = UL.at[iZ].set(QL[0] * QL[iZ]) |
|
|
UL = UL.at[4].set( |
|
|
gamminv1 * QL[4] |
|
|
+ 0.5 * (UL[iX] * QL[iX] + UL[iY] * QL[iY] + UL[iZ] * QL[iZ]) |
|
|
) |
|
|
UR = UR.at[0].set(QR[0]) |
|
|
UR = UR.at[iX].set(QR[0] * QR[iX]) |
|
|
UR = UR.at[iY].set(QR[0] * QR[iY]) |
|
|
UR = UR.at[iZ].set(QR[0] * QR[iZ]) |
|
|
UR = UR.at[4].set( |
|
|
gamminv1 * QR[4] |
|
|
+ 0.5 * (UR[iX] * QR[iX] + UR[iY] * QR[iY] + UR[iZ] * QR[iZ]) |
|
|
) |
|
|
|
|
|
Va = ( |
|
|
(Sfr - QL[iX, 2:-1]) * UL[iX, 2:-1] |
|
|
- (Sfl - QR[iX, 1:-2]) * UR[iX, 1:-2] |
|
|
- QL[4, 2:-1] |
|
|
+ QR[4, 1:-2] |
|
|
) |
|
|
Va /= (Sfr - QL[iX, 2:-1]) * QL[0, 2:-1] - (Sfl - QR[iX, 1:-2]) * QR[0, 1:-2] |
|
|
Pa = QR[4, 1:-2] + QR[0, 1:-2] * (Sfl - QR[iX, 1:-2]) * (Va - QR[iX, 1:-2]) |
|
|
|
|
|
|
|
|
Dal = QR[0, 1:-2] * (Sfl - QR[iX, 1:-2]) / (Sfl - Va) |
|
|
Dar = QL[0, 2:-1] * (Sfr - QL[iX, 2:-1]) / (Sfr - Va) |
|
|
|
|
|
fL, fR = jnp.zeros_like(QL), jnp.zeros_like(QR) |
|
|
fL = fL.at[0].set(UL[iX]) |
|
|
fL = fL.at[iX].set(UL[iX] * QL[iX] + QL[4]) |
|
|
fL = fL.at[iY].set(UL[iX] * QL[iY]) |
|
|
fL = fL.at[iZ].set(UL[iX] * QL[iZ]) |
|
|
fL = fL.at[4].set((UL[4] + QL[4]) * QL[iX]) |
|
|
fR = fR.at[0].set(UR[iX]) |
|
|
fR = fR.at[iX].set(UR[iX] * QR[iX] + QR[4]) |
|
|
fR = fR.at[iY].set(UR[iX] * QR[iY]) |
|
|
fR = fR.at[iZ].set(UR[iX] * QR[iZ]) |
|
|
fR = fR.at[4].set((UR[4] + QR[4]) * QR[iX]) |
|
|
|
|
|
far, fal = jnp.zeros_like(QL[:, 2:-1]), jnp.zeros_like(QR[:, 1:-2]) |
|
|
far = far.at[0].set(Dar * Va) |
|
|
far = far.at[iX].set(Dar * Va**2 + Pa) |
|
|
far = far.at[iY].set(Dar * Va * QL[iY, 2:-1]) |
|
|
far = far.at[iZ].set(Dar * Va * QL[iZ, 2:-1]) |
|
|
far = far.at[4].set( |
|
|
( |
|
|
gamgamm1inv * Pa |
|
|
+ 0.5 * Dar * (Va**2 + QL[iY, 2:-1] ** 2 + QL[iZ, 2:-1] ** 2) |
|
|
) |
|
|
* Va |
|
|
) |
|
|
fal = fal.at[0].set(Dal * Va) |
|
|
fal = fal.at[iX].set(Dal * Va**2 + Pa) |
|
|
fal = fal.at[iY].set(Dal * Va * QR[iY, 1:-2]) |
|
|
fal = fal.at[iZ].set(Dal * Va * QR[iZ, 1:-2]) |
|
|
fal = fal.at[4].set( |
|
|
( |
|
|
gamgamm1inv * Pa |
|
|
+ 0.5 * Dal * (Va**2 + QR[iY, 1:-2] ** 2 + QR[iZ, 1:-2] ** 2) |
|
|
) |
|
|
* Va |
|
|
) |
|
|
|
|
|
f_Riemann = jnp.where( |
|
|
Sfl > 0.0, fR[:, 1:-2], fL[:, 2:-1] |
|
|
) |
|
|
f_Riemann = jnp.where( |
|
|
Sfl * Va < 0.0, fal, f_Riemann |
|
|
) |
|
|
return jnp.where( |
|
|
Sfr * Va < 0.0, far, f_Riemann |
|
|
) |
|
|
|
|
|
|
|
|
Q = jnp.zeros( |
|
|
[cfg.args.numbers, 5, nx + 4, ny + 4, nz + 4] |
|
|
) |
|
|
|
|
|
Q = Q.at[:, 0, 2:-2, 2:-2, 2:-2].set( |
|
|
init_multi_HD( |
|
|
xc, |
|
|
yc, |
|
|
zc, |
|
|
numbers=cfg.args.numbers, |
|
|
k_tot=3, |
|
|
init_key=cfg.args.init_key, |
|
|
num_choise_k=2, |
|
|
umin=1.0e0, |
|
|
umax=1.0e1, |
|
|
if_renorm=True, |
|
|
) |
|
|
) |
|
|
|
|
|
Q = device_put(Q) |
|
|
|
|
|
DDDs = [] |
|
|
VVxs = [] |
|
|
VVys = [] |
|
|
VVzs = [] |
|
|
PPPs = [] |
|
|
for i in range(Q.shape[0]): |
|
|
t, DDD, VVx, VVy, VVz, PPP = evolve(Q[i]) |
|
|
DDDs.append(jnp.squeeze(DDD)) |
|
|
VVxs.append(jnp.squeeze(VVx)) |
|
|
VVys.append(jnp.squeeze(VVy)) |
|
|
VVzs.append(jnp.squeeze(VVz)) |
|
|
PPPs.append(jnp.squeeze(PPP)) |
|
|
|
|
|
density = jnp.stack(DDDs) |
|
|
ux = jnp.stack(VVxs) |
|
|
pressure = jnp.stack(PPPs) |
|
|
|
|
|
return ux, density, pressure, xc, tc |
|
|
|
|
|
|
|
|
@hydra.main(config_path="config", config_name="config", version_base=None) |
|
|
def main(cfg: DictConfig) -> None: |
|
|
nxs = cfg.args.nx |
|
|
dt_saves = cfg.args.dt_save |
|
|
outputs = [] |
|
|
xcs = [] |
|
|
tcs = [] |
|
|
for nx, dt_save in zip(nxs, dt_saves): |
|
|
print(nx, dt_save) |
|
|
u, density, pressure, xc, tc = run_step(cfg, nx, dt_save) |
|
|
full = np.stack([u, density, pressure], axis=-1) |
|
|
outputs.append(full) |
|
|
xcs.append(xc) |
|
|
tcs.append(tc[1:]) |
|
|
|
|
|
|
|
|
errors = [] |
|
|
for i in range(len(nxs) - 1): |
|
|
coarse_tuple = (outputs[i], xcs[i], tcs[i]) |
|
|
fine_tuple = (outputs[i+1], xcs[i+1], tcs[i+1]) |
|
|
error = compute_error( |
|
|
coarse_tuple, fine_tuple |
|
|
) |
|
|
errors.append(error) |
|
|
breakpoint() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|