Transformer-PINN / train.py
guanwencan's picture
Upload 5 files
5e4dee3 verified
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import pickle
from data_loader import CrackDataLoader, DamageCalculator
from model import CrackTransformerPINN, CrackPINNTrainer
def plot_training_history(trainer, save_path=None):
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].plot(trainer.train_losses, label='Train Loss', linewidth=2)
axes[0].plot(trainer.val_losses, label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training History', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[1].semilogy(trainer.train_losses, label='Train Loss', linewidth=2)
axes[1].semilogy(trainer.val_losses, label='Val Loss', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Loss (log scale)', fontsize=12)
axes[1].set_title('Training History (Log Scale)', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
if len(trainer.train_losses) > 1:
train_improvement = np.diff(trainer.train_losses)
axes[2].plot(train_improvement, linewidth=2, alpha=0.7)
axes[2].axhline(y=0, color='r', linestyle='--', alpha=0.5)
axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('Loss Change', fontsize=12)
axes[2].set_title('Convergence Rate', fontsize=14, fontweight='bold')
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Training history saved to: {save_path}")
plt.close()
def plot_prediction_examples(X_test, y_test, trainer, angle_bins, loader, n_examples=4, save_path=None):
X_test_original = loader.scaler_X.inverse_transform(X_test)
y_test_original = loader.scaler_y.inverse_transform(y_test)
y_pred_norm, pred_totals = trainer.predict(X_test[:n_examples])
y_pred = loader.scaler_y.inverse_transform(y_pred_norm)
fig, axes = plt.subplots(2, 2, figsize=(14, 10), subplot_kw=dict(projection='polar'))
axes = axes.flatten()
for i in range(min(n_examples, len(axes))):
ax = axes[i]
theta = np.deg2rad(angle_bins)
ax.plot(theta, y_test_original[i], 'o-', label='True', linewidth=2, markersize=4, alpha=0.7)
ax.plot(theta, y_pred[i], 's-', label='Predicted', linewidth=2, markersize=3, alpha=0.7)
pH = X_test_original[i, 0]
FN = X_test_original[i, 1]
FT = X_test_original[i, 2]
T = X_test_original[i, 3]
phase = X_test_original[i, 4]
phase_str = "Unstable" if phase < 0.5 else "Peak"
D0 = DamageCalculator.compute_total_damage(pH, FN, FT, T)
lambda_coef = DamageCalculator.compute_lambda(D0)
true_total = y_test_original[i].sum()
pred_total = y_pred[i].sum()
title = f"pH={pH:.0f}, FN={FN:.0f}, FT={FT:.0f}, T={T:.0f}C\n"
title += f"D0={D0:.3f}, lambda={lambda_coef:.3f}\n"
title += f"{phase_str} | True: {true_total:.0f}, Pred: {pred_total:.0f}"
ax.set_title(title, fontsize=9, pad=20)
ax.legend(loc='upper right', fontsize=8)
ax.set_theta_zero_location('N')
ax.set_theta_direction(-1)
ax.grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Prediction examples saved to: {save_path}")
plt.close()
def plot_damage_analysis(X, y, save_path=None):
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
D0_values = []
total_cracks = y.sum(axis=1)
for i in range(X.shape[0]):
D0 = DamageCalculator.compute_total_damage(X[i, 0], X[i, 1], X[i, 2], X[i, 3])
D0_values.append(D0)
D0_values = np.array(D0_values)
axes[0, 0].scatter(D0_values, total_cracks, alpha=0.6, edgecolors='black', linewidth=0.5)
axes[0, 0].set_xlabel('Initial Damage Factor D0', fontsize=12)
axes[0, 0].set_ylabel('Total Crack Count', fontsize=12)
axes[0, 0].set_title('D0 vs Total Cracks', fontsize=14, fontweight='bold')
axes[0, 0].grid(True, alpha=0.3)
axes[0, 1].scatter(X[:, 0], total_cracks, alpha=0.6, c=D0_values, cmap='viridis')
axes[0, 1].set_xlabel('pH Value', fontsize=12)
axes[0, 1].set_ylabel('Total Crack Count', fontsize=12)
axes[0, 1].set_title('pH vs Total Cracks', fontsize=14, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)
axes[1, 0].scatter(X[:, 1], total_cracks, alpha=0.6, c=D0_values, cmap='viridis')
axes[1, 0].set_xlabel('Freeze-thaw Cycles (FN)', fontsize=12)
axes[1, 0].set_ylabel('Total Crack Count', fontsize=12)
axes[1, 0].set_title('FN vs Total Cracks', fontsize=14, fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)
scatter = axes[1, 1].scatter(X[:, 3], total_cracks, alpha=0.6, c=D0_values, cmap='viridis')
axes[1, 1].set_xlabel('Damage Temperature (T)', fontsize=12)
axes[1, 1].set_ylabel('Total Crack Count', fontsize=12)
axes[1, 1].set_title('T vs Total Cracks', fontsize=14, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)
plt.colorbar(scatter, ax=axes[1, 1], label='D0')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Damage analysis saved to: {save_path}")
plt.close()
def main():
print("=" * 80)
print("Transformer-PINN Crack Prediction Model")
print("Based on: Mechanism of micro-damage evolution in rocks")
print("under multiple coupled cyclic stresses")
print("=" * 80)
base_path = "./data"
output_dir = "./output"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print("\n" + "=" * 80)
print("Step 1: Loading/Generating Data")
print("=" * 80)
loader = CrackDataLoader(base_path, stress_type="major")
try:
X, y, angle_bins, damage_list = loader.load_all_data(phase="both")
except:
print("Real data not found. Generating synthetic data...")
X, y, angle_bins = loader.create_synthetic_data(n_samples=200, output_dim=72)
stats = loader.get_statistics(X, y)
print("\nData statistics:")
for key, value in stats.items():
print(f" {key}: {value}")
print("\n" + "=" * 80)
print("Step 2: Splitting Dataset (Train:Val:Test = 64:16:20)")
print("=" * 80)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
X_train, X_val, y_train, y_val = train_test_split(
X_train, y_train, test_size=0.2, random_state=42
)
print(f"Training set: {X_train.shape[0]} samples")
print(f"Validation set: {X_val.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")
print("\n" + "=" * 80)
print("Step 3: Normalizing Data")
print("=" * 80)
X_train_norm, y_train_norm, X_val_norm, y_val_norm = loader.normalize_data(
X_train, y_train, X_val, y_val
)
X_test_norm = loader.scaler_X.transform(X_test)
y_test_norm = loader.scaler_y.transform(y_test)
print("Normalization complete")
print("\n" + "=" * 80)
print("Step 4: Creating Model")
print("=" * 80)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
model = CrackTransformerPINN(
input_dim=5,
output_dim=y.shape[1],
hidden_dims=[128, 256, 256, 128],
dropout=0.2
)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")
print("\nModel components:")
print(" - Transformer Encoder (8 heads, 4 layers)")
print(" - Mogi-Coulomb Yield Criterion Layer")
print(" - Weibull Strength Distribution Layer")
print(" - Energy-based Damage Evolution Layer")
print(" - PINN Decoder with Physics Constraints")
print("\n" + "=" * 80)
print("Step 5: Training Model")
print("=" * 80)
trainer = CrackPINNTrainer(
model,
device=device,
lr=1e-3,
weight_decay=1e-4
)
trainer.fit(
X_train_norm, y_train_norm,
X_val_norm, y_val_norm,
epochs=300,
batch_size=8,
patience=50
)
print("\n" + "=" * 80)
print("Step 6: Testing Model")
print("=" * 80)
test_loss, test_metrics = trainer.validate(
torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(
torch.FloatTensor(X_test_norm),
torch.FloatTensor(y_test_norm)
),
batch_size=8,
shuffle=False
)
)
print(f"Test set performance:")
print(f" Loss: {test_loss:.4f}")
print(f" R2: {test_metrics['r2']:.4f}")
print(f" RMSE: {test_metrics['rmse']:.2f}")
print(f" Total Count MAE: {test_metrics['total_count_mae']:.2f}")
print("\n" + "=" * 80)
print("Step 7: Saving Model")
print("=" * 80)
model_path = os.path.join(output_dir, "crack_transformer_pinn.pth")
torch.save({
'model_state_dict': model.state_dict(),
'model_config': {
'input_dim': 5,
'output_dim': y.shape[1],
'hidden_dims': [128, 256, 256, 128],
'dropout': 0.2
},
'test_metrics': test_metrics
}, model_path)
print(f"Model saved to: {model_path}")
scaler_path = os.path.join(output_dir, "scalers.pkl")
with open(scaler_path, 'wb') as f:
pickle.dump({
'scaler_X': loader.scaler_X,
'scaler_y': loader.scaler_y,
'angle_bins': angle_bins
}, f)
print(f"Scalers saved to: {scaler_path}")
print("\n" + "=" * 80)
print("Step 8: Generating Visualizations")
print("=" * 80)
history_path = os.path.join(output_dir, "training_history.png")
plot_training_history(trainer, save_path=history_path)
examples_path = os.path.join(output_dir, "prediction_examples.png")
plot_prediction_examples(
X_test_norm, y_test_norm,
trainer, angle_bins, loader,
n_examples=4,
save_path=examples_path
)
damage_path = os.path.join(output_dir, "damage_analysis.png")
plot_damage_analysis(X, y, save_path=damage_path)
print("\n" + "=" * 80)
print("Training Pipeline Complete!")
print("=" * 80)
print(f"\nGenerated files:")
print(f" 1. Model checkpoint: {model_path}")
print(f" 2. Scalers: {scaler_path}")
print(f" 3. Training history: {history_path}")
print(f" 4. Prediction examples: {examples_path}")
print(f" 5. Damage analysis: {damage_path}")
print("\nPhysics constraints applied:")
print(" - Mogi-Coulomb yield criterion: tau_oct = C1 + C2 * sigma_m2")
print(" - Weibull strength: D_q = 1 - exp(-(F/F0)^m)")
print(" - Energy damage: D_n = (2/pi) * arctan(b * U_p)")
print(" - Total damage: D_total = 1 - (1-D_ft)(1-D_ch)(1-D_th)")
if __name__ == "__main__":
main()