File size: 7,917 Bytes
2261778 ad8eda7 2261778 ad8eda7 2261778 ad8eda7 2261778 ad8eda7 2261778 ad8eda7 2261778 ad8eda7 2261778 ad8eda7 2261778 ad8eda7 2261778 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import matplotlib.pyplot as plt
from IPython.display import Image, display
import os
class HashGridEncoder(nn.Module):
def __init__(self, n_levels=16, n_features_per_level=2,
log2_hashmap_size=19, base_resolution=16,
per_level_scale=1.5):
super(HashGridEncoder, self).__init__()
self.n_levels = n_levels; self.n_features_per_level = n_features_per_level
self.log2_hashmap_size = log2_hashmap_size; self.base_resolution = base_resolution
self.per_level_scale = per_level_scale; self.hashmap_size = 2**self.log2_hashmap_size
self.embeddings = nn.ModuleList([nn.Embedding(self.hashmap_size, self.n_features_per_level) for i in range(self.n_levels)])
for emb in self.embeddings: nn.init.uniform_(emb.weight, -1e-4, 1e-4)
def hash_fn(self, c, l):
p = torch.tensor([1, 2654435761, 805459861], dtype=torch.int64, device=c.device)
return (c * p).sum(dim=-1) % self.hashmap_size
def trilinear_interp(self, x, v000, v001, v010, v011, v100, v101, v110, v111):
w0=(1-x[...,0:1]); w1=x[...,0:1]; c00=v000*w0+v100*w1; c01=v001*w0+v101*w1
c10=v010*w0+v110*w1; c11=v011*w0+v111*w1; c0=c00*(1-x[...,1:2])+c10*x[...,1:2]
c1=c01*(1-x[...,1:2])+c11*x[...,1:2]; c=c0*(1-x[...,2:3])+c1*x[...,2:3]; return c
def forward(self, x_coords):
all_features = []
for l in range(self.n_levels):
r=int(self.base_resolution*(self.per_level_scale**l)); x_s=x_coords*r
x_f=torch.floor(x_s).to(torch.int64); x_l=(x_s-x_f)
corners = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[0,1,1],[1,0,0],[1,0,1],[1,1,0],[1,1,1]], dtype=torch.int64, device=x_coords.device)
c_c_i=x_f.unsqueeze(1)+corners.unsqueeze(0); h_i=self.hash_fn(c_c_i,l)
e_f=self.embeddings[l](h_i); i_f=self.trilinear_interp(x_l,e_f[:,0],e_f[:,1],e_f[:,2],e_f[:,3],e_f[:,4],e_f[:,5],e_f[:,6],e_f[:,7])
all_features.append(i_f)
return torch.cat(all_features, dim=-1)
class GeoNetHash(nn.Module):
def __init__(self, box_size_scale=1.0):
super(GeoNetHash, self).__init__()
self.box_size_scale = box_size_scale
self.encoder = HashGridEncoder(n_levels=16, n_features_per_level=2)
self.mlp = nn.Sequential(nn.Linear(self.encoder.n_levels * self.encoder.n_features_per_level, 64), nn.ReLU(), nn.Linear(64, 1))
def normalize_coords(self, coords):
self.min = coords.min(dim=0, keepdim=True)[0]; self.max = coords.max(dim=0, keepdim=True)[0]
self.max[self.max == self.min] += 1e-6
return (coords - self.min) / (self.max - self.min)
def normalize_inputs(self, coords):
if self.min is None or self.max is None: raise Exception("GeoNet not normalized.")
return (coords - self.min) / (self.max - self.min)
def forward(self, coords_input_norm):
features = self.encoder(coords_input_norm); return self.mlp(features)
class PositionalEncoder4D(nn.Module):
def __init__(self, input_dims, num_freqs):
super(PositionalEncoder4D, self).__init__()
self.input_dims=input_dims; self.num_freqs=num_freqs
self.freq_bands=2.**torch.linspace(0.,num_freqs-1,num_freqs)
self.output_dims=self.input_dims*(2*self.num_freqs+1)
def forward(self, x):
encoding=[x]; x_freq=x.unsqueeze(-1)*self.freq_bands.to(x.device)
encoding.append(torch.sin(x_freq).view(x.shape[0], -1))
encoding.append(torch.cos(x_freq).view(x.shape[0], -1))
return torch.cat(encoding, dim=1)
class MaxwellPINN(nn.Module):
def __init__(self, num_freqs=8, hidden_dim=128):
super(MaxwellPINN, self).__init__()
self.encoder = PositionalEncoder4D(input_dims=4, num_freqs=num_freqs)
self.network = nn.Sequential(
nn.Linear(self.encoder.output_dims, hidden_dim), nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
nn.Linear(hidden_dim, 6)
)
def forward(self, x, y, z, t):
coords=torch.cat([x,y,z,t],dim=1); x_encoded=self.encoder(coords); return self.network(x_encoded)
DEVICE = torch.device("cpu")
print(f"test: {DEVICE}")
GEO_MODEL_PATH = "geonet_real_v30.pth"
PHYS_MODEL_PATH = "physnet_v31_real.pth"
NPZ_FILE_PATH = "ground_truth.npz"
print(f"load{NPZ_FILE_PATH}...")
try:
ground_truth_data = np.load(NPZ_FILE_PATH, allow_pickle=True)
coords_np = ground_truth_data['coords']; t_np = ground_truth_data['t']
coords_tensor = torch.tensor(coords_np, dtype=torch.float32).to(DEVICE)
t_tensor = torch.tensor(t_np, dtype=torch.float32).view(-1, 1).to(DEVICE)
T_MAX = t_tensor.max().item()
print("loaded")
except FileNotFoundError:
print(f"file error'{NPZ_FILE_PATH}' file not found")
exit()
try:
map_location = torch.device('cpu')
# GeoNet
model_geo = GeoNetHash().to(DEVICE)
model_geo.normalize_coords(coords_tensor)
model_geo.load_state_dict(torch.load(GEO_MODEL_PATH, map_location=map_location))
model_geo.eval()
# PhysNet
model_phys = MaxwellPINN(num_freqs=8, hidden_dim=128).to(DEVICE)
model_phys.load_state_dict(torch.load(PHYS_MODEL_PATH, map_location=map_location))
model_phys.eval()
print("loaded (GeoNet . PhysNet) ")
except Exception as e:
print(f"\n error loading model {e}")
print("pth files not found")
exit()
print("\n--- graphs:")
with torch.no_grad():
resolution = 100
min_vals = coords_tensor.min(dim=0)[0]
max_vals = coords_tensor.max(dim=0)[0]
x_range = np.linspace(min_vals[0].item(), max_vals[0].item(), resolution)
z_range = np.linspace(min_vals[2].item(), max_vals[2].item(), resolution)
xx, zz = np.meshgrid(x_range, z_range)
yy = np.full_like(xx, (min_vals[1] + max_vals[1]).item() / 2.0)
T_VISUALIZATION = T_MAX * 0.75
tt = np.ones_like(xx) * T_VISUALIZATION
x_grid = torch.tensor(xx.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)
y_grid = torch.tensor(yy.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)
z_grid = torch.tensor(zz.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)
t_grid = torch.tensor(tt.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)
coords_viz = torch.cat([x_grid, y_grid, z_grid], dim=1)
coords_viz_norm = model_geo.normalize_inputs(coords_viz)
sdf_grid = model_geo(coords_viz_norm).cpu().numpy().reshape(resolution, resolution)
pred_all = model_phys(x_grid, y_grid, z_grid, t_grid)
Ex_pred_grid = pred_all[:, 0].cpu().numpy().reshape(resolution, resolution)
Ex_pred_grid[sdf_grid < 0.0] = 0.0
v_max = np.max(np.abs(Ex_pred_grid))
v_max = max(v_max, 1e-4)
plt.figure(figsize=(8, 6))
cf = plt.contourf(xx, zz, Ex_pred_grid, levels=50, cmap='RdBu_r', vmin=-v_max, vmax=v_max)
plt.contour(xx, zz, sdf_grid, levels=[0], colors='black', linewidths=3)
title_str = f'Predicted $\\tilde{{E}}_x$ field (v35) at time $t={T_VISUALIZATION:.2f}$'
plt.title(title_str)
plt.xlabel('x (m)'); plt.ylabel('z (m)'); plt.colorbar(cf, label='E_x field value (Normalized)')
plt.axis('equal')
plt.tight_layout()
output_filename = "v35_final_plot_defense.png"
plt.savefig(output_filename, dpi=300)
plt.close()
print(f"\n✅ completed'{output_filename}' saved.")
display(Image(filename=output_filename)) |