DualMaxwell / example_usage.py
ayda138000's picture
Update example_usage.py
2261778 verified
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import matplotlib.pyplot as plt
from IPython.display import Image, display
import os
class HashGridEncoder(nn.Module):
    def __init__(self, n_levels=16, n_features_per_level=2,
                 log2_hashmap_size=19, base_resolution=16,
                 per_level_scale=1.5):
        super(HashGridEncoder, self).__init__()
        self.n_levels = n_levels; self.n_features_per_level = n_features_per_level
        self.log2_hashmap_size = log2_hashmap_size; self.base_resolution = base_resolution
        self.per_level_scale = per_level_scale; self.hashmap_size = 2**self.log2_hashmap_size
        self.embeddings = nn.ModuleList([nn.Embedding(self.hashmap_size, self.n_features_per_level) for i in range(self.n_levels)])
        for emb in self.embeddings: nn.init.uniform_(emb.weight, -1e-4, 1e-4)
    def hash_fn(self, c, l):
        p = torch.tensor([1, 2654435761, 805459861], dtype=torch.int64, device=c.device)
        return (c * p).sum(dim=-1) % self.hashmap_size
    def trilinear_interp(self, x, v000, v001, v010, v011, v100, v101, v110, v111):
        w0=(1-x[...,0:1]); w1=x[...,0:1]; c00=v000*w0+v100*w1; c01=v001*w0+v101*w1
        c10=v010*w0+v110*w1; c11=v011*w0+v111*w1; c0=c00*(1-x[...,1:2])+c10*x[...,1:2]
        c1=c01*(1-x[...,1:2])+c11*x[...,1:2]; c=c0*(1-x[...,2:3])+c1*x[...,2:3]; return c
    def forward(self, x_coords):
        all_features = []
        for l in range(self.n_levels):
            r=int(self.base_resolution*(self.per_level_scale**l)); x_s=x_coords*r
            x_f=torch.floor(x_s).to(torch.int64); x_l=(x_s-x_f)
            corners = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[0,1,1],[1,0,0],[1,0,1],[1,1,0],[1,1,1]], dtype=torch.int64, device=x_coords.device)
            c_c_i=x_f.unsqueeze(1)+corners.unsqueeze(0); h_i=self.hash_fn(c_c_i,l)
            e_f=self.embeddings[l](h_i); i_f=self.trilinear_interp(x_l,e_f[:,0],e_f[:,1],e_f[:,2],e_f[:,3],e_f[:,4],e_f[:,5],e_f[:,6],e_f[:,7])
            all_features.append(i_f)
        return torch.cat(all_features, dim=-1)
class GeoNetHash(nn.Module):
    def __init__(self, box_size_scale=1.0):
        super(GeoNetHash, self).__init__()
        self.box_size_scale = box_size_scale
        self.encoder = HashGridEncoder(n_levels=16, n_features_per_level=2)
        self.mlp = nn.Sequential(nn.Linear(self.encoder.n_levels * self.encoder.n_features_per_level, 64), nn.ReLU(), nn.Linear(64, 1))
    def normalize_coords(self, coords):
        self.min = coords.min(dim=0, keepdim=True)[0]; self.max = coords.max(dim=0, keepdim=True)[0]
        self.max[self.max == self.min] += 1e-6
        return (coords - self.min) / (self.max - self.min)
    def normalize_inputs(self, coords):
        if self.min is None or self.max is None: raise Exception("GeoNet not normalized.")
        return (coords - self.min) / (self.max - self.min)
    def forward(self, coords_input_norm):
        features = self.encoder(coords_input_norm); return self.mlp(features)
class PositionalEncoder4D(nn.Module):
    def __init__(self, input_dims, num_freqs):
        super(PositionalEncoder4D, self).__init__()
        self.input_dims=input_dims; self.num_freqs=num_freqs
        self.freq_bands=2.**torch.linspace(0.,num_freqs-1,num_freqs)
        self.output_dims=self.input_dims*(2*self.num_freqs+1)
    def forward(self, x):
        encoding=[x]; x_freq=x.unsqueeze(-1)*self.freq_bands.to(x.device)
        encoding.append(torch.sin(x_freq).view(x.shape[0], -1))
        encoding.append(torch.cos(x_freq).view(x.shape[0], -1))
        return torch.cat(encoding, dim=1)
