# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn.functional as F from physicsnemo.models.layers.spectral_layers import fourier_derivatives from .losses import (LpLoss, fourier_derivatives_lap, fourier_derivatives_ptot, fourier_derivatives_vec_pot) from .mhd_pde import MHD_PDE class LossMHDVecPot_PhysicsNeMo(object): "Calculate loss for MHD equations with vector potential, using physicsnemo derivatives" def __init__( self, nu=1e-4, eta=1e-4, rho0=1.0, data_weight=1.0, ic_weight=1.0, pde_weight=1.0, constraint_weight=1.0, use_data_loss=True, use_ic_loss=True, use_pde_loss=True, use_constraint_loss=True, u_weight=1.0, v_weight=1.0, A_weight=1.0, Du_weight=1.0, Dv_weight=1.0, DA_weight=1.0, div_B_weight=1.0, div_vel_weight=1.0, Lx=1.0, Ly=1.0, tend=1.0, use_weighted_mean=False, **kwargs, ): # add **kwargs so that we ignore unexpected kwargs when passing a config dict): self.nu = nu self.eta = eta self.rho0 = rho0 self.data_weight = data_weight self.ic_weight = ic_weight self.pde_weight = pde_weight self.constraint_weight = constraint_weight self.use_data_loss = use_data_loss self.use_ic_loss = use_ic_loss self.use_pde_loss = use_pde_loss self.use_constraint_loss = use_constraint_loss self.u_weight = u_weight self.v_weight = v_weight self.Du_weight = Du_weight self.Dv_weight = Dv_weight self.div_B_weight = div_B_weight self.div_vel_weight = div_vel_weight self.Lx = Lx self.Ly = Ly self.tend = tend self.use_weighted_mean = use_weighted_mean self.A_weight = A_weight self.DA_weight = DA_weight # Define 2D MHD PDEs self.mhd_pde_eq = MHD_PDE(self.nu, self.eta, self.rho0) self.mhd_pde_node = self.mhd_pde_eq.make_nodes() if not self.use_data_loss: self.data_weight = 0 if not self.use_ic_loss: self.ic_weight = 0 if not self.use_pde_loss: self.pde_weight = 0 if not self.use_constraint_loss: self.constraint_weight = 0 def __call__(self, pred, true, inputs, return_loss_dict=False): loss, loss_dict = self.compute_losses(pred, true, inputs) return loss, loss_dict def compute_loss(self, pred, true, inputs): "Compute weighted loss" pred = pred.reshape(true.shape) u = pred[..., 0] v = pred[..., 1] A = pred[..., 2] # Data if self.use_data_loss: loss_data = self.data_loss(pred, true) else: loss_data = 0 # IC if self.use_ic_loss: loss_ic = self.ic_loss(pred, inputs) else: loss_ic = 0 # PDE if self.use_pde_loss: Du, Dv, DA = self.mhd_pde(u, v, A) loss_pde = self.mhd_pde_loss(Du, Dv, DA) else: loss_pde = 0 # Constraints if self.use_constraint_loss: div_vel, div_B = self.mhd_constraint(u, v, A) loss_constraint = self.mhd_constraint_loss(div_vel, div_B) else: loss_constraint = 0 if self.use_weighted_mean: weight_sum = ( self.data_weight + self.ic_weight + self.pde_weight + self.constraint_weight ) else: weight_sum = 1.0 loss = ( self.data_weight * loss_data + self.ic_weight * loss_ic + self.pde_weight * loss_pde + self.constraint_weight * loss_constraint ) / weight_sum return loss def compute_losses(self, pred, true, inputs): "Compute weighted loss and dictionary" pred = pred.reshape(true.shape) u = pred[..., 0] v = pred[..., 1] A = pred[..., 2] loss_dict = {} # Data if self.use_data_loss: loss_data, loss_u, loss_v, loss_A = self.data_loss( pred, true, return_all_losses=True ) loss_dict["loss_data"] = loss_data loss_dict["loss_u"] = loss_u loss_dict["loss_v"] = loss_v loss_dict["loss_A"] = loss_A else: loss_data = 0 # IC if self.use_ic_loss: loss_ic, loss_u_ic, loss_v_ic, loss_A_ic = self.ic_loss( pred, inputs, return_all_losses=True ) loss_dict["loss_ic"] = loss_ic loss_dict["loss_u_ic"] = loss_u_ic loss_dict["loss_v_ic"] = loss_v_ic loss_dict["loss_A_ic"] = loss_A_ic else: loss_ic = 0 # PDE if self.use_pde_loss: Du, Dv, DA = self.mhd_pde(u, v, A) loss_pde, loss_Du, loss_Dv, loss_DA = self.mhd_pde_loss( Du, Dv, DA, return_all_losses=True ) loss_dict["loss_pde"] = loss_pde loss_dict["loss_Du"] = loss_Du loss_dict["loss_Dv"] = loss_Dv loss_dict["loss_DA"] = loss_DA else: loss_pde = 0 # Constraints if self.use_constraint_loss: div_vel, div_B = self.mhd_constraint(u, v, A) loss_constraint, loss_div_vel, loss_div_B = self.mhd_constraint_loss( div_vel, div_B, return_all_losses=True ) loss_dict["loss_constraint"] = loss_constraint loss_dict["loss_div_vel"] = loss_div_vel loss_dict["loss_div_B"] = loss_div_B else: loss_constraint = 0 if self.use_weighted_mean: weight_sum = ( self.data_weight + self.ic_weight + self.pde_weight + self.constraint_weight ) else: weight_sum = 1.0 loss = ( self.data_weight * loss_data + self.ic_weight * loss_ic + self.pde_weight * loss_pde + self.constraint_weight * loss_constraint ) / weight_sum loss_dict["loss"] = loss return loss, loss_dict def data_loss(self, pred, true, return_all_losses=False): "Compute data loss" lploss = LpLoss(size_average=True) u_pred = pred[..., 0] v_pred = pred[..., 1] A_pred = pred[..., 2] u_true = true[..., 0] v_true = true[..., 1] A_true = true[..., 2] loss_u = lploss(u_pred, u_true) loss_v = lploss(v_pred, v_true) loss_A = lploss(A_pred, A_true) if self.use_weighted_mean: weight_sum = self.u_weight + self.v_weight + self.A_weight else: weight_sum = 1.0 loss_data = ( self.u_weight * loss_u + self.v_weight * loss_v + self.A_weight * loss_A ) / weight_sum if return_all_losses: return loss_data, loss_u, loss_v, loss_A else: return loss_data def ic_loss(self, pred, input, return_all_losses=False): "Compute initial condition loss" lploss = LpLoss(size_average=True) ic_pred = pred[:, 0] ic_true = input[:, 0, ..., 3:] u_ic_pred = ic_pred[..., 0] v_ic_pred = ic_pred[..., 1] A_ic_pred = ic_pred[..., 2] u_ic_true = ic_true[..., 0] v_ic_true = ic_true[..., 1] A_ic_true = ic_true[..., 2] loss_u_ic = lploss(u_ic_pred, u_ic_true) loss_v_ic = lploss(v_ic_pred, v_ic_true) loss_A_ic = lploss(A_ic_pred, A_ic_true) if self.use_weighted_mean: weight_sum = self.u_weight + self.v_weight + self.A_weight else: weight_sum = 1.0 loss_ic = ( self.u_weight * loss_u_ic + self.v_weight * loss_v_ic + self.A_weight * loss_A_ic ) / weight_sum if return_all_losses: return loss_ic, loss_u_ic, loss_v_ic, loss_A_ic else: return loss_ic def mhd_pde_loss(self, Du, Dv, DA, return_all_losses=None): "Compute PDE loss" Du_val = torch.zeros_like(Du) Dv_val = torch.zeros_like(Dv) DA_val = torch.zeros_like(DA) loss_Du = F.mse_loss(Du, Du_val) loss_Dv = F.mse_loss(Dv, Dv_val) loss_DA = F.mse_loss(DA, DA_val) if self.use_weighted_mean: weight_sum = self.Du_weight + self.Dv_weight + self.DA_weight else: weight_sum = 1.0 loss_pde = ( self.Du_weight * loss_Du + self.Dv_weight * loss_Dv + self.DA_weight * loss_DA ) / weight_sum if return_all_losses: return loss_pde, loss_Du, loss_Dv, loss_DA else: return loss_pde def mhd_constraint(self, u, v, A): "Compute constraints" nt = u.size(1) nx = u.size(2) ny = u.size(3) f_du, _ = fourier_derivatives(u, [self.Lx, self.Ly]) f_dv, _ = fourier_derivatives(v, [self.Lx, self.Ly]) f_dBx, f_dBy, _, _, _ = fourier_derivatives_vec_pot(A, [self.Lx, self.Ly]) u_x = f_du[:, 0:nt, :nx, :ny] v_y = f_dv[:, nt : 2 * nt, :nx, :ny] Bx_x = f_dBx[:, 0:nt, :nx, :ny] By_y = f_dBy[:, nt : 2 * nt, :nx, :ny] div_B = self.mhd_pde_node[12].evaluate({"Bx__x": Bx_x, "By__y": By_y})["div_B"] div_vel = self.mhd_pde_node[13].evaluate({"u__x": u_x, "v__y": v_y})["div_vel"] return div_vel, div_B def mhd_constraint_loss(self, div_vel, div_B, return_all_losses=False): "Compute constraint loss" div_vel_val = torch.zeros_like(div_vel) div_B_val = torch.zeros_like(div_B) loss_div_vel = F.mse_loss(div_vel, div_vel_val) loss_div_B = F.mse_loss(div_B, div_B_val) if self.use_weighted_mean: weight_sum = self.div_vel_weight + self.div_B_weight else: weight_sum = 1.0 loss_constraint = ( self.div_vel_weight * loss_div_vel + self.div_B_weight * loss_div_B ) / weight_sum if return_all_losses: return loss_constraint, loss_div_vel, loss_div_B else: return loss_constraint def mhd_pde(self, u, v, A, p=None): "Compute PDEs for MHD using vector potential" nt = u.size(1) nx = u.size(2) ny = u.size(3) dt = self.tend / (nt - 1) # compute fourier derivatives f_du, _ = fourier_derivatives(u, [self.Lx, self.Ly]) f_dv, _ = fourier_derivatives(v, [self.Lx, self.Ly]) f_dBx, f_dBy, f_dA, f_dB, B2_h = fourier_derivatives_vec_pot( A, [self.Lx, self.Ly] ) u_x = f_du[:, 0:nt, :nx, :ny] u_y = f_du[:, nt : 2 * nt, :nx, :ny] v_x = f_dv[:, 0:nt, :nx, :ny] v_y = f_dv[:, nt : 2 * nt, :nx, :ny] A_x = f_dA[:, 0:nt, :nx, :ny] A_y = f_dA[:, nt : 2 * nt, :nx, :ny] Bx = f_dB[:, 0:nt, :nx, :ny] By = f_dB[:, nt : 2 * nt, :nx, :ny] Bx_x = f_dBx[:, 0:nt, :nx, :ny] Bx_y = f_dBx[:, nt : 2 * nt, :nx, :ny] By_x = f_dBy[:, 0:nt, :nx, :ny] By_y = f_dBy[:, nt : 2 * nt, :nx, :ny] u_lap = fourier_derivatives_lap(u, [self.Lx, self.Ly]) v_lap = fourier_derivatives_lap(v, [self.Lx, self.Ly]) A_lap = fourier_derivatives_lap(A, [self.Lx, self.Ly]) # note that for pressure, the zero mode (the mean) cannot be zero for invertability so it is set to 1 div_vel_grad_vel = u_x**2 + 2 * u_y * v_x + v_y**2 div_B_grad_B = Bx_x**2 + 2 * Bx_y * By_x + By_y**2 f_dptot = fourier_derivatives_ptot( p, div_vel_grad_vel, div_B_grad_B, B2_h, self.rho0, [self.Lx, self.Ly] ) ptot_x = f_dptot[:, 0:nt, :nx, :ny] ptot_y = f_dptot[:, nt : 2 * nt, :nx, :ny] # Plug inputs into dictionary all_inputs = { "u": u, "u__x": u_x, "u__y": u_y, "v": v, "v__x": v_x, "v__y": v_y, "Bx": Bx, "Bx__x": Bx_x, "Bx__y": Bx_y, "By": By, "By__x": By_x, "By__y": By_y, "A__x": A_x, "A__y": A_y, "ptot__x": ptot_x, "ptot__y": ptot_y, "u__lap": u_lap, "v__lap": v_lap, "A__lap": A_lap, } # Substitute values into PDE equations u_rhs = self.mhd_pde_node[14].evaluate(all_inputs)["u_rhs"] v_rhs = self.mhd_pde_node[15].evaluate(all_inputs)["v_rhs"] A_rhs = self.mhd_pde_node[23].evaluate(all_inputs)["A_rhs"] u_t = self.Du_t(u, dt) v_t = self.Du_t(v, dt) A_t = self.Du_t(A, dt) # Find difference Du = self.mhd_pde_node[18].evaluate({"u__t": u_t, "u_rhs": u_rhs[:, 1:-1]})[ "Du" ] Dv = self.mhd_pde_node[19].evaluate({"v__t": v_t, "v_rhs": v_rhs[:, 1:-1]})[ "Dv" ] DA = self.mhd_pde_node[24].evaluate({"A__t": A_t, "A_rhs": A_rhs[:, 1:-1]})[ "DA" ] return Du, Dv, DA def Du_t(self, u, dt): "Compute time derivative" u_t = (u[:, 2:] - u[:, :-2]) / (2 * dt) return u_t