import os import numpy as np import torch import pickle import matplotlib.pyplot as plt from model import CrackTransformerPINN from data_loader import DamageCalculator class CrackPredictor: def __init__(self, model_path, scaler_path, device='cpu'): self.device = device with open(scaler_path, 'rb') as f: scalers = pickle.load(f) self.scaler_X = scalers['scaler_X'] self.scaler_y = scalers['scaler_y'] self.angle_bins = scalers['angle_bins'] checkpoint = torch.load(model_path, map_location=device) if 'model_config' in checkpoint: config = checkpoint['model_config'] self.model = CrackTransformerPINN( input_dim=config['input_dim'], output_dim=config['output_dim'], hidden_dims=config['hidden_dims'], dropout=config['dropout'] ) self.model.load_state_dict(checkpoint['model_state_dict']) else: self.model = CrackTransformerPINN( input_dim=5, output_dim=len(self.angle_bins), hidden_dims=[128, 256, 256, 128], dropout=0.2 ) self.model.load_state_dict(checkpoint) self.model.to(device) self.model.eval() def predict(self, pH, FN, FT, T, phase): X = np.array([[pH, FN, FT, T, phase]], dtype=np.float32) X_norm = self.scaler_X.transform(X) with torch.no_grad(): X_tensor = torch.FloatTensor(X_norm).to(self.device) pred_dist_norm, pred_total = self.model(X_tensor, return_physics=False) pred_dist_norm = pred_dist_norm.cpu().numpy() pred_total = pred_total.cpu().numpy().flatten() pred_dist = self.scaler_y.inverse_transform(pred_dist_norm) D0 = DamageCalculator.compute_total_damage(pH, FN, FT, T) lambda_coef = DamageCalculator.compute_lambda(D0) return { 'angle_distribution': pred_dist[0], 'total_count': pred_total[0], 'D0': D0, 'lambda': lambda_coef } def predict_with_physics(self, pH, FN, FT, T, phase): X = np.array([[pH, FN, FT, T, phase]], dtype=np.float32) X_norm = self.scaler_X.transform(X) with torch.no_grad(): X_tensor = torch.FloatTensor(X_norm).to(self.device) pred_dist_norm, pred_total, physics = self.model(X_tensor, return_physics=True) pred_dist_norm = pred_dist_norm.cpu().numpy() pred_dist = self.scaler_y.inverse_transform(pred_dist_norm) result = { 'angle_distribution': pred_dist[0], 'total_count': pred_total.cpu().numpy().flatten()[0], 'D0': physics['D0'].cpu().numpy().flatten()[0], 'lambda': physics['lambda'].cpu().numpy().flatten()[0], 'D_n': physics['D_n'].cpu().numpy().flatten()[0], 'tau_oct': physics['tau_oct'].cpu().numpy().flatten()[0], 'yield_stress': physics['yield_stress'].cpu().numpy().flatten()[0], 'C1': physics['C1'].cpu().numpy().flatten()[0], 'C2': physics['C2'].cpu().numpy().flatten()[0], 'D_q': physics['D_q'].cpu().numpy().flatten()[0], 'm': physics['m'].cpu().numpy().flatten()[0], 'F0': physics['F0'].cpu().numpy().flatten()[0] } return result def predict_batch(self, X): X_norm = self.scaler_X.transform(X) with torch.no_grad(): X_tensor = torch.FloatTensor(X_norm).to(self.device) pred_dist_norm, pred_total = self.model(X_tensor, return_physics=False) pred_dist_norm = pred_dist_norm.cpu().numpy() pred_total = pred_total.cpu().numpy().flatten() pred_dist = self.scaler_y.inverse_transform(pred_dist_norm) return pred_dist, pred_total def visualize(self, result, title=None, save_path=None): fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(projection='polar')) theta = np.deg2rad(self.angle_bins) pred_dist = result['angle_distribution'] ax.plot(theta, pred_dist, 'o-', linewidth=2, markersize=4, color='blue') ax.fill(theta, pred_dist, alpha=0.3, color='blue') if title: ax.set_title(title, fontsize=14, pad=20) else: total = result['total_count'] D0 = result['D0'] lam = result['lambda'] ax.set_title(f'Predicted Crack Distribution\nTotal: {total:.0f}, D0: {D0:.3f}, lambda: {lam:.3f}', fontsize=12, pad=20) ax.set_theta_zero_location('N') ax.set_theta_direction(-1) ax.grid(True, alpha=0.3) if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.show() def compare_stress_types(self, pH, FN, FT, T, save_path=None): result_unstable = self.predict(pH, FN, FT, T, phase=0) result_peak = self.predict(pH, FN, FT, T, phase=1) fig, axes = plt.subplots(1, 2, figsize=(14, 6), subplot_kw=dict(projection='polar')) theta = np.deg2rad(self.angle_bins) axes[0].plot(theta, result_unstable['angle_distribution'], 'o-', linewidth=2, markersize=4, color='blue') axes[0].fill(theta, result_unstable['angle_distribution'], alpha=0.3, color='blue') axes[0].set_title(f'Unstable Development Phase\nTotal: {result_unstable["total_count"]:.0f}', fontsize=12, pad=20) axes[0].set_theta_zero_location('N') axes[0].set_theta_direction(-1) axes[0].grid(True, alpha=0.3) axes[1].plot(theta, result_peak['angle_distribution'], 'o-', linewidth=2, markersize=4, color='red') axes[1].fill(theta, result_peak['angle_distribution'], alpha=0.3, color='red') axes[1].set_title(f'Peak Stress Phase\nTotal: {result_peak["total_count"]:.0f}', fontsize=12, pad=20) axes[1].set_theta_zero_location('N') axes[1].set_theta_direction(-1) axes[1].grid(True, alpha=0.3) D0 = result_unstable['D0'] lam = result_unstable['lambda'] fig.suptitle(f'pH={pH}, FN={FN}, FT={FT}, T={T}\nD0={D0:.3f}, lambda={lam:.3f}', fontsize=14, fontweight='bold') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.show() def main(): model_path = "./output/crack_transformer_pinn.pth" scaler_path = "./output/scalers.pkl" if not os.path.exists(model_path): print("Model file not found. Please train the model first using train.py") return device = 'cuda' if torch.cuda.is_available() else 'cpu' predictor = CrackPredictor(model_path, scaler_path, device=device) print("=" * 60) print("Transformer-PINN Crack Prediction - Inference Demo") print("=" * 60) test_cases = [ {'pH': 3, 'FN': 30, 'FT': 40, 'T': 25, 'phase': 0}, {'pH': 3, 'FN': 40, 'FT': 20, 'T': 300, 'phase': 1}, {'pH': 7, 'FN': 10, 'FT': 40, 'T': 300, 'phase': 1}, {'pH': 5, 'FN': 10, 'FT': 20, 'T': 900, 'phase': 0}, ] for i, params in enumerate(test_cases): result = predictor.predict(**params) print(f"\nTest Case {i+1}:") print(f" Input: pH={params['pH']}, FN={params['FN']}, FT={params['FT']}, T={params['T']}, phase={params['phase']}") print(f" D0 (Initial Damage): {result['D0']:.4f}") print(f" Lambda (Damage Coefficient): {result['lambda']:.4f}") print(f" Predicted Total Cracks: {result['total_count']:.0f}") print(f" Peak Angle: {predictor.angle_bins[result['angle_distribution'].argmax()]:.1f} degrees") print(f" Peak Count: {result['angle_distribution'].max():.0f}") print("\n" + "=" * 60) print("Testing Physics Output") print("=" * 60) physics_result = predictor.predict_with_physics(pH=3, FN=30, FT=40, T=25, phase=1) print("\nPhysics Parameters:") print(f" Mogi-Coulomb:") print(f" tau_oct: {physics_result['tau_oct']:.4f}") print(f" yield_stress: {physics_result['yield_stress']:.4f}") print(f" C1: {physics_result['C1']:.4f}") print(f" C2: {physics_result['C2']:.4f}") print(f" Weibull Distribution:") print(f" D_q: {physics_result['D_q']:.4f}") print(f" m: {physics_result['m']:.4f}") print(f" F0: {physics_result['F0']:.4f}") print(f" Energy Damage:") print(f" D_n: {physics_result['D_n']:.4f}") print("\n" + "=" * 60) print("Inference complete!") print("=" * 60) if __name__ == "__main__": main()