ayda138000 commited on
Commit
f570fb5
·
verified ·
1 Parent(s): dd12400

Update pinn_electromagnetics/losses.py

Browse files
Files changed (1) hide show
  1. pinn_electromagnetics/losses.py +7 -21
pinn_electromagnetics/losses.py CHANGED
@@ -1,14 +1,13 @@
1
 
2
  import torch
3
- import numpy as np # Used for constants like pi if needed elsewhere
4
 
5
- # --- Derivative Computation (from v28) ---
6
  def compute_all_derivatives(model_phys, x, y, z, t):
7
  outputs = model_phys(x, y, z, t)
8
  Ex, Ey, Ez = outputs[:, 0:1], outputs[:, 1:2], outputs[:, 2:3]
9
  Bx, By, Bz = outputs[:, 3:4], outputs[:, 4:5], outputs[:, 5:6]
10
 
11
- # Use .sum() for scalar output when computing gradients over multiple components
12
  dEx_grads = torch.autograd.grad(Ex.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
13
  dEy_grads = torch.autograd.grad(Ey.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
14
  dEz_grads = torch.autograd.grad(Ez.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
@@ -16,7 +15,7 @@ def compute_all_derivatives(model_phys, x, y, z, t):
16
  dBy_grads = torch.autograd.grad(By.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
17
  dBz_grads = torch.autograd.grad(Bz.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
18
 
19
- # Ensure gradients are not None if allow_unused=True
20
  def get_grad(grad_tuple, idx):
21
  return grad_tuple[idx] if grad_tuple[idx] is not None else torch.zeros_like(x)
22
 
@@ -30,23 +29,18 @@ def compute_all_derivatives(model_phys, x, y, z, t):
30
  }
31
  return outputs, derivs
32
 
33
- # --- Maxwell Loss (from v28) ---
34
  def compute_maxwell_loss(derivs):
35
- # Maxwell's Equations (dimensionless units)
36
 
37
- # Gauss's Law for E: div(E) = 0
38
  loss_gauss_E = (derivs['dEx_dx'] + derivs['dEy_dy'] + derivs['dEz_dz'])**2
39
 
40
- # Gauss's Law for B: div(B) = 0
41
  loss_gauss_B = (derivs['dBx_dx'] + derivs['dBy_dy'] + derivs['dBz_dz'])**2
42
 
43
- # Faraday's Law: curl(E) = -dB/dt
44
  loss_faraday_x = (derivs['dEz_dy'] - derivs['dEy_dz'] + derivs['dBx_dt'])**2
45
  loss_faraday_y = (derivs['dEx_dz'] - derivs['dEz_dx'] + derivs['dBy_dt'])**2
46
  loss_faraday_z = (derivs['dEy_dx'] - derivs['dEx_dy'] + derivs['dBz_dt'])**2
47
  loss_faraday = loss_faraday_x + loss_faraday_y + loss_faraday_z
48
-
49
- # Ampere-Maxwell Law: curl(B) = dE/dt (assuming no current J)
50
  loss_ampere_x = (derivs['dBz_dy'] - derivs['dBy_dz'] - derivs['dEx_dt'])**2
51
  loss_ampere_y = (derivs['dBx_dz'] - derivs['dBz_dx'] - derivs['dEy_dt'])**2
52
  loss_ampere_z = (derivs['dBy_dx'] - derivs['dBx_dy'] - derivs['dEz_dt'])**2
@@ -54,7 +48,6 @@ def compute_maxwell_loss(derivs):
54
 
55
  return torch.mean(loss_gauss_E + loss_gauss_B + loss_faraday + loss_ampere)
56
 
57
- # --- Data Loss (from v31.2, adapted for NPZ data) ---
58
  def compute_data_loss(model_phys,
59
  coords_tensor, E0_tensor, B0_tensor,
60
  source_pos_tensor, source_orientation_tensor,
@@ -62,7 +55,6 @@ def compute_data_loss(model_phys,
62
  n_samples_ic=1024, n_samples_source=128):
63
  loss_data = 0.0
64
 
65
- # 1. Initial Conditions (t=0): E=E0, B=B0
66
  idx_ic = torch.randperm(coords_tensor.shape[0], device=coords_tensor.device)[:n_samples_ic]
67
  x_ic, y_ic, z_ic = coords_tensor[idx_ic, 0:1], coords_tensor[idx_ic, 1:2], coords_tensor[idx_ic, 2:3]
68
  t_ic = torch.zeros_like(x_ic)
@@ -72,22 +64,16 @@ def compute_data_loss(model_phys,
72
 
73
  loss_data += torch.mean((pred_ic[:, 0:3] - E_actual_ic)**2)
74
  loss_data += torch.mean((pred_ic[:, 3:6] - B_actual_ic)**2)
75
-
76
- # 2. Source Boundary (time-varying E-field, B=0)
77
- # We use all time points for the source signal
78
  n_times_source = t_tensor.shape[0]
79
- # Source position needs to be expanded to match the number of time samples
80
  x_s = source_pos_tensor[0:1].expand(n_times_source, -1)
81
  y_s = source_pos_tensor[1:2].expand(n_times_source, -1)
82
  z_s = source_pos_tensor[2:3].expand(n_times_source, -1)
83
- t_s = t_tensor # shape (n_times_source, 1)
84
 
85
  pred_source = model_phys(x_s, y_s, z_s, t_s)
86
-
87
- # E_actual = signal * orientation. Expand orientation to match time samples.
88
  E_actual_source = source_signal_tensor * source_orientation_tensor.view(1, 3)
89
 
90
  loss_data += torch.mean((pred_source[:, 0:3] - E_actual_source)**2)
91
- loss_data += torch.mean(pred_source[:, 3:6]**2) # B=0 at the source
92
 
93
  return loss_data
 
1
 
2
  import torch
3
+ import numpy as np
4
 
 
5
  def compute_all_derivatives(model_phys, x, y, z, t):
6
  outputs = model_phys(x, y, z, t)
7
  Ex, Ey, Ez = outputs[:, 0:1], outputs[:, 1:2], outputs[:, 2:3]
8
  Bx, By, Bz = outputs[:, 3:4], outputs[:, 4:5], outputs[:, 5:6]
9
 
10
+
11
  dEx_grads = torch.autograd.grad(Ex.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
12
  dEy_grads = torch.autograd.grad(Ey.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
13
  dEz_grads = torch.autograd.grad(Ez.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
 
15
  dBy_grads = torch.autograd.grad(By.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
16
  dBz_grads = torch.autograd.grad(Bz.sum(), [x, y, z, t], create_graph=True, allow_unused=True)
17
 
18
+
19
  def get_grad(grad_tuple, idx):
20
  return grad_tuple[idx] if grad_tuple[idx] is not None else torch.zeros_like(x)
21
 
 
29
  }
30
  return outputs, derivs
31
 
 
32
  def compute_maxwell_loss(derivs):
 
33
 
34
+
35
  loss_gauss_E = (derivs['dEx_dx'] + derivs['dEy_dy'] + derivs['dEz_dz'])**2
36
 
 
37
  loss_gauss_B = (derivs['dBx_dx'] + derivs['dBy_dy'] + derivs['dBz_dz'])**2
38
 
 
39
  loss_faraday_x = (derivs['dEz_dy'] - derivs['dEy_dz'] + derivs['dBx_dt'])**2
40
  loss_faraday_y = (derivs['dEx_dz'] - derivs['dEz_dx'] + derivs['dBy_dt'])**2
41
  loss_faraday_z = (derivs['dEy_dx'] - derivs['dEx_dy'] + derivs['dBz_dt'])**2
42
  loss_faraday = loss_faraday_x + loss_faraday_y + loss_faraday_z
43
+
 
44
  loss_ampere_x = (derivs['dBz_dy'] - derivs['dBy_dz'] - derivs['dEx_dt'])**2
45
  loss_ampere_y = (derivs['dBx_dz'] - derivs['dBz_dx'] - derivs['dEy_dt'])**2
46
  loss_ampere_z = (derivs['dBy_dx'] - derivs['dBx_dy'] - derivs['dEz_dt'])**2
 
48
 
49
  return torch.mean(loss_gauss_E + loss_gauss_B + loss_faraday + loss_ampere)
50
 
 
51
  def compute_data_loss(model_phys,
52
  coords_tensor, E0_tensor, B0_tensor,
53
  source_pos_tensor, source_orientation_tensor,
 
55
  n_samples_ic=1024, n_samples_source=128):
56
  loss_data = 0.0
57
 
 
58
  idx_ic = torch.randperm(coords_tensor.shape[0], device=coords_tensor.device)[:n_samples_ic]
59
  x_ic, y_ic, z_ic = coords_tensor[idx_ic, 0:1], coords_tensor[idx_ic, 1:2], coords_tensor[idx_ic, 2:3]
60
  t_ic = torch.zeros_like(x_ic)
 
64
 
65
  loss_data += torch.mean((pred_ic[:, 0:3] - E_actual_ic)**2)
66
  loss_data += torch.mean((pred_ic[:, 3:6] - B_actual_ic)**2)
 
 
 
67
  n_times_source = t_tensor.shape[0]
 
68
  x_s = source_pos_tensor[0:1].expand(n_times_source, -1)
69
  y_s = source_pos_tensor[1:2].expand(n_times_source, -1)
70
  z_s = source_pos_tensor[2:3].expand(n_times_source, -1)
71
+ t_s = t_tensor
72
 
73
  pred_source = model_phys(x_s, y_s, z_s, t_s)
 
 
74
  E_actual_source = source_signal_tensor * source_orientation_tensor.view(1, 3)
75
 
76
  loss_data += torch.mean((pred_source[:, 0:3] - E_actual_source)**2)
77
+ loss_data += torch.mean(pred_source[:, 3:6]**2)
78
 
79
  return loss_data