File size: 7,917 Bytes
2261778
 
ad8eda7
2261778
ad8eda7
2261778
ad8eda7
 
2261778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad8eda7
2261778
 
 
 
 
 
 
 
ad8eda7
2261778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad8eda7
2261778
 
 
ad8eda7
2261778
ad8eda7
 
2261778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import matplotlib.pyplot as plt
from IPython.display import Image, display
import os

class HashGridEncoder(nn.Module):
    def __init__(self, n_levels=16, n_features_per_level=2,
                 log2_hashmap_size=19, base_resolution=16, 
                 per_level_scale=1.5):
        super(HashGridEncoder, self).__init__()
        self.n_levels = n_levels; self.n_features_per_level = n_features_per_level
        self.log2_hashmap_size = log2_hashmap_size; self.base_resolution = base_resolution
        self.per_level_scale = per_level_scale; self.hashmap_size = 2**self.log2_hashmap_size
        self.embeddings = nn.ModuleList([nn.Embedding(self.hashmap_size, self.n_features_per_level) for i in range(self.n_levels)])
        for emb in self.embeddings: nn.init.uniform_(emb.weight, -1e-4, 1e-4)
    def hash_fn(self, c, l):
        p = torch.tensor([1, 2654435761, 805459861], dtype=torch.int64, device=c.device)
        return (c * p).sum(dim=-1) % self.hashmap_size
    def trilinear_interp(self, x, v000, v001, v010, v011, v100, v101, v110, v111):
        w0=(1-x[...,0:1]); w1=x[...,0:1]; c00=v000*w0+v100*w1; c01=v001*w0+v101*w1
        c10=v010*w0+v110*w1; c11=v011*w0+v111*w1; c0=c00*(1-x[...,1:2])+c10*x[...,1:2]
        c1=c01*(1-x[...,1:2])+c11*x[...,1:2]; c=c0*(1-x[...,2:3])+c1*x[...,2:3]; return c
    def forward(self, x_coords):
        all_features = []
        for l in range(self.n_levels):
            r=int(self.base_resolution*(self.per_level_scale**l)); x_s=x_coords*r
            x_f=torch.floor(x_s).to(torch.int64); x_l=(x_s-x_f)
            corners = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[0,1,1],[1,0,0],[1,0,1],[1,1,0],[1,1,1]], dtype=torch.int64, device=x_coords.device)
            c_c_i=x_f.unsqueeze(1)+corners.unsqueeze(0); h_i=self.hash_fn(c_c_i,l)
            e_f=self.embeddings[l](h_i); i_f=self.trilinear_interp(x_l,e_f[:,0],e_f[:,1],e_f[:,2],e_f[:,3],e_f[:,4],e_f[:,5],e_f[:,6],e_f[:,7])
            all_features.append(i_f)
        return torch.cat(all_features, dim=-1)

class GeoNetHash(nn.Module):
    def __init__(self, box_size_scale=1.0): 
        super(GeoNetHash, self).__init__()
        self.box_size_scale = box_size_scale
        self.encoder = HashGridEncoder(n_levels=16, n_features_per_level=2)
        self.mlp = nn.Sequential(nn.Linear(self.encoder.n_levels * self.encoder.n_features_per_level, 64), nn.ReLU(), nn.Linear(64, 1))
    def normalize_coords(self, coords):
        self.min = coords.min(dim=0, keepdim=True)[0]; self.max = coords.max(dim=0, keepdim=True)[0]
        self.max[self.max == self.min] += 1e-6 
        return (coords - self.min) / (self.max - self.min)
    def normalize_inputs(self, coords):
        if self.min is None or self.max is None: raise Exception("GeoNet not normalized.")
        return (coords - self.min) / (self.max - self.min)
    def forward(self, coords_input_norm): 
        features = self.encoder(coords_input_norm); return self.mlp(features)

class PositionalEncoder4D(nn.Module):
    def __init__(self, input_dims, num_freqs):
        super(PositionalEncoder4D, self).__init__()
        self.input_dims=input_dims; self.num_freqs=num_freqs
        self.freq_bands=2.**torch.linspace(0.,num_freqs-1,num_freqs)
        self.output_dims=self.input_dims*(2*self.num_freqs+1)
    def forward(self, x):
        encoding=[x]; x_freq=x.unsqueeze(-1)*self.freq_bands.to(x.device)
        encoding.append(torch.sin(x_freq).view(x.shape[0], -1))
        encoding.append(torch.cos(x_freq).view(x.shape[0], -1))
        return torch.cat(encoding, dim=1)

