import torch.nn as nn import torch.nn.functional as F import numpy as np import time import matplotlib.pyplot as plt from IPython.display import Image, display import os class HashGridEncoder(nn.Module):     def __init__(self, n_levels=16, n_features_per_level=2,                  log2_hashmap_size=19, base_resolution=16,                  per_level_scale=1.5):         super(HashGridEncoder, self).__init__()         self.n_levels = n_levels; self.n_features_per_level = n_features_per_level         self.log2_hashmap_size = log2_hashmap_size; self.base_resolution = base_resolution         self.per_level_scale = per_level_scale; self.hashmap_size = 2**self.log2_hashmap_size         self.embeddings = nn.ModuleList([nn.Embedding(self.hashmap_size, self.n_features_per_level) for i in range(self.n_levels)])         for emb in self.embeddings: nn.init.uniform_(emb.weight, -1e-4, 1e-4)     def hash_fn(self, c, l):         p = torch.tensor([1, 2654435761, 805459861], dtype=torch.int64, device=c.device)         return (c * p).sum(dim=-1) % self.hashmap_size     def trilinear_interp(self, x, v000, v001, v010, v011, v100, v101, v110, v111):         w0=(1-x[...,0:1]); w1=x[...,0:1]; c00=v000*w0+v100*w1; c01=v001*w0+v101*w1         c10=v010*w0+v110*w1; c11=v011*w0+v111*w1; c0=c00*(1-x[...,1:2])+c10*x[...,1:2]         c1=c01*(1-x[...,1:2])+c11*x[...,1:2]; c=c0*(1-x[...,2:3])+c1*x[...,2:3]; return c     def forward(self, x_coords):         all_features = []         for l in range(self.n_levels):             r=int(self.base_resolution*(self.per_level_scale**l)); x_s=x_coords*r             x_f=torch.floor(x_s).to(torch.int64); x_l=(x_s-x_f)             corners = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[0,1,1],[1,0,0],[1,0,1],[1,1,0],[1,1,1]], dtype=torch.int64, device=x_coords.device)             c_c_i=x_f.unsqueeze(1)+corners.unsqueeze(0); h_i=self.hash_fn(c_c_i,l)             e_f=self.embeddings[l](h_i); i_f=self.trilinear_interp(x_l,e_f[:,0],e_f[:,1],e_f[:,2],e_f[:,3],e_f[:,4],e_f[:,5],e_f[:,6],e_f[:,7])             all_features.append(i_f)         return torch.cat(all_features, dim=-1) class GeoNetHash(nn.Module):     def __init__(self, box_size_scale=1.0):         super(GeoNetHash, self).__init__()         self.box_size_scale = box_size_scale         self.encoder = HashGridEncoder(n_levels=16, n_features_per_level=2)         self.mlp = nn.Sequential(nn.Linear(self.encoder.n_levels * self.encoder.n_features_per_level, 64), nn.ReLU(), nn.Linear(64, 1))     def normalize_coords(self, coords):         self.min = coords.min(dim=0, keepdim=True)[0]; self.max = coords.max(dim=0, keepdim=True)[0]         self.max[self.max == self.min] += 1e-6         return (coords - self.min) / (self.max - self.min)     def normalize_inputs(self, coords):         if self.min is None or self.max is None: raise Exception("GeoNet not normalized.")         return (coords - self.min) / (self.max - self.min)     def forward(self, coords_input_norm):         features = self.encoder(coords_input_norm); return self.mlp(features) class PositionalEncoder4D(nn.Module):     def __init__(self, input_dims, num_freqs):         super(PositionalEncoder4D, self).__init__()         self.input_dims=input_dims; self.num_freqs=num_freqs         self.freq_bands=2.**torch.linspace(0.,num_freqs-1,num_freqs)         self.output_dims=self.input_dims*(2*self.num_freqs+1)     def forward(self, x):         encoding=[x]; x_freq=x.unsqueeze(-1)*self.freq_bands.to(x.device)         encoding.append(torch.sin(x_freq).view(x.shape[0], -1))         encoding.append(torch.cos(x_freq).view(x.shape[0], -1))         return torch.cat(encoding, dim=1) class MaxwellPINN(nn.Module):     def __init__(self, num_freqs=8, hidden_dim=128):         super(MaxwellPINN, self).__init__()         self.encoder = PositionalEncoder4D(input_dims=4, num_freqs=num_freqs)         self.network = nn.Sequential(             nn.Linear(self.encoder.output_dims, hidden_dim), nn.Tanh(),             nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),             nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),             nn.Linear(hidden_dim, 6)         )     def forward(self, x, y, z, t):         coords=torch.cat([x,y,z,t],dim=1); x_encoded=self.encoder(coords); return self.network(x_encoded) DEVICE = torch.device("cpu") print(f"test: {DEVICE}") GEO_MODEL_PATH = "geonet_real_v30.pth" PHYS_MODEL_PATH = "physnet_v31_real.pth" NPZ_FILE_PATH = "ground_truth.npz" print(f"load{NPZ_FILE_PATH}...") try:     ground_truth_data = np.load(NPZ_FILE_PATH, allow_pickle=True)     coords_np = ground_truth_data['coords']; t_np = ground_truth_data['t']         coords_tensor = torch.tensor(coords_np, dtype=torch.float32).to(DEVICE)     t_tensor = torch.tensor(t_np, dtype=torch.float32).view(-1, 1).to(DEVICE)     T_MAX = t_tensor.max().item()     print("loaded") except FileNotFoundError:     print(f"file error'{NPZ_FILE_PATH}' file not found")     exit() try:     map_location = torch.device('cpu')         # GeoNet     model_geo = GeoNetHash().to(DEVICE)     model_geo.normalize_coords(coords_tensor)     model_geo.load_state_dict(torch.load(GEO_MODEL_PATH, map_location=map_location))     model_geo.eval()         # PhysNet     model_phys = MaxwellPINN(num_freqs=8, hidden_dim=128).to(DEVICE)     model_phys.load_state_dict(torch.load(PHYS_MODEL_PATH, map_location=map_location))     model_phys.eval()     print("loaded (GeoNet . PhysNet) ") except Exception as e:     print(f"\n error loading model {e}")     print("pth files not found")     exit() print("\n--- graphs:") with torch.no_grad():     resolution = 100         min_vals = coords_tensor.min(dim=0)[0]     max_vals = coords_tensor.max(dim=0)[0]     x_range = np.linspace(min_vals[0].item(), max_vals[0].item(), resolution)     z_range = np.linspace(min_vals[2].item(), max_vals[2].item(), resolution)     xx, zz = np.meshgrid(x_range, z_range)     yy = np.full_like(xx, (min_vals[1] + max_vals[1]).item() / 2.0)         T_VISUALIZATION = T_MAX * 0.75     tt = np.ones_like(xx) * T_VISUALIZATION     x_grid = torch.tensor(xx.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)     y_grid = torch.tensor(yy.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)     z_grid = torch.tensor(zz.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)     t_grid = torch.tensor(tt.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)     coords_viz = torch.cat([x_grid, y_grid, z_grid], dim=1)     coords_viz_norm = model_geo.normalize_inputs(coords_viz)     sdf_grid = model_geo(coords_viz_norm).cpu().numpy().reshape(resolution, resolution)         pred_all = model_phys(x_grid, y_grid, z_grid, t_grid)     Ex_pred_grid = pred_all[:, 0].cpu().numpy().reshape(resolution, resolution)         Ex_pred_grid[sdf_grid < 0.0] = 0.0         v_max = np.max(np.abs(Ex_pred_grid))     v_max = max(v_max, 1e-4)     plt.figure(figsize=(8, 6))     cf = plt.contourf(xx, zz, Ex_pred_grid, levels=50, cmap='RdBu_r', vmin=-v_max, vmax=v_max)         plt.contour(xx, zz, sdf_grid, levels=[0], colors='black', linewidths=3)         title_str = f'Predicted $\\tilde{{E}}_x$ field (v35) at time $t={T_VISUALIZATION:.2f}$'     plt.title(title_str)         plt.xlabel('x (m)'); plt.ylabel('z (m)'); plt.colorbar(cf, label='E_x field value (Normalized)')     plt.axis('equal')     plt.tight_layout()         output_filename = "v35_final_plot_defense.png"     plt.savefig(output_filename, dpi=300)     plt.close() print(f"\n✅ completed'{output_filename}' saved.") display(Image(filename=output_filename))