|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import time |
|
|
import matplotlib.pyplot as plt |
|
|
from IPython.display import Image, display |
|
|
import os |
|
|
|
|
|
class HashGridEncoder(nn.Module): |
|
|
def __init__(self, n_levels=16, n_features_per_level=2, |
|
|
log2_hashmap_size=19, base_resolution=16, |
|
|
per_level_scale=1.5): |
|
|
super(HashGridEncoder, self).__init__() |
|
|
self.n_levels = n_levels; self.n_features_per_level = n_features_per_level |
|
|
self.log2_hashmap_size = log2_hashmap_size; self.base_resolution = base_resolution |
|
|
self.per_level_scale = per_level_scale; self.hashmap_size = 2**self.log2_hashmap_size |
|
|
self.embeddings = nn.ModuleList([nn.Embedding(self.hashmap_size, self.n_features_per_level) for i in range(self.n_levels)]) |
|
|
for emb in self.embeddings: nn.init.uniform_(emb.weight, -1e-4, 1e-4) |
|
|
def hash_fn(self, c, l): |
|
|
p = torch.tensor([1, 2654435761, 805459861], dtype=torch.int64, device=c.device) |
|
|
return (c * p).sum(dim=-1) % self.hashmap_size |
|
|
def trilinear_interp(self, x, v000, v001, v010, v011, v100, v101, v110, v111): |
|
|
w0=(1-x[...,0:1]); w1=x[...,0:1]; c00=v000*w0+v100*w1; c01=v001*w0+v101*w1 |
|
|
c10=v010*w0+v110*w1; c11=v011*w0+v111*w1; c0=c00*(1-x[...,1:2])+c10*x[...,1:2] |
|
|
c1=c01*(1-x[...,1:2])+c11*x[...,1:2]; c=c0*(1-x[...,2:3])+c1*x[...,2:3]; return c |
|
|
def forward(self, x_coords): |
|
|
all_features = [] |
|
|
for l in range(self.n_levels): |
|
|
r=int(self.base_resolution*(self.per_level_scale**l)); x_s=x_coords*r |
|
|
x_f=torch.floor(x_s).to(torch.int64); x_l=(x_s-x_f) |
|
|
corners = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[0,1,1],[1,0,0],[1,0,1],[1,1,0],[1,1,1]], dtype=torch.int64, device=x_coords.device) |
|
|
c_c_i=x_f.unsqueeze(1)+corners.unsqueeze(0); h_i=self.hash_fn(c_c_i,l) |
|
|
e_f=self.embeddings[l](h_i); i_f=self.trilinear_interp(x_l,e_f[:,0],e_f[:,1],e_f[:,2],e_f[:,3],e_f[:,4],e_f[:,5],e_f[:,6],e_f[:,7]) |
|
|
all_features.append(i_f) |
|
|
return torch.cat(all_features, dim=-1) |
|
|
|
|
|
class GeoNetHash(nn.Module): |
|
|
def __init__(self, box_size_scale=1.0): |
|
|
super(GeoNetHash, self).__init__() |
|
|
self.box_size_scale = box_size_scale |
|
|
self.encoder = HashGridEncoder(n_levels=16, n_features_per_level=2) |
|
|
self.mlp = nn.Sequential(nn.Linear(self.encoder.n_levels * self.encoder.n_features_per_level, 64), nn.ReLU(), nn.Linear(64, 1)) |
|
|
def normalize_coords(self, coords): |
|
|
self.min = coords.min(dim=0, keepdim=True)[0]; self.max = coords.max(dim=0, keepdim=True)[0] |
|
|
self.max[self.max == self.min] += 1e-6 |
|
|
return (coords - self.min) / (self.max - self.min) |
|
|
def normalize_inputs(self, coords): |
|
|
if self.min is None or self.max is None: raise Exception("GeoNet not normalized.") |
|
|
return (coords - self.min) / (self.max - self.min) |
|
|
def forward(self, coords_input_norm): |
|
|
features = self.encoder(coords_input_norm); return self.mlp(features) |
|
|
|
|
|
class PositionalEncoder4D(nn.Module): |
|
|
def __init__(self, input_dims, num_freqs): |
|
|
super(PositionalEncoder4D, self).__init__() |
|
|
self.input_dims=input_dims; self.num_freqs=num_freqs |
|
|
self.freq_bands=2.**torch.linspace(0.,num_freqs-1,num_freqs) |
|
|
self.output_dims=self.input_dims*(2*self.num_freqs+1) |
|
|
def forward(self, x): |
|
|
encoding=[x]; x_freq=x.unsqueeze(-1)*self.freq_bands.to(x.device) |
|
|
encoding.append(torch.sin(x_freq).view(x.shape[0], -1)) |
|
|
encoding.append(torch.cos(x_freq).view(x.shape[0], -1)) |
|
|
return torch.cat(encoding, dim=1) |
|
|
|
|
|
class MaxwellPINN(nn.Module): |
|
|
def __init__(self, num_freqs=8, hidden_dim=128): |
|
|
super(MaxwellPINN, self).__init__() |
|
|
self.encoder = PositionalEncoder4D(input_dims=4, num_freqs=num_freqs) |
|
|
self.network = nn.Sequential( |
|
|
nn.Linear(self.encoder.output_dims, hidden_dim), nn.Tanh(), |
|
|
nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), |
|
|
nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), |
|
|
nn.Linear(hidden_dim, 6) |
|
|
) |
|
|
def forward(self, x, y, z, t): |
|
|
coords=torch.cat([x,y,z,t],dim=1); x_encoded=self.encoder(coords); return self.network(x_encoded) |
|
|
|
|
|
|
|
|
DEVICE = torch.device("cpu") |
|
|
print(f"test: {DEVICE}") |
|
|
GEO_MODEL_PATH = "geonet_real_v30.pth" |
|
|
PHYS_MODEL_PATH = "physnet_v31_real.pth" |
|
|
NPZ_FILE_PATH = "ground_truth.npz" |
|
|
print(f"load{NPZ_FILE_PATH}...") |
|
|
try: |
|
|
ground_truth_data = np.load(NPZ_FILE_PATH, allow_pickle=True) |
|
|
coords_np = ground_truth_data['coords']; t_np = ground_truth_data['t'] |
|
|
|
|
|
coords_tensor = torch.tensor(coords_np, dtype=torch.float32).to(DEVICE) |
|
|
t_tensor = torch.tensor(t_np, dtype=torch.float32).view(-1, 1).to(DEVICE) |
|
|
T_MAX = t_tensor.max().item() |
|
|
|
|
|
print("loaded") |
|
|
except FileNotFoundError: |
|
|
print(f"file error'{NPZ_FILE_PATH}' file not found") |
|
|
exit() |
|
|
|
|
|
try: |
|
|
map_location = torch.device('cpu') |
|
|
|
|
|
|
|
|
model_geo = GeoNetHash().to(DEVICE) |
|
|
model_geo.normalize_coords(coords_tensor) |
|
|
model_geo.load_state_dict(torch.load(GEO_MODEL_PATH, map_location=map_location)) |
|
|
model_geo.eval() |
|
|
|
|
|
|
|
|
model_phys = MaxwellPINN(num_freqs=8, hidden_dim=128).to(DEVICE) |
|
|
model_phys.load_state_dict(torch.load(PHYS_MODEL_PATH, map_location=map_location)) |
|
|
model_phys.eval() |
|
|
print("loaded (GeoNet . PhysNet) ") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\n error loading model {e}") |
|
|
print("pth files not found") |
|
|
exit() |
|
|
|
|
|
print("\n--- graphs:") |
|
|
|
|
|
with torch.no_grad(): |
|
|
resolution = 100 |
|
|
|
|
|
min_vals = coords_tensor.min(dim=0)[0] |
|
|
max_vals = coords_tensor.max(dim=0)[0] |
|
|
|
|
|
x_range = np.linspace(min_vals[0].item(), max_vals[0].item(), resolution) |
|
|
z_range = np.linspace(min_vals[2].item(), max_vals[2].item(), resolution) |
|
|
xx, zz = np.meshgrid(x_range, z_range) |
|
|
yy = np.full_like(xx, (min_vals[1] + max_vals[1]).item() / 2.0) |
|
|
|
|
|
T_VISUALIZATION = T_MAX * 0.75 |
|
|
tt = np.ones_like(xx) * T_VISUALIZATION |
|
|
|
|
|
x_grid = torch.tensor(xx.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1) |
|
|
y_grid = torch.tensor(yy.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1) |
|
|
z_grid = torch.tensor(zz.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1) |
|
|
t_grid = torch.tensor(tt.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1) |
|
|
|
|
|
coords_viz = torch.cat([x_grid, y_grid, z_grid], dim=1) |
|
|
coords_viz_norm = model_geo.normalize_inputs(coords_viz) |
|
|
sdf_grid = model_geo(coords_viz_norm).cpu().numpy().reshape(resolution, resolution) |
|
|
|
|
|
pred_all = model_phys(x_grid, y_grid, z_grid, t_grid) |
|
|
Ex_pred_grid = pred_all[:, 0].cpu().numpy().reshape(resolution, resolution) |
|
|
|
|
|
Ex_pred_grid[sdf_grid < 0.0] = 0.0 |
|
|
|
|
|
v_max = np.max(np.abs(Ex_pred_grid)) |
|
|
v_max = max(v_max, 1e-4) |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
|
cf = plt.contourf(xx, zz, Ex_pred_grid, levels=50, cmap='RdBu_r', vmin=-v_max, vmax=v_max) |
|
|
|
|
|
|
|
|
plt.contour(xx, zz, sdf_grid, levels=[0], colors='black', linewidths=3) |
|
|
|
|
|
title_str = f'Predicted $\\tilde{{E}}_x$ field (v35) at time $t={T_VISUALIZATION:.2f}$' |
|
|
plt.title(title_str) |
|
|
|
|
|
plt.xlabel('x (m)'); plt.ylabel('z (m)'); plt.colorbar(cf, label='E_x field value (Normalized)') |
|
|
plt.axis('equal') |
|
|
plt.tight_layout() |
|
|
|
|
|
output_filename = "v35_final_plot_defense.png" |
|
|
plt.savefig(output_filename, dpi=300) |
|
|
plt.close() |
|
|
|
|
|
print(f"\n✅ completed'{output_filename}' saved.") |
|
|
display(Image(filename=output_filename)) |