| import os | |
| import numpy as np | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from sklearn.model_selection import train_test_split | |
| import pickle | |
| from data_loader import CrackDataLoader, DamageCalculator | |
| from model import CrackTransformerPINN, CrackPINNTrainer | |
| def plot_training_history(trainer, save_path=None): | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 4)) | |
| axes[0].plot(trainer.train_losses, label='Train Loss', linewidth=2) | |
| axes[0].plot(trainer.val_losses, label='Val Loss', linewidth=2) | |
| axes[0].set_xlabel('Epoch', fontsize=12) | |
| axes[0].set_ylabel('Loss', fontsize=12) | |
| axes[0].set_title('Training History', fontsize=14, fontweight='bold') | |
| axes[0].legend() | |
| axes[0].grid(True, alpha=0.3) | |
| axes[1].semilogy(trainer.train_losses, label='Train Loss', linewidth=2) | |
| axes[1].semilogy(trainer.val_losses, label='Val Loss', linewidth=2) | |
| axes[1].set_xlabel('Epoch', fontsize=12) | |
| axes[1].set_ylabel('Loss (log scale)', fontsize=12) | |
| axes[1].set_title('Training History (Log Scale)', fontsize=14, fontweight='bold') | |
| axes[1].legend() | |
| axes[1].grid(True, alpha=0.3) | |
| if len(trainer.train_losses) > 1: | |
| train_improvement = np.diff(trainer.train_losses) | |
| axes[2].plot(train_improvement, linewidth=2, alpha=0.7) | |
| axes[2].axhline(y=0, color='r', linestyle='--', alpha=0.5) | |
| axes[2].set_xlabel('Epoch', fontsize=12) | |
| axes[2].set_ylabel('Loss Change', fontsize=12) | |
| axes[2].set_title('Convergence Rate', fontsize=14, fontweight='bold') | |
| axes[2].grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f"Training history saved to: {save_path}") | |
| plt.close() | |
| def plot_prediction_examples(X_test, y_test, trainer, angle_bins, loader, n_examples=4, save_path=None): | |
| X_test_original = loader.scaler_X.inverse_transform(X_test) | |
| y_test_original = loader.scaler_y.inverse_transform(y_test) | |
| y_pred_norm, pred_totals = trainer.predict(X_test[:n_examples]) | |
| y_pred = loader.scaler_y.inverse_transform(y_pred_norm) | |
| fig, axes = plt.subplots(2, 2, figsize=(14, 10), subplot_kw=dict(projection='polar')) | |
| axes = axes.flatten() | |
| for i in range(min(n_examples, len(axes))): | |
| ax = axes[i] | |
| theta = np.deg2rad(angle_bins) | |
| ax.plot(theta, y_test_original[i], 'o-', label='True', linewidth=2, markersize=4, alpha=0.7) | |
| ax.plot(theta, y_pred[i], 's-', label='Predicted', linewidth=2, markersize=3, alpha=0.7) | |
| pH = X_test_original[i, 0] | |
| FN = X_test_original[i, 1] | |
| FT = X_test_original[i, 2] | |
| T = X_test_original[i, 3] | |
| phase = X_test_original[i, 4] | |
| phase_str = "Unstable" if phase < 0.5 else "Peak" | |
| D0 = DamageCalculator.compute_total_damage(pH, FN, FT, T) | |
| lambda_coef = DamageCalculator.compute_lambda(D0) | |
| true_total = y_test_original[i].sum() | |
| pred_total = y_pred[i].sum() | |
| title = f"pH={pH:.0f}, FN={FN:.0f}, FT={FT:.0f}, T={T:.0f}C\n" | |
| title += f"D0={D0:.3f}, lambda={lambda_coef:.3f}\n" | |
| title += f"{phase_str} | True: {true_total:.0f}, Pred: {pred_total:.0f}" | |
| ax.set_title(title, fontsize=9, pad=20) | |
| ax.legend(loc='upper right', fontsize=8) | |
| ax.set_theta_zero_location('N') | |
| ax.set_theta_direction(-1) | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f"Prediction examples saved to: {save_path}") | |
| plt.close() | |
| def plot_damage_analysis(X, y, save_path=None): | |
| fig, axes = plt.subplots(2, 2, figsize=(12, 10)) | |
| D0_values = [] | |
| total_cracks = y.sum(axis=1) | |
| for i in range(X.shape[0]): | |
| D0 = DamageCalculator.compute_total_damage(X[i, 0], X[i, 1], X[i, 2], X[i, 3]) | |
| D0_values.append(D0) | |
| D0_values = np.array(D0_values) | |
| axes[0, 0].scatter(D0_values, total_cracks, alpha=0.6, edgecolors='black', linewidth=0.5) | |
| axes[0, 0].set_xlabel('Initial Damage Factor D0', fontsize=12) | |
| axes[0, 0].set_ylabel('Total Crack Count', fontsize=12) | |
| axes[0, 0].set_title('D0 vs Total Cracks', fontsize=14, fontweight='bold') | |
| axes[0, 0].grid(True, alpha=0.3) | |
| axes[0, 1].scatter(X[:, 0], total_cracks, alpha=0.6, c=D0_values, cmap='viridis') | |
| axes[0, 1].set_xlabel('pH Value', fontsize=12) | |
| axes[0, 1].set_ylabel('Total Crack Count', fontsize=12) | |
| axes[0, 1].set_title('pH vs Total Cracks', fontsize=14, fontweight='bold') | |
| axes[0, 1].grid(True, alpha=0.3) | |
| axes[1, 0].scatter(X[:, 1], total_cracks, alpha=0.6, c=D0_values, cmap='viridis') | |
| axes[1, 0].set_xlabel('Freeze-thaw Cycles (FN)', fontsize=12) | |
| axes[1, 0].set_ylabel('Total Crack Count', fontsize=12) | |
| axes[1, 0].set_title('FN vs Total Cracks', fontsize=14, fontweight='bold') | |
| axes[1, 0].grid(True, alpha=0.3) | |
| scatter = axes[1, 1].scatter(X[:, 3], total_cracks, alpha=0.6, c=D0_values, cmap='viridis') | |
| axes[1, 1].set_xlabel('Damage Temperature (T)', fontsize=12) | |
| axes[1, 1].set_ylabel('Total Crack Count', fontsize=12) | |
| axes[1, 1].set_title('T vs Total Cracks', fontsize=14, fontweight='bold') | |
| axes[1, 1].grid(True, alpha=0.3) | |
| plt.colorbar(scatter, ax=axes[1, 1], label='D0') | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f"Damage analysis saved to: {save_path}") | |
| plt.close() | |
| def main(): | |
| print("=" * 80) | |
| print("Transformer-PINN Crack Prediction Model") | |
| print("Based on: Mechanism of micro-damage evolution in rocks") | |
| print("under multiple coupled cyclic stresses") | |
| print("=" * 80) | |
| base_path = "./data" | |
| output_dir = "./output" | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| print("\n" + "=" * 80) | |
| print("Step 1: Loading/Generating Data") | |
| print("=" * 80) | |
| loader = CrackDataLoader(base_path, stress_type="major") | |
| try: | |
| X, y, angle_bins, damage_list = loader.load_all_data(phase="both") | |
| except: | |
| print("Real data not found. Generating synthetic data...") | |
| X, y, angle_bins = loader.create_synthetic_data(n_samples=200, output_dim=72) | |
| stats = loader.get_statistics(X, y) | |
| print("\nData statistics:") | |
| for key, value in stats.items(): | |
| print(f" {key}: {value}") | |
| print("\n" + "=" * 80) | |
| print("Step 2: Splitting Dataset (Train:Val:Test = 64:16:20)") | |
| print("=" * 80) | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=0.2, random_state=42 | |
| ) | |
| X_train, X_val, y_train, y_val = train_test_split( | |
| X_train, y_train, test_size=0.2, random_state=42 | |
| ) | |
| print(f"Training set: {X_train.shape[0]} samples") | |
| print(f"Validation set: {X_val.shape[0]} samples") | |
| print(f"Test set: {X_test.shape[0]} samples") | |
| print("\n" + "=" * 80) | |
| print("Step 3: Normalizing Data") | |
| print("=" * 80) | |
| X_train_norm, y_train_norm, X_val_norm, y_val_norm = loader.normalize_data( | |
| X_train, y_train, X_val, y_val | |
| ) | |
| X_test_norm = loader.scaler_X.transform(X_test) | |
| y_test_norm = loader.scaler_y.transform(y_test) | |
| print("Normalization complete") | |
| print("\n" + "=" * 80) | |
| print("Step 4: Creating Model") | |
| print("=" * 80) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"Using device: {device}") | |
| model = CrackTransformerPINN( | |
| input_dim=5, | |
| output_dim=y.shape[1], | |
| hidden_dims=[128, 256, 256, 128], | |
| dropout=0.2 | |
| ) | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| print(f"Model parameters: {n_params:,}") | |
| print("\nModel components:") | |
| print(" - Transformer Encoder (8 heads, 4 layers)") | |
| print(" - Mogi-Coulomb Yield Criterion Layer") | |
| print(" - Weibull Strength Distribution Layer") | |
| print(" - Energy-based Damage Evolution Layer") | |
| print(" - PINN Decoder with Physics Constraints") | |
| print("\n" + "=" * 80) | |
| print("Step 5: Training Model") | |
| print("=" * 80) | |
| trainer = CrackPINNTrainer( | |
| model, | |
| device=device, | |
| lr=1e-3, | |
| weight_decay=1e-4 | |
| ) | |
| trainer.fit( | |
| X_train_norm, y_train_norm, | |
| X_val_norm, y_val_norm, | |
| epochs=300, | |
| batch_size=8, | |
| patience=50 | |
| ) | |
| print("\n" + "=" * 80) | |
| print("Step 6: Testing Model") | |
| print("=" * 80) | |
| test_loss, test_metrics = trainer.validate( | |
| torch.utils.data.DataLoader( | |
| torch.utils.data.TensorDataset( | |
| torch.FloatTensor(X_test_norm), | |
| torch.FloatTensor(y_test_norm) | |
| ), | |
| batch_size=8, | |
| shuffle=False | |
| ) | |
| ) | |
| print(f"Test set performance:") | |
| print(f" Loss: {test_loss:.4f}") | |
| print(f" R2: {test_metrics['r2']:.4f}") | |
| print(f" RMSE: {test_metrics['rmse']:.2f}") | |
| print(f" Total Count MAE: {test_metrics['total_count_mae']:.2f}") | |
| print("\n" + "=" * 80) | |
| print("Step 7: Saving Model") | |
| print("=" * 80) | |
| model_path = os.path.join(output_dir, "crack_transformer_pinn.pth") | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| 'model_config': { | |
| 'input_dim': 5, | |
| 'output_dim': y.shape[1], | |
| 'hidden_dims': [128, 256, 256, 128], | |
| 'dropout': 0.2 | |
| }, | |
| 'test_metrics': test_metrics | |
| }, model_path) | |
| print(f"Model saved to: {model_path}") | |
| scaler_path = os.path.join(output_dir, "scalers.pkl") | |
| with open(scaler_path, 'wb') as f: | |
| pickle.dump({ | |
| 'scaler_X': loader.scaler_X, | |
| 'scaler_y': loader.scaler_y, | |
| 'angle_bins': angle_bins | |
| }, f) | |
| print(f"Scalers saved to: {scaler_path}") | |
| print("\n" + "=" * 80) | |
| print("Step 8: Generating Visualizations") | |
| print("=" * 80) | |
| history_path = os.path.join(output_dir, "training_history.png") | |
| plot_training_history(trainer, save_path=history_path) | |
| examples_path = os.path.join(output_dir, "prediction_examples.png") | |
| plot_prediction_examples( | |
| X_test_norm, y_test_norm, | |
| trainer, angle_bins, loader, | |
| n_examples=4, | |
| save_path=examples_path | |
| ) | |
| damage_path = os.path.join(output_dir, "damage_analysis.png") | |
| plot_damage_analysis(X, y, save_path=damage_path) | |
| print("\n" + "=" * 80) | |
| print("Training Pipeline Complete!") | |
| print("=" * 80) | |
| print(f"\nGenerated files:") | |
| print(f" 1. Model checkpoint: {model_path}") | |
| print(f" 2. Scalers: {scaler_path}") | |
| print(f" 3. Training history: {history_path}") | |
| print(f" 4. Prediction examples: {examples_path}") | |
| print(f" 5. Damage analysis: {damage_path}") | |
| print("\nPhysics constraints applied:") | |
| print(" - Mogi-Coulomb yield criterion: tau_oct = C1 + C2 * sigma_m2") | |
| print(" - Weibull strength: D_q = 1 - exp(-(F/F0)^m)") | |
| print(" - Energy damage: D_n = (2/pi) * arctan(b * U_p)") | |
| print(" - Total damage: D_total = 1 - (1-D_ft)(1-D_ch)(1-D_th)") | |
| if __name__ == "__main__": | |
| main() | |