| import os | |
| import numpy as np | |
| import torch | |
| import pickle | |
| import matplotlib.pyplot as plt | |
| from model import CrackTransformerPINN | |
| from data_loader import DamageCalculator | |
| class CrackPredictor: | |
| def __init__(self, model_path, scaler_path, device='cpu'): | |
| self.device = device | |
| with open(scaler_path, 'rb') as f: | |
| scalers = pickle.load(f) | |
| self.scaler_X = scalers['scaler_X'] | |
| self.scaler_y = scalers['scaler_y'] | |
| self.angle_bins = scalers['angle_bins'] | |
| checkpoint = torch.load(model_path, map_location=device) | |
| if 'model_config' in checkpoint: | |
| config = checkpoint['model_config'] | |
| self.model = CrackTransformerPINN( | |
| input_dim=config['input_dim'], | |
| output_dim=config['output_dim'], | |
| hidden_dims=config['hidden_dims'], | |
| dropout=config['dropout'] | |
| ) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| self.model = CrackTransformerPINN( | |
| input_dim=5, | |
| output_dim=len(self.angle_bins), | |
| hidden_dims=[128, 256, 256, 128], | |
| dropout=0.2 | |
| ) | |
| self.model.load_state_dict(checkpoint) | |
| self.model.to(device) | |
| self.model.eval() | |
| def predict(self, pH, FN, FT, T, phase): | |
| X = np.array([[pH, FN, FT, T, phase]], dtype=np.float32) | |
| X_norm = self.scaler_X.transform(X) | |
| with torch.no_grad(): | |
| X_tensor = torch.FloatTensor(X_norm).to(self.device) | |
| pred_dist_norm, pred_total = self.model(X_tensor, return_physics=False) | |
| pred_dist_norm = pred_dist_norm.cpu().numpy() | |
| pred_total = pred_total.cpu().numpy().flatten() | |
| pred_dist = self.scaler_y.inverse_transform(pred_dist_norm) | |
| D0 = DamageCalculator.compute_total_damage(pH, FN, FT, T) | |
| lambda_coef = DamageCalculator.compute_lambda(D0) | |
| return { | |
| 'angle_distribution': pred_dist[0], | |
| 'total_count': pred_total[0], | |
| 'D0': D0, | |
| 'lambda': lambda_coef | |
| } | |
| def predict_with_physics(self, pH, FN, FT, T, phase): | |
| X = np.array([[pH, FN, FT, T, phase]], dtype=np.float32) | |
| X_norm = self.scaler_X.transform(X) | |
| with torch.no_grad(): | |
| X_tensor = torch.FloatTensor(X_norm).to(self.device) | |
| pred_dist_norm, pred_total, physics = self.model(X_tensor, return_physics=True) | |
| pred_dist_norm = pred_dist_norm.cpu().numpy() | |
| pred_dist = self.scaler_y.inverse_transform(pred_dist_norm) | |
| result = { | |
| 'angle_distribution': pred_dist[0], | |
| 'total_count': pred_total.cpu().numpy().flatten()[0], | |
| 'D0': physics['D0'].cpu().numpy().flatten()[0], | |
| 'lambda': physics['lambda'].cpu().numpy().flatten()[0], | |
| 'D_n': physics['D_n'].cpu().numpy().flatten()[0], | |
| 'tau_oct': physics['tau_oct'].cpu().numpy().flatten()[0], | |
| 'yield_stress': physics['yield_stress'].cpu().numpy().flatten()[0], | |
| 'C1': physics['C1'].cpu().numpy().flatten()[0], | |
| 'C2': physics['C2'].cpu().numpy().flatten()[0], | |
| 'D_q': physics['D_q'].cpu().numpy().flatten()[0], | |
| 'm': physics['m'].cpu().numpy().flatten()[0], | |
| 'F0': physics['F0'].cpu().numpy().flatten()[0] | |
| } | |
| return result | |
| def predict_batch(self, X): | |
| X_norm = self.scaler_X.transform(X) | |
| with torch.no_grad(): | |
| X_tensor = torch.FloatTensor(X_norm).to(self.device) | |
| pred_dist_norm, pred_total = self.model(X_tensor, return_physics=False) | |
| pred_dist_norm = pred_dist_norm.cpu().numpy() | |
| pred_total = pred_total.cpu().numpy().flatten() | |
| pred_dist = self.scaler_y.inverse_transform(pred_dist_norm) | |
| return pred_dist, pred_total | |
| def visualize(self, result, title=None, save_path=None): | |
| fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(projection='polar')) | |
| theta = np.deg2rad(self.angle_bins) | |
| pred_dist = result['angle_distribution'] | |
| ax.plot(theta, pred_dist, 'o-', linewidth=2, markersize=4, color='blue') | |
| ax.fill(theta, pred_dist, alpha=0.3, color='blue') | |
| if title: | |
| ax.set_title(title, fontsize=14, pad=20) | |
| else: | |
| total = result['total_count'] | |
| D0 = result['D0'] | |
| lam = result['lambda'] | |
| ax.set_title(f'Predicted Crack Distribution\nTotal: {total:.0f}, D0: {D0:.3f}, lambda: {lam:.3f}', | |
| fontsize=12, pad=20) | |
| ax.set_theta_zero_location('N') | |
| ax.set_theta_direction(-1) | |
| ax.grid(True, alpha=0.3) | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| plt.show() | |
| def compare_stress_types(self, pH, FN, FT, T, save_path=None): | |
| result_unstable = self.predict(pH, FN, FT, T, phase=0) | |
| result_peak = self.predict(pH, FN, FT, T, phase=1) | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 6), subplot_kw=dict(projection='polar')) | |
| theta = np.deg2rad(self.angle_bins) | |
| axes[0].plot(theta, result_unstable['angle_distribution'], 'o-', | |
| linewidth=2, markersize=4, color='blue') | |
| axes[0].fill(theta, result_unstable['angle_distribution'], alpha=0.3, color='blue') | |
| axes[0].set_title(f'Unstable Development Phase\nTotal: {result_unstable["total_count"]:.0f}', | |
| fontsize=12, pad=20) | |
| axes[0].set_theta_zero_location('N') | |
| axes[0].set_theta_direction(-1) | |
| axes[0].grid(True, alpha=0.3) | |
| axes[1].plot(theta, result_peak['angle_distribution'], 'o-', | |
| linewidth=2, markersize=4, color='red') | |
| axes[1].fill(theta, result_peak['angle_distribution'], alpha=0.3, color='red') | |
| axes[1].set_title(f'Peak Stress Phase\nTotal: {result_peak["total_count"]:.0f}', | |
| fontsize=12, pad=20) | |
| axes[1].set_theta_zero_location('N') | |
| axes[1].set_theta_direction(-1) | |
| axes[1].grid(True, alpha=0.3) | |
| D0 = result_unstable['D0'] | |
| lam = result_unstable['lambda'] | |
| fig.suptitle(f'pH={pH}, FN={FN}, FT={FT}, T={T}\nD0={D0:.3f}, lambda={lam:.3f}', | |
| fontsize=14, fontweight='bold') | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| plt.show() | |
| def main(): | |
| model_path = "./output/crack_transformer_pinn.pth" | |
| scaler_path = "./output/scalers.pkl" | |
| if not os.path.exists(model_path): | |
| print("Model file not found. Please train the model first using train.py") | |
| return | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| predictor = CrackPredictor(model_path, scaler_path, device=device) | |
| print("=" * 60) | |
| print("Transformer-PINN Crack Prediction - Inference Demo") | |
| print("=" * 60) | |
| test_cases = [ | |
| {'pH': 3, 'FN': 30, 'FT': 40, 'T': 25, 'phase': 0}, | |
| {'pH': 3, 'FN': 40, 'FT': 20, 'T': 300, 'phase': 1}, | |
| {'pH': 7, 'FN': 10, 'FT': 40, 'T': 300, 'phase': 1}, | |
| {'pH': 5, 'FN': 10, 'FT': 20, 'T': 900, 'phase': 0}, | |
| ] | |
| for i, params in enumerate(test_cases): | |
| result = predictor.predict(**params) | |
| print(f"\nTest Case {i+1}:") | |
| print(f" Input: pH={params['pH']}, FN={params['FN']}, FT={params['FT']}, T={params['T']}, phase={params['phase']}") | |
| print(f" D0 (Initial Damage): {result['D0']:.4f}") | |
| print(f" Lambda (Damage Coefficient): {result['lambda']:.4f}") | |
| print(f" Predicted Total Cracks: {result['total_count']:.0f}") | |
| print(f" Peak Angle: {predictor.angle_bins[result['angle_distribution'].argmax()]:.1f} degrees") | |
| print(f" Peak Count: {result['angle_distribution'].max():.0f}") | |
| print("\n" + "=" * 60) | |
| print("Testing Physics Output") | |
| print("=" * 60) | |
| physics_result = predictor.predict_with_physics(pH=3, FN=30, FT=40, T=25, phase=1) | |
| print("\nPhysics Parameters:") | |
| print(f" Mogi-Coulomb:") | |
| print(f" tau_oct: {physics_result['tau_oct']:.4f}") | |
| print(f" yield_stress: {physics_result['yield_stress']:.4f}") | |
| print(f" C1: {physics_result['C1']:.4f}") | |
| print(f" C2: {physics_result['C2']:.4f}") | |
| print(f" Weibull Distribution:") | |
| print(f" D_q: {physics_result['D_q']:.4f}") | |
| print(f" m: {physics_result['m']:.4f}") | |
| print(f" F0: {physics_result['F0']:.4f}") | |
| print(f" Energy Damage:") | |
| print(f" D_n: {physics_result['D_n']:.4f}") | |
| print("\n" + "=" * 60) | |
| print("Inference complete!") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |