PhysicsNeMo-MHD / mhd /losses /loss_mhd_vec_pot_physicsnemo.py
carmelog's picture
init: magnetohydrodynamics with physicsnemo
830a558
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from physicsnemo.models.layers.spectral_layers import fourier_derivatives
from .losses import (LpLoss, fourier_derivatives_lap, fourier_derivatives_ptot,
fourier_derivatives_vec_pot)
from .mhd_pde import MHD_PDE
class LossMHDVecPot_PhysicsNeMo(object):
"Calculate loss for MHD equations with vector potential, using physicsnemo derivatives"
def __init__(
self,
nu=1e-4,
eta=1e-4,
rho0=1.0,
data_weight=1.0,
ic_weight=1.0,
pde_weight=1.0,
constraint_weight=1.0,
use_data_loss=True,
use_ic_loss=True,
use_pde_loss=True,
use_constraint_loss=True,
u_weight=1.0,
v_weight=1.0,
A_weight=1.0,
Du_weight=1.0,
Dv_weight=1.0,
DA_weight=1.0,
div_B_weight=1.0,
div_vel_weight=1.0,
Lx=1.0,
Ly=1.0,
tend=1.0,
use_weighted_mean=False,
**kwargs,
): # add **kwargs so that we ignore unexpected kwargs when passing a config dict):
self.nu = nu
self.eta = eta
self.rho0 = rho0
self.data_weight = data_weight
self.ic_weight = ic_weight
self.pde_weight = pde_weight
self.constraint_weight = constraint_weight
self.use_data_loss = use_data_loss
self.use_ic_loss = use_ic_loss
self.use_pde_loss = use_pde_loss
self.use_constraint_loss = use_constraint_loss
self.u_weight = u_weight
self.v_weight = v_weight
self.Du_weight = Du_weight
self.Dv_weight = Dv_weight
self.div_B_weight = div_B_weight
self.div_vel_weight = div_vel_weight
self.Lx = Lx
self.Ly = Ly
self.tend = tend
self.use_weighted_mean = use_weighted_mean
self.A_weight = A_weight
self.DA_weight = DA_weight
# Define 2D MHD PDEs
self.mhd_pde_eq = MHD_PDE(self.nu, self.eta, self.rho0)
self.mhd_pde_node = self.mhd_pde_eq.make_nodes()
if not self.use_data_loss:
self.data_weight = 0
if not self.use_ic_loss:
self.ic_weight = 0
if not self.use_pde_loss:
self.pde_weight = 0
if not self.use_constraint_loss:
self.constraint_weight = 0
def __call__(self, pred, true, inputs, return_loss_dict=False):
loss, loss_dict = self.compute_losses(pred, true, inputs)
return loss, loss_dict
def compute_loss(self, pred, true, inputs):
"Compute weighted loss"
pred = pred.reshape(true.shape)
u = pred[..., 0]
v = pred[..., 1]
A = pred[..., 2]
# Data
if self.use_data_loss:
loss_data = self.data_loss(pred, true)
else:
loss_data = 0
# IC
if self.use_ic_loss:
loss_ic = self.ic_loss(pred, inputs)
else:
loss_ic = 0
# PDE
if self.use_pde_loss:
Du, Dv, DA = self.mhd_pde(u, v, A)
loss_pde = self.mhd_pde_loss(Du, Dv, DA)
else:
loss_pde = 0
# Constraints
if self.use_constraint_loss:
div_vel, div_B = self.mhd_constraint(u, v, A)
loss_constraint = self.mhd_constraint_loss(div_vel, div_B)
else:
loss_constraint = 0
if self.use_weighted_mean:
weight_sum = (
self.data_weight
+ self.ic_weight
+ self.pde_weight
+ self.constraint_weight
)
else:
weight_sum = 1.0
loss = (
self.data_weight * loss_data
+ self.ic_weight * loss_ic
+ self.pde_weight * loss_pde
+ self.constraint_weight * loss_constraint
) / weight_sum
return loss
def compute_losses(self, pred, true, inputs):
"Compute weighted loss and dictionary"
pred = pred.reshape(true.shape)
u = pred[..., 0]
v = pred[..., 1]
A = pred[..., 2]
loss_dict = {}
# Data
if self.use_data_loss:
loss_data, loss_u, loss_v, loss_A = self.data_loss(
pred, true, return_all_losses=True
)
loss_dict["loss_data"] = loss_data
loss_dict["loss_u"] = loss_u
loss_dict["loss_v"] = loss_v
loss_dict["loss_A"] = loss_A
else:
loss_data = 0
# IC
if self.use_ic_loss:
loss_ic, loss_u_ic, loss_v_ic, loss_A_ic = self.ic_loss(
pred, inputs, return_all_losses=True
)
loss_dict["loss_ic"] = loss_ic
loss_dict["loss_u_ic"] = loss_u_ic
loss_dict["loss_v_ic"] = loss_v_ic
loss_dict["loss_A_ic"] = loss_A_ic
else:
loss_ic = 0
# PDE
if self.use_pde_loss:
Du, Dv, DA = self.mhd_pde(u, v, A)
loss_pde, loss_Du, loss_Dv, loss_DA = self.mhd_pde_loss(
Du, Dv, DA, return_all_losses=True
)
loss_dict["loss_pde"] = loss_pde
loss_dict["loss_Du"] = loss_Du
loss_dict["loss_Dv"] = loss_Dv
loss_dict["loss_DA"] = loss_DA
else:
loss_pde = 0
# Constraints
if self.use_constraint_loss:
div_vel, div_B = self.mhd_constraint(u, v, A)
loss_constraint, loss_div_vel, loss_div_B = self.mhd_constraint_loss(
div_vel, div_B, return_all_losses=True
)
loss_dict["loss_constraint"] = loss_constraint
loss_dict["loss_div_vel"] = loss_div_vel
loss_dict["loss_div_B"] = loss_div_B
else:
loss_constraint = 0
if self.use_weighted_mean:
weight_sum = (
self.data_weight
+ self.ic_weight
+ self.pde_weight
+ self.constraint_weight
)
else:
weight_sum = 1.0
loss = (
self.data_weight * loss_data
+ self.ic_weight * loss_ic
+ self.pde_weight * loss_pde
+ self.constraint_weight * loss_constraint
) / weight_sum
loss_dict["loss"] = loss
return loss, loss_dict
def data_loss(self, pred, true, return_all_losses=False):
"Compute data loss"
lploss = LpLoss(size_average=True)
u_pred = pred[..., 0]
v_pred = pred[..., 1]
A_pred = pred[..., 2]
u_true = true[..., 0]
v_true = true[..., 1]
A_true = true[..., 2]
loss_u = lploss(u_pred, u_true)
loss_v = lploss(v_pred, v_true)
loss_A = lploss(A_pred, A_true)
if self.use_weighted_mean:
weight_sum = self.u_weight + self.v_weight + self.A_weight
else:
weight_sum = 1.0
loss_data = (
self.u_weight * loss_u + self.v_weight * loss_v + self.A_weight * loss_A
) / weight_sum
if return_all_losses:
return loss_data, loss_u, loss_v, loss_A
else:
return loss_data
def ic_loss(self, pred, input, return_all_losses=False):
"Compute initial condition loss"
lploss = LpLoss(size_average=True)
ic_pred = pred[:, 0]
ic_true = input[:, 0, ..., 3:]
u_ic_pred = ic_pred[..., 0]
v_ic_pred = ic_pred[..., 1]
A_ic_pred = ic_pred[..., 2]
u_ic_true = ic_true[..., 0]
v_ic_true = ic_true[..., 1]
A_ic_true = ic_true[..., 2]
loss_u_ic = lploss(u_ic_pred, u_ic_true)
loss_v_ic = lploss(v_ic_pred, v_ic_true)
loss_A_ic = lploss(A_ic_pred, A_ic_true)
if self.use_weighted_mean:
weight_sum = self.u_weight + self.v_weight + self.A_weight
else:
weight_sum = 1.0
loss_ic = (
self.u_weight * loss_u_ic
+ self.v_weight * loss_v_ic
+ self.A_weight * loss_A_ic
) / weight_sum
if return_all_losses:
return loss_ic, loss_u_ic, loss_v_ic, loss_A_ic
else:
return loss_ic
def mhd_pde_loss(self, Du, Dv, DA, return_all_losses=None):
"Compute PDE loss"
Du_val = torch.zeros_like(Du)
Dv_val = torch.zeros_like(Dv)
DA_val = torch.zeros_like(DA)
loss_Du = F.mse_loss(Du, Du_val)
loss_Dv = F.mse_loss(Dv, Dv_val)
loss_DA = F.mse_loss(DA, DA_val)
if self.use_weighted_mean:
weight_sum = self.Du_weight + self.Dv_weight + self.DA_weight
else:
weight_sum = 1.0
loss_pde = (
self.Du_weight * loss_Du
+ self.Dv_weight * loss_Dv
+ self.DA_weight * loss_DA
) / weight_sum
if return_all_losses:
return loss_pde, loss_Du, loss_Dv, loss_DA
else:
return loss_pde
def mhd_constraint(self, u, v, A):
"Compute constraints"
nt = u.size(1)
nx = u.size(2)
ny = u.size(3)
f_du, _ = fourier_derivatives(u, [self.Lx, self.Ly])
f_dv, _ = fourier_derivatives(v, [self.Lx, self.Ly])
f_dBx, f_dBy, _, _, _ = fourier_derivatives_vec_pot(A, [self.Lx, self.Ly])
u_x = f_du[:, 0:nt, :nx, :ny]
v_y = f_dv[:, nt : 2 * nt, :nx, :ny]
Bx_x = f_dBx[:, 0:nt, :nx, :ny]
By_y = f_dBy[:, nt : 2 * nt, :nx, :ny]
div_B = self.mhd_pde_node[12].evaluate({"Bx__x": Bx_x, "By__y": By_y})["div_B"]
div_vel = self.mhd_pde_node[13].evaluate({"u__x": u_x, "v__y": v_y})["div_vel"]
return div_vel, div_B
def mhd_constraint_loss(self, div_vel, div_B, return_all_losses=False):
"Compute constraint loss"
div_vel_val = torch.zeros_like(div_vel)
div_B_val = torch.zeros_like(div_B)
loss_div_vel = F.mse_loss(div_vel, div_vel_val)
loss_div_B = F.mse_loss(div_B, div_B_val)
if self.use_weighted_mean:
weight_sum = self.div_vel_weight + self.div_B_weight
else:
weight_sum = 1.0
loss_constraint = (
self.div_vel_weight * loss_div_vel + self.div_B_weight * loss_div_B
) / weight_sum
if return_all_losses:
return loss_constraint, loss_div_vel, loss_div_B
else:
return loss_constraint
def mhd_pde(self, u, v, A, p=None):
"Compute PDEs for MHD using vector potential"
nt = u.size(1)
nx = u.size(2)
ny = u.size(3)
dt = self.tend / (nt - 1)
# compute fourier derivatives
f_du, _ = fourier_derivatives(u, [self.Lx, self.Ly])
f_dv, _ = fourier_derivatives(v, [self.Lx, self.Ly])
f_dBx, f_dBy, f_dA, f_dB, B2_h = fourier_derivatives_vec_pot(
A, [self.Lx, self.Ly]
)
u_x = f_du[:, 0:nt, :nx, :ny]
u_y = f_du[:, nt : 2 * nt, :nx, :ny]
v_x = f_dv[:, 0:nt, :nx, :ny]
v_y = f_dv[:, nt : 2 * nt, :nx, :ny]
A_x = f_dA[:, 0:nt, :nx, :ny]
A_y = f_dA[:, nt : 2 * nt, :nx, :ny]
Bx = f_dB[:, 0:nt, :nx, :ny]
By = f_dB[:, nt : 2 * nt, :nx, :ny]
Bx_x = f_dBx[:, 0:nt, :nx, :ny]
Bx_y = f_dBx[:, nt : 2 * nt, :nx, :ny]
By_x = f_dBy[:, 0:nt, :nx, :ny]
By_y = f_dBy[:, nt : 2 * nt, :nx, :ny]
u_lap = fourier_derivatives_lap(u, [self.Lx, self.Ly])
v_lap = fourier_derivatives_lap(v, [self.Lx, self.Ly])
A_lap = fourier_derivatives_lap(A, [self.Lx, self.Ly])
# note that for pressure, the zero mode (the mean) cannot be zero for invertability so it is set to 1
div_vel_grad_vel = u_x**2 + 2 * u_y * v_x + v_y**2
div_B_grad_B = Bx_x**2 + 2 * Bx_y * By_x + By_y**2
f_dptot = fourier_derivatives_ptot(
p, div_vel_grad_vel, div_B_grad_B, B2_h, self.rho0, [self.Lx, self.Ly]
)
ptot_x = f_dptot[:, 0:nt, :nx, :ny]
ptot_y = f_dptot[:, nt : 2 * nt, :nx, :ny]
# Plug inputs into dictionary
all_inputs = {
"u": u,
"u__x": u_x,
"u__y": u_y,
"v": v,
"v__x": v_x,
"v__y": v_y,
"Bx": Bx,
"Bx__x": Bx_x,
"Bx__y": Bx_y,
"By": By,
"By__x": By_x,
"By__y": By_y,
"A__x": A_x,
"A__y": A_y,
"ptot__x": ptot_x,
"ptot__y": ptot_y,
"u__lap": u_lap,
"v__lap": v_lap,
"A__lap": A_lap,
}
# Substitute values into PDE equations
u_rhs = self.mhd_pde_node[14].evaluate(all_inputs)["u_rhs"]
v_rhs = self.mhd_pde_node[15].evaluate(all_inputs)["v_rhs"]
A_rhs = self.mhd_pde_node[23].evaluate(all_inputs)["A_rhs"]
u_t = self.Du_t(u, dt)
v_t = self.Du_t(v, dt)
A_t = self.Du_t(A, dt)
# Find difference
Du = self.mhd_pde_node[18].evaluate({"u__t": u_t, "u_rhs": u_rhs[:, 1:-1]})[
"Du"
]
Dv = self.mhd_pde_node[19].evaluate({"v__t": v_t, "v_rhs": v_rhs[:, 1:-1]})[
"Dv"
]
DA = self.mhd_pde_node[24].evaluate({"A__t": A_t, "A_rhs": A_rhs[:, 1:-1]})[
"DA"
]
return Du, Dv, DA
def Du_t(self, u, dt):
"Compute time derivative"
u_t = (u[:, 2:] - u[:, :-2]) / (2 * dt)
return u_t