import os import numpy as np import torch import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split import pickle from data_loader import CrackDataLoader, DamageCalculator from model import CrackTransformerPINN, CrackPINNTrainer def plot_training_history(trainer, save_path=None): fig, axes = plt.subplots(1, 3, figsize=(15, 4)) axes[0].plot(trainer.train_losses, label='Train Loss', linewidth=2) axes[0].plot(trainer.val_losses, label='Val Loss', linewidth=2) axes[0].set_xlabel('Epoch', fontsize=12) axes[0].set_ylabel('Loss', fontsize=12) axes[0].set_title('Training History', fontsize=14, fontweight='bold') axes[0].legend() axes[0].grid(True, alpha=0.3) axes[1].semilogy(trainer.train_losses, label='Train Loss', linewidth=2) axes[1].semilogy(trainer.val_losses, label='Val Loss', linewidth=2) axes[1].set_xlabel('Epoch', fontsize=12) axes[1].set_ylabel('Loss (log scale)', fontsize=12) axes[1].set_title('Training History (Log Scale)', fontsize=14, fontweight='bold') axes[1].legend() axes[1].grid(True, alpha=0.3) if len(trainer.train_losses) > 1: train_improvement = np.diff(trainer.train_losses) axes[2].plot(train_improvement, linewidth=2, alpha=0.7) axes[2].axhline(y=0, color='r', linestyle='--', alpha=0.5) axes[2].set_xlabel('Epoch', fontsize=12) axes[2].set_ylabel('Loss Change', fontsize=12) axes[2].set_title('Convergence Rate', fontsize=14, fontweight='bold') axes[2].grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Training history saved to: {save_path}") plt.close() def plot_prediction_examples(X_test, y_test, trainer, angle_bins, loader, n_examples=4, save_path=None): X_test_original = loader.scaler_X.inverse_transform(X_test) y_test_original = loader.scaler_y.inverse_transform(y_test) y_pred_norm, pred_totals = trainer.predict(X_test[:n_examples]) y_pred = loader.scaler_y.inverse_transform(y_pred_norm) fig, axes = plt.subplots(2, 2, figsize=(14, 10), subplot_kw=dict(projection='polar')) axes = axes.flatten() for i in range(min(n_examples, len(axes))): ax = axes[i] theta = np.deg2rad(angle_bins) ax.plot(theta, y_test_original[i], 'o-', label='True', linewidth=2, markersize=4, alpha=0.7) ax.plot(theta, y_pred[i], 's-', label='Predicted', linewidth=2, markersize=3, alpha=0.7) pH = X_test_original[i, 0] FN = X_test_original[i, 1] FT = X_test_original[i, 2] T = X_test_original[i, 3] phase = X_test_original[i, 4] phase_str = "Unstable" if phase < 0.5 else "Peak" D0 = DamageCalculator.compute_total_damage(pH, FN, FT, T) lambda_coef = DamageCalculator.compute_lambda(D0) true_total = y_test_original[i].sum() pred_total = y_pred[i].sum() title = f"pH={pH:.0f}, FN={FN:.0f}, FT={FT:.0f}, T={T:.0f}C\n" title += f"D0={D0:.3f}, lambda={lambda_coef:.3f}\n" title += f"{phase_str} | True: {true_total:.0f}, Pred: {pred_total:.0f}" ax.set_title(title, fontsize=9, pad=20) ax.legend(loc='upper right', fontsize=8) ax.set_theta_zero_location('N') ax.set_theta_direction(-1) ax.grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Prediction examples saved to: {save_path}") plt.close() def plot_damage_analysis(X, y, save_path=None): fig, axes = plt.subplots(2, 2, figsize=(12, 10)) D0_values = [] total_cracks = y.sum(axis=1) for i in range(X.shape[0]): D0 = DamageCalculator.compute_total_damage(X[i, 0], X[i, 1], X[i, 2], X[i, 3]) D0_values.append(D0) D0_values = np.array(D0_values) axes[0, 0].scatter(D0_values, total_cracks, alpha=0.6, edgecolors='black', linewidth=0.5) axes[0, 0].set_xlabel('Initial Damage Factor D0', fontsize=12) axes[0, 0].set_ylabel('Total Crack Count', fontsize=12) axes[0, 0].set_title('D0 vs Total Cracks', fontsize=14, fontweight='bold') axes[0, 0].grid(True, alpha=0.3) axes[0, 1].scatter(X[:, 0], total_cracks, alpha=0.6, c=D0_values, cmap='viridis') axes[0, 1].set_xlabel('pH Value', fontsize=12) axes[0, 1].set_ylabel('Total Crack Count', fontsize=12) axes[0, 1].set_title('pH vs Total Cracks', fontsize=14, fontweight='bold') axes[0, 1].grid(True, alpha=0.3) axes[1, 0].scatter(X[:, 1], total_cracks, alpha=0.6, c=D0_values, cmap='viridis') axes[1, 0].set_xlabel('Freeze-thaw Cycles (FN)', fontsize=12) axes[1, 0].set_ylabel('Total Crack Count', fontsize=12) axes[1, 0].set_title('FN vs Total Cracks', fontsize=14, fontweight='bold') axes[1, 0].grid(True, alpha=0.3) scatter = axes[1, 1].scatter(X[:, 3], total_cracks, alpha=0.6, c=D0_values, cmap='viridis') axes[1, 1].set_xlabel('Damage Temperature (T)', fontsize=12) axes[1, 1].set_ylabel('Total Crack Count', fontsize=12) axes[1, 1].set_title('T vs Total Cracks', fontsize=14, fontweight='bold') axes[1, 1].grid(True, alpha=0.3) plt.colorbar(scatter, ax=axes[1, 1], label='D0') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Damage analysis saved to: {save_path}") plt.close() def main(): print("=" * 80) print("Transformer-PINN Crack Prediction Model") print("Based on: Mechanism of micro-damage evolution in rocks") print("under multiple coupled cyclic stresses") print("=" * 80) base_path = "./data" output_dir = "./output" if not os.path.exists(output_dir): os.makedirs(output_dir) print("\n" + "=" * 80) print("Step 1: Loading/Generating Data") print("=" * 80) loader = CrackDataLoader(base_path, stress_type="major") try: X, y, angle_bins, damage_list = loader.load_all_data(phase="both") except: print("Real data not found. Generating synthetic data...") X, y, angle_bins = loader.create_synthetic_data(n_samples=200, output_dim=72) stats = loader.get_statistics(X, y) print("\nData statistics:") for key, value in stats.items(): print(f" {key}: {value}") print("\n" + "=" * 80) print("Step 2: Splitting Dataset (Train:Val:Test = 64:16:20)") print("=" * 80) X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42 ) X_train, X_val, y_train, y_val = train_test_split( X_train, y_train, test_size=0.2, random_state=42 ) print(f"Training set: {X_train.shape[0]} samples") print(f"Validation set: {X_val.shape[0]} samples") print(f"Test set: {X_test.shape[0]} samples") print("\n" + "=" * 80) print("Step 3: Normalizing Data") print("=" * 80) X_train_norm, y_train_norm, X_val_norm, y_val_norm = loader.normalize_data( X_train, y_train, X_val, y_val ) X_test_norm = loader.scaler_X.transform(X_test) y_test_norm = loader.scaler_y.transform(y_test) print("Normalization complete") print("\n" + "=" * 80) print("Step 4: Creating Model") print("=" * 80) device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {device}") model = CrackTransformerPINN( input_dim=5, output_dim=y.shape[1], hidden_dims=[128, 256, 256, 128], dropout=0.2 ) n_params = sum(p.numel() for p in model.parameters()) print(f"Model parameters: {n_params:,}") print("\nModel components:") print(" - Transformer Encoder (8 heads, 4 layers)") print(" - Mogi-Coulomb Yield Criterion Layer") print(" - Weibull Strength Distribution Layer") print(" - Energy-based Damage Evolution Layer") print(" - PINN Decoder with Physics Constraints") print("\n" + "=" * 80) print("Step 5: Training Model") print("=" * 80) trainer = CrackPINNTrainer( model, device=device, lr=1e-3, weight_decay=1e-4 ) trainer.fit( X_train_norm, y_train_norm, X_val_norm, y_val_norm, epochs=300, batch_size=8, patience=50 ) print("\n" + "=" * 80) print("Step 6: Testing Model") print("=" * 80) test_loss, test_metrics = trainer.validate( torch.utils.data.DataLoader( torch.utils.data.TensorDataset( torch.FloatTensor(X_test_norm), torch.FloatTensor(y_test_norm) ), batch_size=8, shuffle=False ) ) print(f"Test set performance:") print(f" Loss: {test_loss:.4f}") print(f" R2: {test_metrics['r2']:.4f}") print(f" RMSE: {test_metrics['rmse']:.2f}") print(f" Total Count MAE: {test_metrics['total_count_mae']:.2f}") print("\n" + "=" * 80) print("Step 7: Saving Model") print("=" * 80) model_path = os.path.join(output_dir, "crack_transformer_pinn.pth") torch.save({ 'model_state_dict': model.state_dict(), 'model_config': { 'input_dim': 5, 'output_dim': y.shape[1], 'hidden_dims': [128, 256, 256, 128], 'dropout': 0.2 }, 'test_metrics': test_metrics }, model_path) print(f"Model saved to: {model_path}") scaler_path = os.path.join(output_dir, "scalers.pkl") with open(scaler_path, 'wb') as f: pickle.dump({ 'scaler_X': loader.scaler_X, 'scaler_y': loader.scaler_y, 'angle_bins': angle_bins }, f) print(f"Scalers saved to: {scaler_path}") print("\n" + "=" * 80) print("Step 8: Generating Visualizations") print("=" * 80) history_path = os.path.join(output_dir, "training_history.png") plot_training_history(trainer, save_path=history_path) examples_path = os.path.join(output_dir, "prediction_examples.png") plot_prediction_examples( X_test_norm, y_test_norm, trainer, angle_bins, loader, n_examples=4, save_path=examples_path ) damage_path = os.path.join(output_dir, "damage_analysis.png") plot_damage_analysis(X, y, save_path=damage_path) print("\n" + "=" * 80) print("Training Pipeline Complete!") print("=" * 80) print(f"\nGenerated files:") print(f" 1. Model checkpoint: {model_path}") print(f" 2. Scalers: {scaler_path}") print(f" 3. Training history: {history_path}") print(f" 4. Prediction examples: {examples_path}") print(f" 5. Damage analysis: {damage_path}") print("\nPhysics constraints applied:") print(" - Mogi-Coulomb yield criterion: tau_oct = C1 + C2 * sigma_m2") print(" - Weibull strength: D_q = 1 - exp(-(F/F0)^m)") print(" - Energy damage: D_n = (2/pi) * arctan(b * U_p)") print(" - Total damage: D_total = 1 - (1-D_ft)(1-D_ch)(1-D_th)") if __name__ == "__main__": main()