Update pinn_electromagnetics/losses.py
Browse files
pinn_electromagnetics/losses.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
| 1 |
|
| 2 |
import torch
|
| 3 |
-
import numpy as np
|
| 4 |
|
| 5 |
-
# --- Derivative Computation (from v28) ---
|
| 6 |
def compute_all_derivatives(model_phys, x, y, z, t):
|
| 7 |
outputs = model_phys(x, y, z, t)
|
| 8 |
Ex, Ey, Ez = outputs[:, 0:1], outputs[:, 1:2], outputs[:, 2:3]
|
| 9 |
Bx, By, Bz = outputs[:, 3:4], outputs[:, 4:5], outputs[:, 5:6]
|
| 10 |
|
| 11 |
-
|
| 12 |
dEx_grads = torch.autograd.grad(Ex.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
|
| 13 |
dEy_grads = torch.autograd.grad(Ey.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
|
| 14 |
dEz_grads = torch.autograd.grad(Ez.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
|
|
@@ -16,7 +15,7 @@ def compute_all_derivatives(model_phys, x, y, z, t):
|
|
| 16 |
dBy_grads = torch.autograd.grad(By.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
|
| 17 |
dBz_grads = torch.autograd.grad(Bz.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
|
| 18 |
|
| 19 |
-
|
| 20 |
def get_grad(grad_tuple, idx):
|
| 21 |
return grad_tuple[idx] if grad_tuple[idx] is not None else torch.zeros_like(x)
|
| 22 |
|
|
@@ -30,23 +29,18 @@ def compute_all_derivatives(model_phys, x, y, z, t):
|
|
| 30 |
}
|
| 31 |
return outputs, derivs
|
| 32 |
|
| 33 |
-
# --- Maxwell Loss (from v28) ---
|
| 34 |
def compute_maxwell_loss(derivs):
|
| 35 |
-
# Maxwell's Equations (dimensionless units)
|
| 36 |
|
| 37 |
-
|
| 38 |
loss_gauss_E = (derivs['dEx_dx'] + derivs['dEy_dy'] + derivs['dEz_dz'])**2
|
| 39 |
|
| 40 |
-
# Gauss's Law for B: div(B) = 0
|
| 41 |
loss_gauss_B = (derivs['dBx_dx'] + derivs['dBy_dy'] + derivs['dBz_dz'])**2
|
| 42 |
|
| 43 |
-
# Faraday's Law: curl(E) = -dB/dt
|
| 44 |
loss_faraday_x = (derivs['dEz_dy'] - derivs['dEy_dz'] + derivs['dBx_dt'])**2
|
| 45 |
loss_faraday_y = (derivs['dEx_dz'] - derivs['dEz_dx'] + derivs['dBy_dt'])**2
|
| 46 |
loss_faraday_z = (derivs['dEy_dx'] - derivs['dEx_dy'] + derivs['dBz_dt'])**2
|
| 47 |
loss_faraday = loss_faraday_x + loss_faraday_y + loss_faraday_z
|
| 48 |
-
|
| 49 |
-
# Ampere-Maxwell Law: curl(B) = dE/dt (assuming no current J)
|
| 50 |
loss_ampere_x = (derivs['dBz_dy'] - derivs['dBy_dz'] - derivs['dEx_dt'])**2
|
| 51 |
loss_ampere_y = (derivs['dBx_dz'] - derivs['dBz_dx'] - derivs['dEy_dt'])**2
|
| 52 |
loss_ampere_z = (derivs['dBy_dx'] - derivs['dBx_dy'] - derivs['dEz_dt'])**2
|
|
@@ -54,7 +48,6 @@ def compute_maxwell_loss(derivs):
|
|
| 54 |
|
| 55 |
return torch.mean(loss_gauss_E + loss_gauss_B + loss_faraday + loss_ampere)
|
| 56 |
|
| 57 |
-
# --- Data Loss (from v31.2, adapted for NPZ data) ---
|
| 58 |
def compute_data_loss(model_phys,
|
| 59 |
coords_tensor, E0_tensor, B0_tensor,
|
| 60 |
source_pos_tensor, source_orientation_tensor,
|
|
@@ -62,7 +55,6 @@ def compute_data_loss(model_phys,
|
|
| 62 |
n_samples_ic=1024, n_samples_source=128):
|
| 63 |
loss_data = 0.0
|
| 64 |
|
| 65 |
-
# 1. Initial Conditions (t=0): E=E0, B=B0
|
| 66 |
idx_ic = torch.randperm(coords_tensor.shape[0], device=coords_tensor.device)[:n_samples_ic]
|
| 67 |
x_ic, y_ic, z_ic = coords_tensor[idx_ic, 0:1], coords_tensor[idx_ic, 1:2], coords_tensor[idx_ic, 2:3]
|
| 68 |
t_ic = torch.zeros_like(x_ic)
|
|
@@ -72,22 +64,16 @@ def compute_data_loss(model_phys,
|
|
| 72 |
|
| 73 |
loss_data += torch.mean((pred_ic[:, 0:3] - E_actual_ic)**2)
|
| 74 |
loss_data += torch.mean((pred_ic[:, 3:6] - B_actual_ic)**2)
|
| 75 |
-
|
| 76 |
-
# 2. Source Boundary (time-varying E-field, B=0)
|
| 77 |
-
# We use all time points for the source signal
|
| 78 |
n_times_source = t_tensor.shape[0]
|
| 79 |
-
# Source position needs to be expanded to match the number of time samples
|
| 80 |
x_s = source_pos_tensor[0:1].expand(n_times_source, -1)
|
| 81 |
y_s = source_pos_tensor[1:2].expand(n_times_source, -1)
|
| 82 |
z_s = source_pos_tensor[2:3].expand(n_times_source, -1)
|
| 83 |
-
t_s = t_tensor
|
| 84 |
|
| 85 |
pred_source = model_phys(x_s, y_s, z_s, t_s)
|
| 86 |
-
|
| 87 |
-
# E_actual = signal * orientation. Expand orientation to match time samples.
|
| 88 |
E_actual_source = source_signal_tensor * source_orientation_tensor.view(1, 3)
|
| 89 |
|
| 90 |
loss_data += torch.mean((pred_source[:, 0:3] - E_actual_source)**2)
|
| 91 |
-
loss_data += torch.mean(pred_source[:, 3:6]**2)
|
| 92 |
|
| 93 |
return loss_data
|
|
|
|
| 1 |
|
| 2 |
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
|
|
|
|
| 5 |
def compute_all_derivatives(model_phys, x, y, z, t):
|
| 6 |
outputs = model_phys(x, y, z, t)
|
| 7 |
Ex, Ey, Ez = outputs[:, 0:1], outputs[:, 1:2], outputs[:, 2:3]
|
| 8 |
Bx, By, Bz = outputs[:, 3:4], outputs[:, 4:5], outputs[:, 5:6]
|
| 9 |
|
| 10 |
+
|
| 11 |
dEx_grads = torch.autograd.grad(Ex.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
|
| 12 |
dEy_grads = torch.autograd.grad(Ey.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
|
| 13 |
dEz_grads = torch.autograd.grad(Ez.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
|
|
|
|
| 15 |
dBy_grads = torch.autograd.grad(By.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
|
| 16 |
dBz_grads = torch.autograd.grad(Bz.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
|
| 17 |
|
| 18 |
+
|
| 19 |
def get_grad(grad_tuple, idx):
|
| 20 |
return grad_tuple[idx] if grad_tuple[idx] is not None else torch.zeros_like(x)
|
| 21 |
|
|
|
|
| 29 |
}
|
| 30 |
return outputs, derivs
|
| 31 |
|
|
|
|
| 32 |
def compute_maxwell_loss(derivs):
|
|
|
|
| 33 |
|
| 34 |
+
|
| 35 |
loss_gauss_E = (derivs['dEx_dx'] + derivs['dEy_dy'] + derivs['dEz_dz'])**2
|
| 36 |
|
|
|
|
| 37 |
loss_gauss_B = (derivs['dBx_dx'] + derivs['dBy_dy'] + derivs['dBz_dz'])**2
|
| 38 |
|
|
|
|
| 39 |
loss_faraday_x = (derivs['dEz_dy'] - derivs['dEy_dz'] + derivs['dBx_dt'])**2
|
| 40 |
loss_faraday_y = (derivs['dEx_dz'] - derivs['dEz_dx'] + derivs['dBy_dt'])**2
|
| 41 |
loss_faraday_z = (derivs['dEy_dx'] - derivs['dEx_dy'] + derivs['dBz_dt'])**2
|
| 42 |
loss_faraday = loss_faraday_x + loss_faraday_y + loss_faraday_z
|
| 43 |
+
|
|
|
|
| 44 |
loss_ampere_x = (derivs['dBz_dy'] - derivs['dBy_dz'] - derivs['dEx_dt'])**2
|
| 45 |
loss_ampere_y = (derivs['dBx_dz'] - derivs['dBz_dx'] - derivs['dEy_dt'])**2
|
| 46 |
loss_ampere_z = (derivs['dBy_dx'] - derivs['dBx_dy'] - derivs['dEz_dt'])**2
|
|
|
|
| 48 |
|
| 49 |
return torch.mean(loss_gauss_E + loss_gauss_B + loss_faraday + loss_ampere)
|
| 50 |
|
|
|
|
| 51 |
def compute_data_loss(model_phys,
|
| 52 |
coords_tensor, E0_tensor, B0_tensor,
|
| 53 |
source_pos_tensor, source_orientation_tensor,
|
|
|
|
| 55 |
n_samples_ic=1024, n_samples_source=128):
|
| 56 |
loss_data = 0.0
|
| 57 |
|
|
|
|
| 58 |
idx_ic = torch.randperm(coords_tensor.shape[0], device=coords_tensor.device)[:n_samples_ic]
|
| 59 |
x_ic, y_ic, z_ic = coords_tensor[idx_ic, 0:1], coords_tensor[idx_ic, 1:2], coords_tensor[idx_ic, 2:3]
|
| 60 |
t_ic = torch.zeros_like(x_ic)
|
|
|
|
| 64 |
|
| 65 |
loss_data += torch.mean((pred_ic[:, 0:3] - E_actual_ic)**2)
|
| 66 |
loss_data += torch.mean((pred_ic[:, 3:6] - B_actual_ic)**2)
|
|
|
|
|
|
|
|
|
|
| 67 |
n_times_source = t_tensor.shape[0]
|
|
|
|
| 68 |
x_s = source_pos_tensor[0:1].expand(n_times_source, -1)
|
| 69 |
y_s = source_pos_tensor[1:2].expand(n_times_source, -1)
|
| 70 |
z_s = source_pos_tensor[2:3].expand(n_times_source, -1)
|
| 71 |
+
t_s = t_tensor
|
| 72 |
|
| 73 |
pred_source = model_phys(x_s, y_s, z_s, t_s)
|
|
|
|
|
|
|
| 74 |
E_actual_source = source_signal_tensor * source_orientation_tensor.view(1, 3)
|
| 75 |
|
| 76 |
loss_data += torch.mean((pred_source[:, 0:3] - E_actual_source)**2)
|
| 77 |
+
loss_data += torch.mean(pred_source[:, 3:6]**2)
|
| 78 |
|
| 79 |
return loss_data
|