class MaxwellPINN(nn.Module):
    def __init__(self, num_freqs=8, hidden_dim=128):
        super(MaxwellPINN, self).__init__()
        self.encoder = PositionalEncoder4D(input_dims=4, num_freqs=num_freqs)
        self.network = nn.Sequential(
            nn.Linear(self.encoder.output_dims, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, 6)
        )
    def forward(self, x, y, z, t):
        coords=torch.cat([x,y,z,t],dim=1); x_encoded=self.encoder(coords); return self.network(x_encoded)
DEVICE = torch.device("cpu")
print(f"test: {DEVICE}")
GEO_MODEL_PATH = "geonet_real_v30.pth"
PHYS_MODEL_PATH = "physnet_v31_real.pth"
NPZ_FILE_PATH = "ground_truth.npz"
print(f"load{NPZ_FILE_PATH}...")
try:
    ground_truth_data = np.load(NPZ_FILE_PATH, allow_pickle=True)
    coords_np = ground_truth_data['coords']; t_np = ground_truth_data['t']
   
    coords_tensor = torch.tensor(coords_np, dtype=torch.float32).to(DEVICE)
    t_tensor = torch.tensor(t_np, dtype=torch.float32).view(-1, 1).to(DEVICE)
    T_MAX = t_tensor.max().item()
    print("loaded")
except FileNotFoundError:
    print(f"file error'{NPZ_FILE_PATH}' file not found")
    exit()
try:
    map_location = torch.device('cpu')
   
    # GeoNet
    model_geo = GeoNetHash().to(DEVICE)
    model_geo.normalize_coords(coords_tensor)
    model_geo.load_state_dict(torch.load(GEO_MODEL_PATH, map_location=map_location))
    model_geo.eval()
   
    # PhysNet
    model_phys = MaxwellPINN(num_freqs=8, hidden_dim=128).to(DEVICE)
    model_phys.load_state_dict(torch.load(PHYS_MODEL_PATH, map_location=map_location))
    model_phys.eval()
    print("loaded (GeoNet . PhysNet) ")
except Exception as e:
    print(f"\n error loading model {e}")
    print("pth files not found")
    exit()
print("\n--- graphs:")
with torch.no_grad():
    resolution = 100
   
    min_vals = coords_tensor.min(dim=0)[0]
    max_vals = coords_tensor.max(dim=0)[0]
    x_range = np.linspace(min_vals[0].item(), max_vals[0].item(), resolution)
    z_range = np.linspace(min_vals[2].item(), max_vals[2].item(), resolution)
    xx, zz = np.meshgrid(x_range, z_range)
    yy = np.full_like(xx, (min_vals[1] + max_vals[1]).item() / 2.0)
   
    T_VISUALIZATION = T_MAX * 0.75
    tt = np.ones_like(xx) * T_VISUALIZATION
    x_grid = torch.tensor(xx.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)
    y_grid = torch.tensor(yy.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)
    z_grid = torch.tensor(zz.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)
    t_grid = torch.tensor(tt.flatten(), dtype=torch.float32).to(DEVICE).view(-1, 1)
    coords_viz = torch.cat([x_grid, y_grid, z_grid], dim=1)
    coords_viz_norm = model_geo.normalize_inputs(coords_viz)
    sdf_grid = model_geo(coords_viz_norm).cpu().numpy().reshape(resolution, resolution)
   
    pred_all = model_phys(x_grid, y_grid, z_grid, t_grid)
    Ex_pred_grid = pred_all[:, 0].cpu().numpy().reshape(resolution, resolution)
   
    Ex_pred_grid[sdf_grid < 0.0] = 0.0
   
    v_max = np.max(np.abs(Ex_pred_grid))
    v_max = max(v_max, 1e-4)
    plt.figure(figsize=(8, 6))
    cf = plt.contourf(xx, zz, Ex_pred_grid, levels=50, cmap='RdBu_r', vmin=-v_max, vmax=v_max)
   
    plt.contour(xx, zz, sdf_grid, levels=[0], colors='black', linewidths=3)
   
    title_str = f'Predicted $\\tilde{{E}}_x$ field (v35) at time $t={T_VISUALIZATION:.2f}$'
    plt.title(title_str)
   
    plt.xlabel('x (m)'); plt.ylabel('z (m)'); plt.colorbar(cf, label='E_x field value (Normalized)')
    plt.axis('equal')
    plt.tight_layout()
   
    output_filename = "v35_final_plot_defense.png"
    plt.savefig(output_filename, dpi=300)
    plt.close()
print(f"\n✅ completed'{output_filename}' saved.")
display(Image(filename=output_filename))