"""3-D Poisson solver for linear and nonlinear pressure perturbations. The governing equation (anelastic/Boussinesq, Trapp 2013) is: ∇²p' = F_lin + F_spin + F_splat + F_buoy where each forcing term is solved independently so the contributions can be compared and displayed separately. Solver: 2-D FFT in the (periodic) horizontal + vectorized Thomas algorithm for the resulting 1-D BVP in z with Neumann BCs. """ from __future__ import annotations import numpy as np from .sounding import G # -------------------------------------------------------------------------- # Grid derivative helpers # -------------------------------------------------------------------------- def _grad_x(f: np.ndarray, dx: float) -> np.ndarray: """∂f/∂x with periodic x (centered differences).""" return (np.roll(f, -1, axis=0) - np.roll(f, 1, axis=0)) / (2.0 * dx) def _grad_y(f: np.ndarray, dx: float) -> np.ndarray: """∂f/∂y with periodic y (centered differences).""" return (np.roll(f, -1, axis=1) - np.roll(f, 1, axis=1)) / (2.0 * dx) def _grad_z(f: np.ndarray, dz: float) -> np.ndarray: """∂f/∂z with one-sided differences at boundaries.""" out = np.empty_like(f) out[:, :, 1:-1] = (f[:, :, 2:] - f[:, :, :-2]) / (2.0 * dz) out[:, :, 0] = (f[:, :, 1] - f[:, :, 0]) / dz out[:, :, -1] = (f[:, :, -1] - f[:, :, -2]) / dz return out # -------------------------------------------------------------------------- # Forcing term constructors # -------------------------------------------------------------------------- def forcing_linear( rho0: np.ndarray, dudz_env: np.ndarray, dvdz_env: np.ndarray, w3d: np.ndarray, dx: float, ) -> np.ndarray: """Linear (shear-interaction) forcing. F_lin = -2ρ₀ [(∂U/∂z)(∂w'/∂x) + (∂V/∂z)(∂w'/∂y)] """ dwdx = _grad_x(w3d, dx) dwdy = _grad_y(w3d, dx) rho0_3d = rho0[np.newaxis, np.newaxis, :] # broadcast to (1,1,Nz) dUdz = dudz_env[np.newaxis, np.newaxis, :] dVdz = dvdz_env[np.newaxis, np.newaxis, :] return -2.0 * rho0_3d * (dUdz * dwdx + dVdz * dwdy) def _strain_and_rotation( u3d: np.ndarray, v3d: np.ndarray, w3d: np.ndarray, env_u: np.ndarray, env_v: np.ndarray, dx: float, dz: float, ) -> tuple[np.ndarray, np.ndarray]: """Compute ΣΣ Sᵢⱼ² and ΣΣ Rᵢⱼ² from the perturbation velocity field.""" # Perturbation winds up = u3d - env_u[np.newaxis, np.newaxis, :] vp = v3d - env_v[np.newaxis, np.newaxis, :] wp = w3d # no environmental w # Velocity gradients of perturbation dudx = _grad_x(up, dx) dudy = _grad_y(up, dx) dudz = _grad_z(up, dz) dvdx = _grad_x(vp, dx) dvdy = _grad_y(vp, dx) dvdz = _grad_z(vp, dz) dwdx = _grad_x(wp, dx) dwdy = _grad_y(wp, dx) dwdz = _grad_z(wp, dz) # Strain rate tensor Sᵢⱼ = ½(∂uᵢ/∂xⱼ + ∂uⱼ/∂xᵢ) S11 = dudx S22 = dvdy S33 = dwdz S12 = 0.5 * (dudy + dvdx) S13 = 0.5 * (dudz + dwdx) S23 = 0.5 * (dvdz + dwdy) S2 = S11**2 + S22**2 + S33**2 + 2.0 * (S12**2 + S13**2 + S23**2) # Rotation rate tensor ΣΣ Rᵢⱼ² = ½|ω|² where ω = ∇×u' # ζ_x = ∂w/∂y − ∂v/∂z, ζ_y = ∂u/∂z − ∂w/∂x, ζ_z = ∂v/∂x − ∂u/∂y zx = dwdy - dvdz zy = dudz - dwdx zz = dvdx - dudy R2 = 0.5 * (zx**2 + zy**2 + zz**2) # = ΣΣ Rᵢⱼ² return S2, R2 def forcing_splat( rho0: np.ndarray, u3d: np.ndarray, v3d: np.ndarray, w3d: np.ndarray, env_u: np.ndarray, env_v: np.ndarray, dx: float, dz: float, ) -> np.ndarray: """Nonlinear splat (deformation) forcing F_splat = -ρ₀ ΣΣ Sᵢⱼ².""" S2, _ = _strain_and_rotation(u3d, v3d, w3d, env_u, env_v, dx, dz) return -rho0[np.newaxis, np.newaxis, :] * S2 def forcing_spin( rho0: np.ndarray, u3d: np.ndarray, v3d: np.ndarray, w3d: np.ndarray, env_u: np.ndarray, env_v: np.ndarray, dx: float, dz: float, ) -> np.ndarray: """Nonlinear spin (rotation) forcing F_spin = +ρ₀ ΣΣ Rᵢⱼ² = ρ₀/2 |ω'|².""" _, R2 = _strain_and_rotation(u3d, v3d, w3d, env_u, env_v, dx, dz) return rho0[np.newaxis, np.newaxis, :] * R2 def forcing_buoyancy( rho0: np.ndarray, theta0: np.ndarray, theta_prime3d: np.ndarray, dz: float, ) -> np.ndarray: """Buoyancy pressure forcing F_buoy = -ρ₀ (g/θ₀) ∂θ'/∂z.""" dthp_dz = _grad_z(theta_prime3d, dz) return -rho0[np.newaxis, np.newaxis, :] * (G / theta0[np.newaxis, np.newaxis, :]) * dthp_dz # -------------------------------------------------------------------------- # Vectorized Thomas algorithm (TDMA) # -------------------------------------------------------------------------- def _tdma_batch( a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray, ) -> np.ndarray: """Thomas algorithm for M independent tridiagonal systems of size N. a, b, c, d : (M, N) arrays (complex or real). a[:,0] and c[:,-1] are unused (boundary rows). Returns x : (M, N). """ _, N = d.shape cp = np.zeros_like(d) dp = np.zeros_like(d) # Forward sweep w = b[:, 0].copy() cp[:, 0] = c[:, 0] / w dp[:, 0] = d[:, 0] / w for k in range(1, N): w = b[:, k] - a[:, k] * cp[:, k - 1] cp[:, k] = c[:, k] / w dp[:, k] = (d[:, k] - a[:, k] * dp[:, k - 1]) / w # Back substitution x = np.zeros_like(d) x[:, N - 1] = dp[:, N - 1] for k in range(N - 2, -1, -1): x[:, k] = dp[:, k] - cp[:, k] * x[:, k + 1] return x # -------------------------------------------------------------------------- # 3-D Poisson solver # -------------------------------------------------------------------------- def solve_poisson_3d(F: np.ndarray, dx: float, dz: float) -> np.ndarray: """Solve ∇²p = F on a periodic (x,y) × Neumann-z domain. F : (Nx, Ny, Nz) real forcing array. dx : horizontal grid spacing (same in x and y) in meters. dz : vertical grid spacing in meters. Returns p : (Nx, Ny, Nz) real array with zero mean. """ Nx, Ny, Nz = F.shape # --- 2-D real FFT in the horizontal --- F_hat = np.fft.rfft2(F, axes=(0, 1)) # (Nx, Ny//2+1, Nz) Nkx = Nx Nky = Ny // 2 + 1 Nmodes = Nkx * Nky kx = np.fft.fftfreq(Nx, d=dx) * (2.0 * np.pi) # (Nx,) ky = np.fft.rfftfreq(Ny, d=dx) * (2.0 * np.pi) # (Nky,) KX, KY = np.meshgrid(kx, ky, indexing="ij") # (Nx, Nky) lam2 = (KX**2 + KY**2).reshape(Nmodes) # (Nmodes,) # Reshape F_hat → (Nmodes, Nz) for batch solve rhs = F_hat.reshape(Nmodes, Nz) * dz**2 # (Nmodes, Nz), complex # --- Build tridiagonal coefficients (Nmodes, Nz) --- main_diag = -(2.0 + lam2[:, np.newaxis] * dz**2) * np.ones((Nmodes, Nz), dtype=complex) upper_diag = np.ones((Nmodes, Nz), dtype=complex) lower_diag = np.ones((Nmodes, Nz), dtype=complex) # Neumann BC at z=0: ghost-point reflection → super-diagonal = 2 at row 0 upper_diag[:, 0] = 2.0 upper_diag[:, -1] = 0.0 # unused # Neumann BC at z=top: sub-diagonal = 2 at last row lower_diag[:, -1] = 2.0 lower_diag[:, 0] = 0.0 # unused # --- Handle the (kx=0, ky=0) mode separately --- # ∇²p = F with Neumann BCs is solvable only if ∫F dz = 0. # Enforce solvability by subtracting mean, then fix gauge p[0]=0. m00 = 0 rhs[m00] -= rhs[m00].mean() main_diag[m00, 0] = 1.0 upper_diag[m00, 0] = 0.0 lower_diag[m00, 0] = 0.0 rhs[m00, 0] = 0.0 # --- Batch TDMA --- P_hat_2d = _tdma_batch(lower_diag, main_diag, upper_diag, rhs) # (Nmodes, Nz) # --- Inverse 2-D FFT --- P_hat = P_hat_2d.reshape(Nkx, Nky, Nz) p = np.fft.irfft2(P_hat, s=(Nx, Ny), axes=(0, 1)) # (Nx, Ny, Nz) p -= p.mean() return p # -------------------------------------------------------------------------- # Acceleration diagnostics # -------------------------------------------------------------------------- def pressure_accelerations( p_lin: np.ndarray, p_spin: np.ndarray, p_splat: np.ndarray, p_buoy: np.ndarray, rho0: np.ndarray, dz: float, ) -> dict: """Compute vertical acceleration from each pressure component. a = -(1/ρ₀) ∂p'/∂z at each (x, y, z). """ def _accel(p): dpdz = _grad_z(p, dz) return -dpdz / rho0[np.newaxis, np.newaxis, :] return { "a_lin": _accel(p_lin), "a_spin": _accel(p_spin), "a_splat": _accel(p_splat), "a_buoy": _accel(p_buoy), "a_total": _accel(p_lin + p_spin + p_splat + p_buoy), }