class MaxwellPINN(nn.Module):
    def __init__(self, num_freqs=8, hidden_dim=128): 
        super(MaxwellPINN, self).__init__()
        self.encoder = PositionalEncoder4D(input_dims=4, num_freqs=num_freqs)
        self.network = nn.Sequential(
            nn.Linear(self.encoder.output_dims, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, 6) 
        )
    def forward(self, x, y, z, t):
        coords=torch.cat([x,y,z,t],dim=1); x_encoded=self.encoder(coords); return self.network(x_encoded)


DEVICE = torch.device("cpu") 
print(f"test: {DEVICE}")
GEO_MODEL_PATH = "geonet_real_v30.pth"
PHYS_MODEL_PATH = "physnet_v31_real.pth"
NPZ_FILE_PATH = "ground_truth.npz"
print(f"load{NPZ_FILE_PATH}...")
try:
    ground_truth_data = np.load(NPZ_FILE_PATH, allow_pickle=True)
    coords_np = ground_truth_data['coords']; t_np = ground_truth_data['t']
    
    coords_tensor = torch.tensor(coords_np, dtype=torch.float32).to(DEVICE)
    t_tensor = torch.tensor(t_np, dtype=torch.float32).view(-1, 1).to(DEVICE)
    T_MAX = t_tensor.max().item()

    print("loaded")
except FileNotFoundError:
    print(f"file error'{NPZ_FILE_PATH}' file not found")
    exit()

try:
    map_location = torch.device('cpu')
    
    # GeoNet
    model_geo = GeoNetHash().to(DEVICE)
    model_geo.normalize_coords(coords_tensor) 
    model_geo.load_state_dict(torch.load(GEO_MODEL_PATH, map_location=map_location))
    model_geo.eval()
    
    # PhysNet
    model_phys = MaxwellPINN(num_freqs=8, hidden_dim=128).to(DEVICE)
    model_phys.load_state_dict(torch.load(PHYS_MODEL_PATH, map_location=map_location))
    model_phys.eval()
    print("loaded (GeoNet . PhysNet) ")

except Exception as e:
    print(f"\n error loading model {e}")
    print("pth files not found")
    exit()

print("\n--- graphs:")

with torch.no_grad():
    resolution = 100
    
    min_vals = coords_tensor.min(dim=0)[0]
    max_vals = coords_tensor.max(dim=0)[0]

    x_range = np.linspace(min_vals[0].item(), max_vals[0].item(), resolution)
    z_range = np.linspace(min_vals[2].item(), max_vals[2].item(), resolution)
    xx, zz = np.meshgrid(x_range, z_range)
    yy = np.full_like(xx, (min_vals[1] + max_vals[1]).item() / 2.0)
    
    T_VISUALIZATION = T_MAX * 0.75
    tt = np.ones_like(xx) * T_VISUALIZATION

    x_grid = torch.tensor(xx.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)
    y_grid = torch.tensor(yy.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)
    z_grid = torch.tensor(zz.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)
    t_grid = torch.tensor(tt.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)

    coords_viz = torch.cat([x_grid, y_grid, z_grid], dim=1)
    coords_viz_norm = model_geo.normalize_inputs(coords_viz) 
    sdf_grid = model_geo(coords_viz_norm).cpu().numpy().reshape(resolution, resolution)
    
    pred_all = model_phys(x_grid, y_grid, z_grid, t_grid)
    Ex_pred_grid = pred_all[:, 0].cpu().numpy().reshape(resolution, resolution)
    
    Ex_pred_grid[sdf_grid < 0.0] = 0.0
    
    v_max = np.max(np.abs(Ex_pred_grid))
    v_max = max(v_max, 1e-4) 

    plt.figure(figsize=(8, 6))
    cf = plt.contourf(xx, zz, Ex_pred_grid, levels=50, cmap='RdBu_r', vmin=-v_max, vmax=v_max) 
    

    plt.contour(xx, zz, sdf_grid, levels=[0], colors='black', linewidths=3)
    
    title_str = f'Predicted $\\tilde{{E}}_x$ field (v35) at time $t={T_VISUALIZATION:.2f}$'
    plt.title(title_str)
    
    plt.xlabel('x (m)'); plt.ylabel('z (m)'); plt.colorbar(cf, label='E_x field value (Normalized)')
    plt.axis('equal')
    plt.tight_layout()
    
    output_filename = "v35_final_plot_defense.png"
    plt.savefig(output_filename, dpi=300)
    plt.close()

print(f"\n✅ completed'{output_filename}' saved.")
display(Image(filename=output_filename))