Transformer-PINN / inference.py
guanwencan's picture
Upload 5 files
5e4dee3 verified
import os
import numpy as np
import torch
import pickle
import matplotlib.pyplot as plt
from model import CrackTransformerPINN
from data_loader import DamageCalculator
class CrackPredictor:
def __init__(self, model_path, scaler_path, device='cpu'):
self.device = device
with open(scaler_path, 'rb') as f:
scalers = pickle.load(f)
self.scaler_X = scalers['scaler_X']
self.scaler_y = scalers['scaler_y']
self.angle_bins = scalers['angle_bins']
checkpoint = torch.load(model_path, map_location=device)
if 'model_config' in checkpoint:
config = checkpoint['model_config']
self.model = CrackTransformerPINN(
input_dim=config['input_dim'],
output_dim=config['output_dim'],
hidden_dims=config['hidden_dims'],
dropout=config['dropout']
)
self.model.load_state_dict(checkpoint['model_state_dict'])
else:
self.model = CrackTransformerPINN(
input_dim=5,
output_dim=len(self.angle_bins),
hidden_dims=[128, 256, 256, 128],
dropout=0.2
)
self.model.load_state_dict(checkpoint)
self.model.to(device)
self.model.eval()
def predict(self, pH, FN, FT, T, phase):
X = np.array([[pH, FN, FT, T, phase]], dtype=np.float32)
X_norm = self.scaler_X.transform(X)
with torch.no_grad():
X_tensor = torch.FloatTensor(X_norm).to(self.device)
pred_dist_norm, pred_total = self.model(X_tensor, return_physics=False)
pred_dist_norm = pred_dist_norm.cpu().numpy()
pred_total = pred_total.cpu().numpy().flatten()
pred_dist = self.scaler_y.inverse_transform(pred_dist_norm)
D0 = DamageCalculator.compute_total_damage(pH, FN, FT, T)
lambda_coef = DamageCalculator.compute_lambda(D0)
return {
'angle_distribution': pred_dist[0],
'total_count': pred_total[0],
'D0': D0,
'lambda': lambda_coef
}
def predict_with_physics(self, pH, FN, FT, T, phase):
X = np.array([[pH, FN, FT, T, phase]], dtype=np.float32)
X_norm = self.scaler_X.transform(X)
with torch.no_grad():
X_tensor = torch.FloatTensor(X_norm).to(self.device)
pred_dist_norm, pred_total, physics = self.model(X_tensor, return_physics=True)
pred_dist_norm = pred_dist_norm.cpu().numpy()
pred_dist = self.scaler_y.inverse_transform(pred_dist_norm)
result = {
'angle_distribution': pred_dist[0],
'total_count': pred_total.cpu().numpy().flatten()[0],
'D0': physics['D0'].cpu().numpy().flatten()[0],
'lambda': physics['lambda'].cpu().numpy().flatten()[0],
'D_n': physics['D_n'].cpu().numpy().flatten()[0],
'tau_oct': physics['tau_oct'].cpu().numpy().flatten()[0],
'yield_stress': physics['yield_stress'].cpu().numpy().flatten()[0],
'C1': physics['C1'].cpu().numpy().flatten()[0],
'C2': physics['C2'].cpu().numpy().flatten()[0],
'D_q': physics['D_q'].cpu().numpy().flatten()[0],
'm': physics['m'].cpu().numpy().flatten()[0],
'F0': physics['F0'].cpu().numpy().flatten()[0]
}
return result
def predict_batch(self, X):
X_norm = self.scaler_X.transform(X)
with torch.no_grad():
X_tensor = torch.FloatTensor(X_norm).to(self.device)
pred_dist_norm, pred_total = self.model(X_tensor, return_physics=False)
pred_dist_norm = pred_dist_norm.cpu().numpy()
pred_total = pred_total.cpu().numpy().flatten()
pred_dist = self.scaler_y.inverse_transform(pred_dist_norm)
return pred_dist, pred_total
def visualize(self, result, title=None, save_path=None):
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(projection='polar'))
theta = np.deg2rad(self.angle_bins)
pred_dist = result['angle_distribution']
ax.plot(theta, pred_dist, 'o-', linewidth=2, markersize=4, color='blue')
ax.fill(theta, pred_dist, alpha=0.3, color='blue')
if title:
ax.set_title(title, fontsize=14, pad=20)
else:
total = result['total_count']
D0 = result['D0']
lam = result['lambda']
ax.set_title(f'Predicted Crack Distribution\nTotal: {total:.0f}, D0: {D0:.3f}, lambda: {lam:.3f}',
fontsize=12, pad=20)
ax.set_theta_zero_location('N')
ax.set_theta_direction(-1)
ax.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
def compare_stress_types(self, pH, FN, FT, T, save_path=None):
result_unstable = self.predict(pH, FN, FT, T, phase=0)
result_peak = self.predict(pH, FN, FT, T, phase=1)
fig, axes = plt.subplots(1, 2, figsize=(14, 6), subplot_kw=dict(projection='polar'))
theta = np.deg2rad(self.angle_bins)
axes[0].plot(theta, result_unstable['angle_distribution'], 'o-',
linewidth=2, markersize=4, color='blue')
axes[0].fill(theta, result_unstable['angle_distribution'], alpha=0.3, color='blue')
axes[0].set_title(f'Unstable Development Phase\nTotal: {result_unstable["total_count"]:.0f}',
fontsize=12, pad=20)
axes[0].set_theta_zero_location('N')
axes[0].set_theta_direction(-1)
axes[0].grid(True, alpha=0.3)
axes[1].plot(theta, result_peak['angle_distribution'], 'o-',
linewidth=2, markersize=4, color='red')
axes[1].fill(theta, result_peak['angle_distribution'], alpha=0.3, color='red')
axes[1].set_title(f'Peak Stress Phase\nTotal: {result_peak["total_count"]:.0f}',
fontsize=12, pad=20)
axes[1].set_theta_zero_location('N')
axes[1].set_theta_direction(-1)
axes[1].grid(True, alpha=0.3)
D0 = result_unstable['D0']
lam = result_unstable['lambda']
fig.suptitle(f'pH={pH}, FN={FN}, FT={FT}, T={T}\nD0={D0:.3f}, lambda={lam:.3f}',
fontsize=14, fontweight='bold')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
def main():
model_path = "./output/crack_transformer_pinn.pth"
scaler_path = "./output/scalers.pkl"
if not os.path.exists(model_path):
print("Model file not found. Please train the model first using train.py")
return
device = 'cuda' if torch.cuda.is_available() else 'cpu'
predictor = CrackPredictor(model_path, scaler_path, device=device)
print("=" * 60)
print("Transformer-PINN Crack Prediction - Inference Demo")
print("=" * 60)
test_cases = [
{'pH': 3, 'FN': 30, 'FT': 40, 'T': 25, 'phase': 0},
{'pH': 3, 'FN': 40, 'FT': 20, 'T': 300, 'phase': 1},
{'pH': 7, 'FN': 10, 'FT': 40, 'T': 300, 'phase': 1},
{'pH': 5, 'FN': 10, 'FT': 20, 'T': 900, 'phase': 0},
]
for i, params in enumerate(test_cases):
result = predictor.predict(**params)
print(f"\nTest Case {i+1}:")
print(f" Input: pH={params['pH']}, FN={params['FN']}, FT={params['FT']}, T={params['T']}, phase={params['phase']}")
print(f" D0 (Initial Damage): {result['D0']:.4f}")
print(f" Lambda (Damage Coefficient): {result['lambda']:.4f}")
print(f" Predicted Total Cracks: {result['total_count']:.0f}")
print(f" Peak Angle: {predictor.angle_bins[result['angle_distribution'].argmax()]:.1f} degrees")
print(f" Peak Count: {result['angle_distribution'].max():.0f}")
print("\n" + "=" * 60)
print("Testing Physics Output")
print("=" * 60)
physics_result = predictor.predict_with_physics(pH=3, FN=30, FT=40, T=25, phase=1)
print("\nPhysics Parameters:")
print(f" Mogi-Coulomb:")
print(f" tau_oct: {physics_result['tau_oct']:.4f}")
print(f" yield_stress: {physics_result['yield_stress']:.4f}")
print(f" C1: {physics_result['C1']:.4f}")
print(f" C2: {physics_result['C2']:.4f}")
print(f" Weibull Distribution:")
print(f" D_q: {physics_result['D_q']:.4f}")
print(f" m: {physics_result['m']:.4f}")
print(f" F0: {physics_result['F0']:.4f}")
print(f" Energy Damage:")
print(f" D_n: {physics_result['D_n']:.4f}")
print("\n" + "=" * 60)
print("Inference complete!")
print("=" * 60)
if __name__ == "__main__":
main()