AnsysLPFMTrame-App / utils /Jansen-score.py
udbhav
Recreate Trame_app branch with clean history
67fb03c
import pyvista as pv
import numpy as np
from scipy.spatial.distance import jensenshannon
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
import os
# -------------------- PyVista Config --------------------
pv.OFF_SCREEN = True
pv.set_plot_theme("document")
# -------------------- Paths --------------------
dataset = "plane_transonic"
train_folder = f'/raid/ansysai/pkakka/6-Transformers/comparePhysicsLM/Data/{dataset}/'
train_save_dir = os.path.join(train_folder, f"../../metrics/{dataset}/")
os.makedirs(train_save_dir, exist_ok=True)
# -------------------- Utility Functions --------------------
def get_points(file_path, max_points=5000):
"""Extract and subsample point cloud from VTP file."""
try:
mesh = pv.read(file_path)
points = mesh.points
if len(points) > max_points:
indices = np.random.choice(len(points), max_points, replace=False)
points = points[indices]
return points
except Exception as e:
raise ValueError(f"Error reading {file_path}: {e}")
# -------------------- Training Histogram --------------------
def compute_training_dist(train_folder, output_file='train_dist.npz', num_bins=25, smooth_sigma=1):
"""Compute combined training point cloud histogram (no normalization, with smoothing)."""
train_txt_path = os.path.join(train_folder, '1_VTK_surface/train.txt')
if not os.path.exists(train_txt_path):
raise ValueError(f"train.txt not found at {train_txt_path}")
with open(train_txt_path, 'r') as f:
folder_names = [line.strip() for line in f if line.strip()]
train_files = []
for name in folder_names:
vtp_file = os.path.join(train_folder, '1_VTK_surface', name, f'{name}.vtp')
if os.path.exists(vtp_file):
train_files.append(vtp_file)
else:
print(f"Warning: VTP not found: {vtp_file}")
if not train_files:
raise ValueError("No training VTPs found.")
# Combine all training points
train_points = np.concatenate([get_points(f) for f in train_files], axis=0)
# 3D histogram edges
bin_edges = [np.histogram_bin_edges(train_points[:, i], bins=num_bins) for i in range(3)]
train_hist, _ = np.histogramdd(train_points, bins=bin_edges, density=True)
# Gaussian smoothing
train_hist = gaussian_filter(train_hist, sigma=smooth_sigma)
# Flatten and normalize
train_hist = train_hist.flatten()
train_hist /= train_hist.sum()
# Save histogram and edges
output_path = os.path.join(train_folder+"1_VTK_surface", output_file)
np.savez(output_path, hist=train_hist, edges0=bin_edges[0], edges1=bin_edges[1], edges2=bin_edges[2])
print(f"Training histogram saved: {output_path} ({train_points.shape[0]} points)")
# -------------------- JS Score for Test --------------------
def compute_js_score(test_file, train_dist_file='train_dist.npz', smooth_sigma=1):
"""Compute JS score (1 - JS divergence) for a test VTP file with smoothing."""
data = np.load(os.path.join(train_save_dir, train_dist_file))
train_hist = data['hist']
bin_edges = [data['edges0'], data['edges1'], data['edges2']]
test_points = get_points(test_file)
# 3D histogram for test points
test_hist, _ = np.histogramdd(test_points, bins=bin_edges, density=True)
test_hist = gaussian_filter(test_hist, sigma=smooth_sigma) # smooth test histogram
# Flatten and normalize
test_hist = test_hist.flatten()
test_hist /= test_hist.sum()
# Compute JS divergence
div = jensenshannon(train_hist, test_hist)
return 1 - div
# -------------------- Analyze Test Folder --------------------
def analyze_and_save_scores(test_folder, train_dist_file='train_dist.npz', output_file='test_js_scores.txt', num_bins=25, smooth_sigma=1):
"""Compute JS scores for all test cases and save results."""
test_txt_path = os.path.join(train_folder, '1_VTK_surface/test.txt')
if not os.path.exists(test_txt_path):
raise ValueError(f"test.txt not found at {test_txt_path}")
with open(test_txt_path, 'r') as f:
folder_names = [line.strip() for line in f if line.strip()]
if not folder_names:
raise ValueError("No test cases found.")
output_path = os.path.join(train_save_dir, output_file)
scores = []
names = []
with open(output_path, 'w') as f_out:
f_out.write("Test_File\tJS_Score\n")
for name in folder_names:
vtp_file = os.path.join(train_folder, '1_VTK_surface', name, f'{name}.vtp')
if os.path.exists(vtp_file):
try:
score = compute_js_score(vtp_file, train_dist_file, smooth_sigma)
scores.append(score)
names.append(name)
print(f"{name}: {score:.4f}")
f_out.write(f"{name}\t{score:.6f}\n")
except Exception as e:
print(f"Error for {name}: {e}")
f_out.write(f"{name}\tERROR: {e}\n")
else:
print(f"Warning: VTP not found: {vtp_file}")
f_out.write(f"{name}\tERROR: VTP not found\n")
if scores:
print(f"\nAverage Score: {np.mean(scores):.4f} ± {np.std(scores):.4f}")
print(f"Min/Max: {np.min(scores):.4f} / {np.max(scores):.4f}")
# Plot histogram
plt.figure(figsize=(6, 4))
plt.hist(scores, bins=10, alpha=0.7, edgecolor='black')
plt.xlabel('JS Score (Higher = Closer to Train)')
plt.ylabel('Count')
plt.title('Test Geometry JS Scores')
plt.savefig(os.path.join(train_save_dir, 'js_scores_hist.png'))
plt.close()
return names, scores
# -------------------- Optional Visualization --------------------
def visualize_sample(test_folder, show_plot=True):
"""Plot the first test geometry (optional)."""
if not show_plot:
print("Skipping visualization (set show_plot=True to enable)")
return
test_txt_path = os.path.join(train_folder, '1_VTK_surface/test.txt')
with open(test_txt_path, 'r') as f:
folder_names = [line.strip() for line in f if line.strip()]
if folder_names:
name = folder_names[0]
vtp_file = os.path.join(train_folder, '1_VTK_surface', name, f'{name}.vtp')
if os.path.exists(vtp_file):
try:
mesh = pv.read(vtp_file)
plotter = pv.Plotter(off_screen=True)
plotter.add_mesh(mesh, color='blue', show_edges=True)
plotter.add_title(f'Sample Geometry: {name}')
screenshot_path = os.path.join(train_save_dir, f'sample_geometry_{name}.png')
plotter.screenshot(screenshot_path)
plotter.close()
print(f"Sample screenshot saved: {screenshot_path}")
except Exception as e:
print(f"Warning: Could not visualize geometry: {e}")
else:
print(f"Warning: VTP file not found: {vtp_file}")
else:
print("Warning: No test cases found for visualization")
# -------------------- Main --------------------
if __name__ == "__main__":
try:
print("Computing training histogram...")
compute_training_dist(train_folder, num_bins=25, smooth_sigma=1)
print("Analyzing test cases...")
names, scores = analyze_and_save_scores(train_folder, num_bins=25, smooth_sigma=1)
visualize_sample(train_folder, show_plot=False)
print("JS analysis completed successfully!")
except Exception as e:
print(f"Error during analysis: {e}")
import traceback
traceback.print_exc()