Upload 5 files
Browse files- data_loader.py +239 -0
- inference.py +238 -0
- model.py +519 -0
- requirements.txt +5 -0
- train.py +331 -0
data_loader.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from glob import glob
|
| 6 |
+
from sklearn.preprocessing import StandardScaler
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DamageCalculator:
|
| 10 |
+
|
| 11 |
+
@staticmethod
|
| 12 |
+
def compute_freeze_thaw_damage(FN, FT, a1=0.002, b1=1.0, c1=0.02):
|
| 13 |
+
return a1 * (FN ** b1) * np.exp(c1 * FT)
|
| 14 |
+
|
| 15 |
+
@staticmethod
|
| 16 |
+
def compute_chemical_damage(pH, a2=0.01, b2=1.5):
|
| 17 |
+
return a2 * np.abs(pH - 7.0) ** b2
|
| 18 |
+
|
| 19 |
+
@staticmethod
|
| 20 |
+
def compute_thermal_damage(T, T0=100.0, a3=0.0003, b3=1.2):
|
| 21 |
+
if T < T0:
|
| 22 |
+
return 0.0
|
| 23 |
+
return a3 * ((T - T0) ** b3)
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def compute_total_damage(pH, FN, FT, T):
|
| 27 |
+
D_ft = DamageCalculator.compute_freeze_thaw_damage(FN, FT)
|
| 28 |
+
D_ch = DamageCalculator.compute_chemical_damage(pH)
|
| 29 |
+
D_th = DamageCalculator.compute_thermal_damage(T)
|
| 30 |
+
|
| 31 |
+
D_total = 1.0 - (1.0 - D_ft) * (1.0 - D_ch) * (1.0 - D_th)
|
| 32 |
+
return np.clip(D_total, 0.0, 0.99)
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def compute_lambda(D0):
|
| 36 |
+
return 1.0 - D0
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class CrackDataLoader:
|
| 40 |
+
|
| 41 |
+
def __init__(self, base_path, stress_type="major"):
|
| 42 |
+
self.base_path = base_path
|
| 43 |
+
self.stress_type = stress_type
|
| 44 |
+
|
| 45 |
+
if stress_type == "major":
|
| 46 |
+
self.data_dir = os.path.join(base_path, "major_principal_stress")
|
| 47 |
+
else:
|
| 48 |
+
self.data_dir = os.path.join(base_path, "minor_principal_stress")
|
| 49 |
+
|
| 50 |
+
self.scaler_X = StandardScaler()
|
| 51 |
+
self.scaler_y = StandardScaler()
|
| 52 |
+
self.damage_calculator = DamageCalculator()
|
| 53 |
+
|
| 54 |
+
def parse_filename(self, filename):
|
| 55 |
+
pattern = r'(\d+)-(\d+)-(\d+)-(\d+)'
|
| 56 |
+
match = re.search(pattern, filename)
|
| 57 |
+
|
| 58 |
+
if match:
|
| 59 |
+
pH = int(match.group(1))
|
| 60 |
+
FN = int(match.group(2))
|
| 61 |
+
FT = int(match.group(3))
|
| 62 |
+
T = int(match.group(4))
|
| 63 |
+
|
| 64 |
+
return {
|
| 65 |
+
'pH': pH,
|
| 66 |
+
'FN': FN,
|
| 67 |
+
'FT': FT,
|
| 68 |
+
'T': T
|
| 69 |
+
}
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(f"Cannot parse filename: {filename}")
|
| 72 |
+
|
| 73 |
+
def load_single_csv(self, csv_path):
|
| 74 |
+
data = pd.read_csv(csv_path, header=None, names=['angle', 'count'])
|
| 75 |
+
angles = data['angle'].values
|
| 76 |
+
counts = data['count'].values
|
| 77 |
+
return angles, counts
|
| 78 |
+
|
| 79 |
+
def load_all_data(self, phase="both"):
|
| 80 |
+
X_list = []
|
| 81 |
+
y_list = []
|
| 82 |
+
damage_list = []
|
| 83 |
+
|
| 84 |
+
if phase == "both":
|
| 85 |
+
subdirs = ["unstable_development", "peak_stress"]
|
| 86 |
+
elif phase == "early":
|
| 87 |
+
subdirs = ["unstable_development"]
|
| 88 |
+
elif phase == "peak":
|
| 89 |
+
subdirs = ["peak_stress"]
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError(f"Unknown phase: {phase}")
|
| 92 |
+
|
| 93 |
+
for subdir in subdirs:
|
| 94 |
+
subdir_path = os.path.join(self.data_dir, subdir)
|
| 95 |
+
|
| 96 |
+
if not os.path.exists(subdir_path):
|
| 97 |
+
print(f"Warning: Directory does not exist {subdir_path}")
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
phase_code = 0 if "unstable" in subdir else 1
|
| 101 |
+
|
| 102 |
+
csv_files = glob(os.path.join(subdir_path, "*.csv"))
|
| 103 |
+
|
| 104 |
+
print(f"Loading {len(csv_files)} files from {subdir}...")
|
| 105 |
+
|
| 106 |
+
for csv_file in csv_files:
|
| 107 |
+
try:
|
| 108 |
+
params = self.parse_filename(os.path.basename(csv_file))
|
| 109 |
+
|
| 110 |
+
angles, counts = self.load_single_csv(csv_file)
|
| 111 |
+
|
| 112 |
+
D0 = DamageCalculator.compute_total_damage(
|
| 113 |
+
params['pH'], params['FN'], params['FT'], params['T']
|
| 114 |
+
)
|
| 115 |
+
lambda_coef = DamageCalculator.compute_lambda(D0)
|
| 116 |
+
|
| 117 |
+
features = np.array([
|
| 118 |
+
params['pH'],
|
| 119 |
+
params['FN'],
|
| 120 |
+
params['FT'],
|
| 121 |
+
params['T'],
|
| 122 |
+
phase_code
|
| 123 |
+
], dtype=np.float32)
|
| 124 |
+
|
| 125 |
+
X_list.append(features)
|
| 126 |
+
y_list.append(counts)
|
| 127 |
+
damage_list.append({'D0': D0, 'lambda': lambda_coef})
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(f"Skipping file {csv_file}: {e}")
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
if len(X_list) == 0:
|
| 134 |
+
raise ValueError("No data loaded successfully!")
|
| 135 |
+
|
| 136 |
+
X = np.array(X_list)
|
| 137 |
+
|
| 138 |
+
y_length = len(y_list[0])
|
| 139 |
+
y_padded = []
|
| 140 |
+
|
| 141 |
+
for y_sample in y_list:
|
| 142 |
+
if len(y_sample) < y_length:
|
| 143 |
+
y_sample = np.pad(y_sample, (0, y_length - len(y_sample)), 'constant')
|
| 144 |
+
elif len(y_sample) > y_length:
|
| 145 |
+
y_sample = y_sample[:y_length]
|
| 146 |
+
y_padded.append(y_sample)
|
| 147 |
+
|
| 148 |
+
y = np.array(y_padded)
|
| 149 |
+
|
| 150 |
+
angles, _ = self.load_single_csv(csv_files[0])
|
| 151 |
+
angle_bins = angles[:y_length]
|
| 152 |
+
|
| 153 |
+
print(f"\nData loading complete:")
|
| 154 |
+
print(f" Samples: {X.shape[0]}")
|
| 155 |
+
print(f" Input features: {X.shape[1]} (pH, FN, FT, T, phase)")
|
| 156 |
+
print(f" Output dimension: {y.shape[1]} (angle bins)")
|
| 157 |
+
print(f" Angle range: {angle_bins[0]:.1f} - {angle_bins[-1]:.1f}")
|
| 158 |
+
print(f" Total cracks range: {y.sum(axis=1).min():.0f} - {y.sum(axis=1).max():.0f}")
|
| 159 |
+
|
| 160 |
+
return X, y, angle_bins, damage_list
|
| 161 |
+
|
| 162 |
+
def create_synthetic_data(self, n_samples=100, output_dim=72):
|
| 163 |
+
pH_values = [1, 3, 5, 7]
|
| 164 |
+
FN_values = [5, 10, 20, 40]
|
| 165 |
+
FT_values = [10, 20, 30, 40]
|
| 166 |
+
T_values = [25, 300, 600, 900]
|
| 167 |
+
phase_values = [0, 1]
|
| 168 |
+
|
| 169 |
+
X_list = []
|
| 170 |
+
y_list = []
|
| 171 |
+
|
| 172 |
+
for _ in range(n_samples):
|
| 173 |
+
pH = np.random.choice(pH_values)
|
| 174 |
+
FN = np.random.choice(FN_values)
|
| 175 |
+
FT = np.random.choice(FT_values)
|
| 176 |
+
T = np.random.choice(T_values)
|
| 177 |
+
phase = np.random.choice(phase_values)
|
| 178 |
+
|
| 179 |
+
D0 = DamageCalculator.compute_total_damage(pH, FN, FT, T)
|
| 180 |
+
|
| 181 |
+
if self.stress_type == "major":
|
| 182 |
+
peak_angle = 90.0 + np.random.normal(0, 10)
|
| 183 |
+
spread = 15.0 + D0 * 20.0
|
| 184 |
+
else:
|
| 185 |
+
peak_angle = 45.0 + np.random.normal(0, 15)
|
| 186 |
+
spread = 20.0 + D0 * 25.0
|
| 187 |
+
|
| 188 |
+
angles = np.linspace(0, 175, output_dim)
|
| 189 |
+
distribution = np.exp(-0.5 * ((angles - peak_angle) / spread) ** 2)
|
| 190 |
+
distribution = distribution * (100 + D0 * 200) * (1 + 0.5 * phase)
|
| 191 |
+
distribution = distribution + np.random.normal(0, 5, output_dim)
|
| 192 |
+
distribution = np.maximum(distribution, 0)
|
| 193 |
+
|
| 194 |
+
X_list.append([pH, FN, FT, T, phase])
|
| 195 |
+
y_list.append(distribution)
|
| 196 |
+
|
| 197 |
+
X = np.array(X_list, dtype=np.float32)
|
| 198 |
+
y = np.array(y_list, dtype=np.float32)
|
| 199 |
+
angle_bins = np.linspace(0, 175, output_dim)
|
| 200 |
+
|
| 201 |
+
return X, y, angle_bins
|
| 202 |
+
|
| 203 |
+
def normalize_data(self, X_train, y_train, X_test=None, y_test=None):
|
| 204 |
+
X_train_norm = self.scaler_X.fit_transform(X_train)
|
| 205 |
+
y_train_norm = self.scaler_y.fit_transform(y_train)
|
| 206 |
+
|
| 207 |
+
if X_test is not None and y_test is not None:
|
| 208 |
+
X_test_norm = self.scaler_X.transform(X_test)
|
| 209 |
+
y_test_norm = self.scaler_y.transform(y_test)
|
| 210 |
+
return X_train_norm, y_train_norm, X_test_norm, y_test_norm
|
| 211 |
+
else:
|
| 212 |
+
return X_train_norm, y_train_norm
|
| 213 |
+
|
| 214 |
+
def denormalize_output(self, y_norm):
|
| 215 |
+
return self.scaler_y.inverse_transform(y_norm)
|
| 216 |
+
|
| 217 |
+
def get_statistics(self, X, y):
|
| 218 |
+
stats = {
|
| 219 |
+
'n_samples': X.shape[0],
|
| 220 |
+
'input_dim': X.shape[1],
|
| 221 |
+
'output_dim': y.shape[1],
|
| 222 |
+
'pH_range': (X[:, 0].min(), X[:, 0].max()),
|
| 223 |
+
'FN_range': (X[:, 1].min(), X[:, 1].max()),
|
| 224 |
+
'FT_range': (X[:, 2].min(), X[:, 2].max()),
|
| 225 |
+
'T_range': (X[:, 3].min(), X[:, 3].max()),
|
| 226 |
+
'total_cracks_range': (y.sum(axis=1).min(), y.sum(axis=1).max()),
|
| 227 |
+
'total_cracks_mean': y.sum(axis=1).mean(),
|
| 228 |
+
'total_cracks_std': y.sum(axis=1).std(),
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
D0_values = []
|
| 232 |
+
for i in range(X.shape[0]):
|
| 233 |
+
D0 = DamageCalculator.compute_total_damage(X[i, 0], X[i, 1], X[i, 2], X[i, 3])
|
| 234 |
+
D0_values.append(D0)
|
| 235 |
+
|
| 236 |
+
stats['D0_range'] = (min(D0_values), max(D0_values))
|
| 237 |
+
stats['D0_mean'] = np.mean(D0_values)
|
| 238 |
+
|
| 239 |
+
return stats
|
inference.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import pickle
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
|
| 7 |
+
from model import CrackTransformerPINN
|
| 8 |
+
from data_loader import DamageCalculator
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CrackPredictor:
|
| 12 |
+
|
| 13 |
+
def __init__(self, model_path, scaler_path, device='cpu'):
|
| 14 |
+
self.device = device
|
| 15 |
+
|
| 16 |
+
with open(scaler_path, 'rb') as f:
|
| 17 |
+
scalers = pickle.load(f)
|
| 18 |
+
|
| 19 |
+
self.scaler_X = scalers['scaler_X']
|
| 20 |
+
self.scaler_y = scalers['scaler_y']
|
| 21 |
+
self.angle_bins = scalers['angle_bins']
|
| 22 |
+
|
| 23 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 24 |
+
|
| 25 |
+
if 'model_config' in checkpoint:
|
| 26 |
+
config = checkpoint['model_config']
|
| 27 |
+
self.model = CrackTransformerPINN(
|
| 28 |
+
input_dim=config['input_dim'],
|
| 29 |
+
output_dim=config['output_dim'],
|
| 30 |
+
hidden_dims=config['hidden_dims'],
|
| 31 |
+
dropout=config['dropout']
|
| 32 |
+
)
|
| 33 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 34 |
+
else:
|
| 35 |
+
self.model = CrackTransformerPINN(
|
| 36 |
+
input_dim=5,
|
| 37 |
+
output_dim=len(self.angle_bins),
|
| 38 |
+
hidden_dims=[128, 256, 256, 128],
|
| 39 |
+
dropout=0.2
|
| 40 |
+
)
|
| 41 |
+
self.model.load_state_dict(checkpoint)
|
| 42 |
+
|
| 43 |
+
self.model.to(device)
|
| 44 |
+
self.model.eval()
|
| 45 |
+
|
| 46 |
+
def predict(self, pH, FN, FT, T, phase):
|
| 47 |
+
X = np.array([[pH, FN, FT, T, phase]], dtype=np.float32)
|
| 48 |
+
|
| 49 |
+
X_norm = self.scaler_X.transform(X)
|
| 50 |
+
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
X_tensor = torch.FloatTensor(X_norm).to(self.device)
|
| 53 |
+
pred_dist_norm, pred_total = self.model(X_tensor, return_physics=False)
|
| 54 |
+
|
| 55 |
+
pred_dist_norm = pred_dist_norm.cpu().numpy()
|
| 56 |
+
pred_total = pred_total.cpu().numpy().flatten()
|
| 57 |
+
|
| 58 |
+
pred_dist = self.scaler_y.inverse_transform(pred_dist_norm)
|
| 59 |
+
|
| 60 |
+
D0 = DamageCalculator.compute_total_damage(pH, FN, FT, T)
|
| 61 |
+
lambda_coef = DamageCalculator.compute_lambda(D0)
|
| 62 |
+
|
| 63 |
+
return {
|
| 64 |
+
'angle_distribution': pred_dist[0],
|
| 65 |
+
'total_count': pred_total[0],
|
| 66 |
+
'D0': D0,
|
| 67 |
+
'lambda': lambda_coef
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
def predict_with_physics(self, pH, FN, FT, T, phase):
|
| 71 |
+
X = np.array([[pH, FN, FT, T, phase]], dtype=np.float32)
|
| 72 |
+
|
| 73 |
+
X_norm = self.scaler_X.transform(X)
|
| 74 |
+
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
X_tensor = torch.FloatTensor(X_norm).to(self.device)
|
| 77 |
+
pred_dist_norm, pred_total, physics = self.model(X_tensor, return_physics=True)
|
| 78 |
+
|
| 79 |
+
pred_dist_norm = pred_dist_norm.cpu().numpy()
|
| 80 |
+
|
| 81 |
+
pred_dist = self.scaler_y.inverse_transform(pred_dist_norm)
|
| 82 |
+
|
| 83 |
+
result = {
|
| 84 |
+
'angle_distribution': pred_dist[0],
|
| 85 |
+
'total_count': pred_total.cpu().numpy().flatten()[0],
|
| 86 |
+
'D0': physics['D0'].cpu().numpy().flatten()[0],
|
| 87 |
+
'lambda': physics['lambda'].cpu().numpy().flatten()[0],
|
| 88 |
+
'D_n': physics['D_n'].cpu().numpy().flatten()[0],
|
| 89 |
+
'tau_oct': physics['tau_oct'].cpu().numpy().flatten()[0],
|
| 90 |
+
'yield_stress': physics['yield_stress'].cpu().numpy().flatten()[0],
|
| 91 |
+
'C1': physics['C1'].cpu().numpy().flatten()[0],
|
| 92 |
+
'C2': physics['C2'].cpu().numpy().flatten()[0],
|
| 93 |
+
'D_q': physics['D_q'].cpu().numpy().flatten()[0],
|
| 94 |
+
'm': physics['m'].cpu().numpy().flatten()[0],
|
| 95 |
+
'F0': physics['F0'].cpu().numpy().flatten()[0]
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
return result
|
| 99 |
+
|
| 100 |
+
def predict_batch(self, X):
|
| 101 |
+
X_norm = self.scaler_X.transform(X)
|
| 102 |
+
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
X_tensor = torch.FloatTensor(X_norm).to(self.device)
|
| 105 |
+
pred_dist_norm, pred_total = self.model(X_tensor, return_physics=False)
|
| 106 |
+
|
| 107 |
+
pred_dist_norm = pred_dist_norm.cpu().numpy()
|
| 108 |
+
pred_total = pred_total.cpu().numpy().flatten()
|
| 109 |
+
|
| 110 |
+
pred_dist = self.scaler_y.inverse_transform(pred_dist_norm)
|
| 111 |
+
|
| 112 |
+
return pred_dist, pred_total
|
| 113 |
+
|
| 114 |
+
def visualize(self, result, title=None, save_path=None):
|
| 115 |
+
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(projection='polar'))
|
| 116 |
+
|
| 117 |
+
theta = np.deg2rad(self.angle_bins)
|
| 118 |
+
pred_dist = result['angle_distribution']
|
| 119 |
+
|
| 120 |
+
ax.plot(theta, pred_dist, 'o-', linewidth=2, markersize=4, color='blue')
|
| 121 |
+
ax.fill(theta, pred_dist, alpha=0.3, color='blue')
|
| 122 |
+
|
| 123 |
+
if title:
|
| 124 |
+
ax.set_title(title, fontsize=14, pad=20)
|
| 125 |
+
else:
|
| 126 |
+
total = result['total_count']
|
| 127 |
+
D0 = result['D0']
|
| 128 |
+
lam = result['lambda']
|
| 129 |
+
ax.set_title(f'Predicted Crack Distribution\nTotal: {total:.0f}, D0: {D0:.3f}, lambda: {lam:.3f}',
|
| 130 |
+
fontsize=12, pad=20)
|
| 131 |
+
|
| 132 |
+
ax.set_theta_zero_location('N')
|
| 133 |
+
ax.set_theta_direction(-1)
|
| 134 |
+
ax.grid(True, alpha=0.3)
|
| 135 |
+
|
| 136 |
+
if save_path:
|
| 137 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 138 |
+
|
| 139 |
+
plt.show()
|
| 140 |
+
|
| 141 |
+
def compare_stress_types(self, pH, FN, FT, T, save_path=None):
|
| 142 |
+
result_unstable = self.predict(pH, FN, FT, T, phase=0)
|
| 143 |
+
result_peak = self.predict(pH, FN, FT, T, phase=1)
|
| 144 |
+
|
| 145 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6), subplot_kw=dict(projection='polar'))
|
| 146 |
+
|
| 147 |
+
theta = np.deg2rad(self.angle_bins)
|
| 148 |
+
|
| 149 |
+
axes[0].plot(theta, result_unstable['angle_distribution'], 'o-',
|
| 150 |
+
linewidth=2, markersize=4, color='blue')
|
| 151 |
+
axes[0].fill(theta, result_unstable['angle_distribution'], alpha=0.3, color='blue')
|
| 152 |
+
axes[0].set_title(f'Unstable Development Phase\nTotal: {result_unstable["total_count"]:.0f}',
|
| 153 |
+
fontsize=12, pad=20)
|
| 154 |
+
axes[0].set_theta_zero_location('N')
|
| 155 |
+
axes[0].set_theta_direction(-1)
|
| 156 |
+
axes[0].grid(True, alpha=0.3)
|
| 157 |
+
|
| 158 |
+
axes[1].plot(theta, result_peak['angle_distribution'], 'o-',
|
| 159 |
+
linewidth=2, markersize=4, color='red')
|
| 160 |
+
axes[1].fill(theta, result_peak['angle_distribution'], alpha=0.3, color='red')
|
| 161 |
+
axes[1].set_title(f'Peak Stress Phase\nTotal: {result_peak["total_count"]:.0f}',
|
| 162 |
+
fontsize=12, pad=20)
|
| 163 |
+
axes[1].set_theta_zero_location('N')
|
| 164 |
+
axes[1].set_theta_direction(-1)
|
| 165 |
+
axes[1].grid(True, alpha=0.3)
|
| 166 |
+
|
| 167 |
+
D0 = result_unstable['D0']
|
| 168 |
+
lam = result_unstable['lambda']
|
| 169 |
+
fig.suptitle(f'pH={pH}, FN={FN}, FT={FT}, T={T}\nD0={D0:.3f}, lambda={lam:.3f}',
|
| 170 |
+
fontsize=14, fontweight='bold')
|
| 171 |
+
|
| 172 |
+
plt.tight_layout()
|
| 173 |
+
|
| 174 |
+
if save_path:
|
| 175 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 176 |
+
|
| 177 |
+
plt.show()
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def main():
|
| 181 |
+
model_path = "./output/crack_transformer_pinn.pth"
|
| 182 |
+
scaler_path = "./output/scalers.pkl"
|
| 183 |
+
|
| 184 |
+
if not os.path.exists(model_path):
|
| 185 |
+
print("Model file not found. Please train the model first using train.py")
|
| 186 |
+
return
|
| 187 |
+
|
| 188 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 189 |
+
predictor = CrackPredictor(model_path, scaler_path, device=device)
|
| 190 |
+
|
| 191 |
+
print("=" * 60)
|
| 192 |
+
print("Transformer-PINN Crack Prediction - Inference Demo")
|
| 193 |
+
print("=" * 60)
|
| 194 |
+
|
| 195 |
+
test_cases = [
|
| 196 |
+
{'pH': 3, 'FN': 30, 'FT': 40, 'T': 25, 'phase': 0},
|
| 197 |
+
{'pH': 3, 'FN': 40, 'FT': 20, 'T': 300, 'phase': 1},
|
| 198 |
+
{'pH': 7, 'FN': 10, 'FT': 40, 'T': 300, 'phase': 1},
|
| 199 |
+
{'pH': 5, 'FN': 10, 'FT': 20, 'T': 900, 'phase': 0},
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
for i, params in enumerate(test_cases):
|
| 203 |
+
result = predictor.predict(**params)
|
| 204 |
+
|
| 205 |
+
print(f"\nTest Case {i+1}:")
|
| 206 |
+
print(f" Input: pH={params['pH']}, FN={params['FN']}, FT={params['FT']}, T={params['T']}, phase={params['phase']}")
|
| 207 |
+
print(f" D0 (Initial Damage): {result['D0']:.4f}")
|
| 208 |
+
print(f" Lambda (Damage Coefficient): {result['lambda']:.4f}")
|
| 209 |
+
print(f" Predicted Total Cracks: {result['total_count']:.0f}")
|
| 210 |
+
print(f" Peak Angle: {predictor.angle_bins[result['angle_distribution'].argmax()]:.1f} degrees")
|
| 211 |
+
print(f" Peak Count: {result['angle_distribution'].max():.0f}")
|
| 212 |
+
|
| 213 |
+
print("\n" + "=" * 60)
|
| 214 |
+
print("Testing Physics Output")
|
| 215 |
+
print("=" * 60)
|
| 216 |
+
|
| 217 |
+
physics_result = predictor.predict_with_physics(pH=3, FN=30, FT=40, T=25, phase=1)
|
| 218 |
+
|
| 219 |
+
print("\nPhysics Parameters:")
|
| 220 |
+
print(f" Mogi-Coulomb:")
|
| 221 |
+
print(f" tau_oct: {physics_result['tau_oct']:.4f}")
|
| 222 |
+
print(f" yield_stress: {physics_result['yield_stress']:.4f}")
|
| 223 |
+
print(f" C1: {physics_result['C1']:.4f}")
|
| 224 |
+
print(f" C2: {physics_result['C2']:.4f}")
|
| 225 |
+
print(f" Weibull Distribution:")
|
| 226 |
+
print(f" D_q: {physics_result['D_q']:.4f}")
|
| 227 |
+
print(f" m: {physics_result['m']:.4f}")
|
| 228 |
+
print(f" F0: {physics_result['F0']:.4f}")
|
| 229 |
+
print(f" Energy Damage:")
|
| 230 |
+
print(f" D_n: {physics_result['D_n']:.4f}")
|
| 231 |
+
|
| 232 |
+
print("\n" + "=" * 60)
|
| 233 |
+
print("Inference complete!")
|
| 234 |
+
print("=" * 60)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
if __name__ == "__main__":
|
| 238 |
+
main()
|
model.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.utils.data import TensorDataset, DataLoader
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MogiCoulombLayer(nn.Module):
|
| 10 |
+
|
| 11 |
+
def __init__(self, hidden_dim):
|
| 12 |
+
super(MogiCoulombLayer, self).__init__()
|
| 13 |
+
self.C1_net = nn.Linear(hidden_dim, 1)
|
| 14 |
+
self.C2_net = nn.Linear(hidden_dim, 1)
|
| 15 |
+
|
| 16 |
+
def forward(self, features, sigma1, sigma2, sigma3):
|
| 17 |
+
C1 = torch.relu(self.C1_net(features))
|
| 18 |
+
C2 = torch.sigmoid(self.C2_net(features))
|
| 19 |
+
|
| 20 |
+
tau_oct = (1.0/3.0) * torch.sqrt(
|
| 21 |
+
(sigma1 - sigma2)**2 + (sigma2 - sigma3)**2 + (sigma1 - sigma3)**2
|
| 22 |
+
)
|
| 23 |
+
sigma_m2 = (sigma1 + sigma3) / 2.0
|
| 24 |
+
|
| 25 |
+
yield_stress = C1 + C2 * sigma_m2
|
| 26 |
+
|
| 27 |
+
return tau_oct, yield_stress, C1, C2
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class WeibullStrengthLayer(nn.Module):
|
| 31 |
+
|
| 32 |
+
def __init__(self, hidden_dim):
|
| 33 |
+
super(WeibullStrengthLayer, self).__init__()
|
| 34 |
+
self.m_net = nn.Sequential(
|
| 35 |
+
nn.Linear(hidden_dim, 32),
|
| 36 |
+
nn.Tanh(),
|
| 37 |
+
nn.Linear(32, 1),
|
| 38 |
+
nn.Softplus()
|
| 39 |
+
)
|
| 40 |
+
self.F0_net = nn.Sequential(
|
| 41 |
+
nn.Linear(hidden_dim, 32),
|
| 42 |
+
nn.Tanh(),
|
| 43 |
+
nn.Linear(32, 1),
|
| 44 |
+
nn.Softplus()
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def forward(self, features, F):
|
| 48 |
+
m = self.m_net(features) + 1.0
|
| 49 |
+
F0 = self.F0_net(features) + 0.1
|
| 50 |
+
|
| 51 |
+
D_q = 1.0 - torch.exp(-torch.pow(F / F0, m))
|
| 52 |
+
|
| 53 |
+
return D_q, m, F0
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class EnergyDamageLayer(nn.Module):
|
| 57 |
+
|
| 58 |
+
def __init__(self, hidden_dim):
|
| 59 |
+
super(EnergyDamageLayer, self).__init__()
|
| 60 |
+
self.a_net = nn.Sequential(
|
| 61 |
+
nn.Linear(hidden_dim, 32),
|
| 62 |
+
nn.Tanh(),
|
| 63 |
+
nn.Linear(32, 1),
|
| 64 |
+
nn.Softplus()
|
| 65 |
+
)
|
| 66 |
+
self.b_net = nn.Sequential(
|
| 67 |
+
nn.Linear(hidden_dim, 32),
|
| 68 |
+
nn.Tanh(),
|
| 69 |
+
nn.Linear(32, 1),
|
| 70 |
+
nn.Softplus()
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def forward(self, features, delta_sigma, D0):
|
| 74 |
+
a = self.a_net(features) + 0.1
|
| 75 |
+
b = self.b_net(features) + 0.01
|
| 76 |
+
|
| 77 |
+
effective_stress = delta_sigma / (1.0 - D0 + 1e-8)
|
| 78 |
+
|
| 79 |
+
U_p = a * torch.exp(b * effective_stress)
|
| 80 |
+
|
| 81 |
+
D_n = (2.0 / np.pi) * torch.atan(b * U_p)
|
| 82 |
+
|
| 83 |
+
return D_n, U_p, a, b
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class CrackTransformerPINN(nn.Module):
|
| 87 |
+
|
| 88 |
+
def __init__(self, input_dim=5, output_dim=72, hidden_dims=[128, 256, 256, 128], dropout=0.2):
|
| 89 |
+
super(CrackTransformerPINN, self).__init__()
|
| 90 |
+
|
| 91 |
+
self.input_dim = input_dim
|
| 92 |
+
self.output_dim = output_dim
|
| 93 |
+
|
| 94 |
+
self.input_embedding = nn.Sequential(
|
| 95 |
+
nn.Linear(input_dim, hidden_dims[0]),
|
| 96 |
+
nn.LayerNorm(hidden_dims[0]),
|
| 97 |
+
nn.GELU(),
|
| 98 |
+
nn.Dropout(dropout * 0.5)
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
self.damage_encoder = nn.Sequential(
|
| 102 |
+
nn.Linear(input_dim, hidden_dims[0]),
|
| 103 |
+
nn.Tanh(),
|
| 104 |
+
nn.Linear(hidden_dims[0], hidden_dims[0])
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 108 |
+
d_model=hidden_dims[0],
|
| 109 |
+
nhead=8,
|
| 110 |
+
dim_feedforward=hidden_dims[0] * 4,
|
| 111 |
+
dropout=dropout,
|
| 112 |
+
activation='gelu',
|
| 113 |
+
batch_first=True,
|
| 114 |
+
norm_first=True
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.transformer_encoder = nn.TransformerEncoder(
|
| 118 |
+
encoder_layer,
|
| 119 |
+
num_layers=4,
|
| 120 |
+
norm=nn.LayerNorm(hidden_dims[0])
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
self.mogi_coulomb = MogiCoulombLayer(hidden_dims[0])
|
| 124 |
+
self.weibull_strength = WeibullStrengthLayer(hidden_dims[0])
|
| 125 |
+
self.energy_damage = EnergyDamageLayer(hidden_dims[0])
|
| 126 |
+
|
| 127 |
+
self.angle_decoder = nn.ModuleList()
|
| 128 |
+
prev_dim = hidden_dims[0] * 2
|
| 129 |
+
|
| 130 |
+
for hidden_dim in hidden_dims[1:]:
|
| 131 |
+
self.angle_decoder.append(nn.Linear(prev_dim, hidden_dim))
|
| 132 |
+
self.angle_decoder.append(nn.LayerNorm(hidden_dim))
|
| 133 |
+
self.angle_decoder.append(nn.Tanh())
|
| 134 |
+
self.angle_decoder.append(nn.Dropout(dropout))
|
| 135 |
+
prev_dim = hidden_dim
|
| 136 |
+
|
| 137 |
+
self.angle_output = nn.Sequential(
|
| 138 |
+
nn.Linear(prev_dim, output_dim),
|
| 139 |
+
nn.ReLU()
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
self.total_count_head = nn.Sequential(
|
| 143 |
+
nn.Linear(hidden_dims[0] * 2, 64),
|
| 144 |
+
nn.Tanh(),
|
| 145 |
+
nn.Linear(64, 1),
|
| 146 |
+
nn.ReLU()
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
self.damage_factor_head = nn.Sequential(
|
| 150 |
+
nn.Linear(hidden_dims[0], 32),
|
| 151 |
+
nn.Tanh(),
|
| 152 |
+
nn.Linear(32, 1),
|
| 153 |
+
nn.Sigmoid()
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def compute_initial_damage(self, pH, FN, FT, T):
|
| 157 |
+
D_ft = 0.002 * FN * torch.exp(0.02 * FT)
|
| 158 |
+
D_ch = 0.01 * torch.abs(pH - 7.0) ** 1.5
|
| 159 |
+
D_th = torch.where(
|
| 160 |
+
T > 100.0,
|
| 161 |
+
0.0003 * (T - 100.0) ** 1.2,
|
| 162 |
+
torch.zeros_like(T)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
D_total = 1.0 - (1.0 - D_ft) * (1.0 - D_ch) * (1.0 - D_th)
|
| 166 |
+
D_total = torch.clamp(D_total, 0.0, 0.99)
|
| 167 |
+
|
| 168 |
+
return D_total
|
| 169 |
+
|
| 170 |
+
def forward(self, x, return_physics=False):
|
| 171 |
+
batch_size = x.shape[0]
|
| 172 |
+
|
| 173 |
+
pH = x[:, 0:1]
|
| 174 |
+
FN = x[:, 1:2]
|
| 175 |
+
FT = x[:, 2:3]
|
| 176 |
+
T = x[:, 3:4]
|
| 177 |
+
phase = x[:, 4:5]
|
| 178 |
+
|
| 179 |
+
D0 = self.compute_initial_damage(pH, FN, FT, T)
|
| 180 |
+
lambda_coef = 1.0 - D0
|
| 181 |
+
|
| 182 |
+
x_embedded = self.input_embedding(x)
|
| 183 |
+
damage_features = self.damage_encoder(x)
|
| 184 |
+
|
| 185 |
+
x_seq = x_embedded.unsqueeze(1)
|
| 186 |
+
encoded = self.transformer_encoder(x_seq)
|
| 187 |
+
encoded = encoded.squeeze(1)
|
| 188 |
+
|
| 189 |
+
combined = torch.cat([encoded, damage_features], dim=-1)
|
| 190 |
+
|
| 191 |
+
h = combined
|
| 192 |
+
for layer in self.angle_decoder:
|
| 193 |
+
h = layer(h)
|
| 194 |
+
|
| 195 |
+
angle_dist = self.angle_output(h)
|
| 196 |
+
|
| 197 |
+
total_count = self.total_count_head(combined)
|
| 198 |
+
|
| 199 |
+
predicted_D0 = self.damage_factor_head(encoded)
|
| 200 |
+
|
| 201 |
+
if return_physics:
|
| 202 |
+
sigma1 = 100.0 * torch.ones(batch_size, 1, device=x.device)
|
| 203 |
+
sigma2 = 50.0 * torch.ones(batch_size, 1, device=x.device)
|
| 204 |
+
sigma3 = 30.0 * torch.ones(batch_size, 1, device=x.device)
|
| 205 |
+
delta_sigma = 20.0 * torch.ones(batch_size, 1, device=x.device)
|
| 206 |
+
F_contact = 10.0 * torch.ones(batch_size, 1, device=x.device)
|
| 207 |
+
|
| 208 |
+
tau_oct, yield_stress, C1, C2 = self.mogi_coulomb(encoded, sigma1, sigma2, sigma3)
|
| 209 |
+
D_q, m, F0 = self.weibull_strength(encoded, F_contact)
|
| 210 |
+
D_n, U_p, a, b = self.energy_damage(encoded, delta_sigma, D0)
|
| 211 |
+
|
| 212 |
+
physics_outputs = {
|
| 213 |
+
'D0': D0,
|
| 214 |
+
'lambda': lambda_coef,
|
| 215 |
+
'predicted_D0': predicted_D0,
|
| 216 |
+
'tau_oct': tau_oct,
|
| 217 |
+
'yield_stress': yield_stress,
|
| 218 |
+
'C1': C1,
|
| 219 |
+
'C2': C2,
|
| 220 |
+
'D_q': D_q,
|
| 221 |
+
'm': m,
|
| 222 |
+
'F0': F0,
|
| 223 |
+
'D_n': D_n,
|
| 224 |
+
'U_p': U_p,
|
| 225 |
+
'a': a,
|
| 226 |
+
'b': b
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
return angle_dist, total_count, physics_outputs
|
| 230 |
+
|
| 231 |
+
return angle_dist, total_count
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class CrackPINNLoss(nn.Module):
|
| 235 |
+
|
| 236 |
+
def __init__(self, lambda_data=1.0, lambda_physics=0.5, lambda_smooth=0.1,
|
| 237 |
+
lambda_damage=0.3, lambda_mogi=0.2, lambda_reg=1e-4):
|
| 238 |
+
super(CrackPINNLoss, self).__init__()
|
| 239 |
+
|
| 240 |
+
self.lambda_data = lambda_data
|
| 241 |
+
self.lambda_physics = lambda_physics
|
| 242 |
+
self.lambda_smooth = lambda_smooth
|
| 243 |
+
self.lambda_damage = lambda_damage
|
| 244 |
+
self.lambda_mogi = lambda_mogi
|
| 245 |
+
self.lambda_reg = lambda_reg
|
| 246 |
+
|
| 247 |
+
self.mse_loss = nn.MSELoss()
|
| 248 |
+
|
| 249 |
+
def data_loss(self, pred_dist, true_dist):
|
| 250 |
+
return self.mse_loss(pred_dist, true_dist)
|
| 251 |
+
|
| 252 |
+
def physics_loss(self, pred_dist, pred_total, true_dist):
|
| 253 |
+
pred_sum = pred_dist.sum(dim=1, keepdim=True)
|
| 254 |
+
true_sum = true_dist.sum(dim=1, keepdim=True)
|
| 255 |
+
|
| 256 |
+
loss_consistency = self.mse_loss(pred_sum, pred_total)
|
| 257 |
+
loss_total = self.mse_loss(pred_total, true_sum)
|
| 258 |
+
|
| 259 |
+
return loss_consistency + loss_total
|
| 260 |
+
|
| 261 |
+
def smoothness_loss(self, pred_dist):
|
| 262 |
+
diff = pred_dist[:, 1:] - pred_dist[:, :-1]
|
| 263 |
+
return torch.mean(diff ** 2)
|
| 264 |
+
|
| 265 |
+
def damage_consistency_loss(self, physics_outputs):
|
| 266 |
+
D0 = physics_outputs['D0']
|
| 267 |
+
predicted_D0 = physics_outputs['predicted_D0']
|
| 268 |
+
|
| 269 |
+
loss_D0 = self.mse_loss(predicted_D0, D0)
|
| 270 |
+
|
| 271 |
+
lambda_coef = physics_outputs['lambda']
|
| 272 |
+
loss_lambda = torch.mean(torch.relu(-lambda_coef) + torch.relu(lambda_coef - 1.0))
|
| 273 |
+
|
| 274 |
+
return loss_D0 + loss_lambda
|
| 275 |
+
|
| 276 |
+
def mogi_coulomb_loss(self, physics_outputs):
|
| 277 |
+
tau_oct = physics_outputs['tau_oct']
|
| 278 |
+
yield_stress = physics_outputs['yield_stress']
|
| 279 |
+
|
| 280 |
+
loss_yield = torch.mean(torch.relu(tau_oct - yield_stress))
|
| 281 |
+
|
| 282 |
+
C1 = physics_outputs['C1']
|
| 283 |
+
C2 = physics_outputs['C2']
|
| 284 |
+
loss_params = torch.mean(torch.relu(-C1)) + torch.mean(torch.relu(C2 - 1.0))
|
| 285 |
+
|
| 286 |
+
return loss_yield + 0.1 * loss_params
|
| 287 |
+
|
| 288 |
+
def regularization_loss(self, model):
|
| 289 |
+
l2_reg = torch.tensor(0.0, device=next(model.parameters()).device)
|
| 290 |
+
for param in model.parameters():
|
| 291 |
+
l2_reg += torch.norm(param, p=2) ** 2
|
| 292 |
+
return l2_reg
|
| 293 |
+
|
| 294 |
+
def forward(self, pred_dist, pred_total, true_dist, model, physics_outputs=None):
|
| 295 |
+
loss_data = self.data_loss(pred_dist, true_dist)
|
| 296 |
+
loss_physics = self.physics_loss(pred_dist, pred_total, true_dist)
|
| 297 |
+
loss_smooth = self.smoothness_loss(pred_dist)
|
| 298 |
+
loss_reg = self.regularization_loss(model)
|
| 299 |
+
|
| 300 |
+
loss_damage = torch.tensor(0.0, device=pred_dist.device)
|
| 301 |
+
loss_mogi = torch.tensor(0.0, device=pred_dist.device)
|
| 302 |
+
|
| 303 |
+
if physics_outputs is not None:
|
| 304 |
+
loss_damage = self.damage_consistency_loss(physics_outputs)
|
| 305 |
+
loss_mogi = self.mogi_coulomb_loss(physics_outputs)
|
| 306 |
+
|
| 307 |
+
total_loss = (self.lambda_data * loss_data +
|
| 308 |
+
self.lambda_physics * loss_physics +
|
| 309 |
+
self.lambda_smooth * loss_smooth +
|
| 310 |
+
self.lambda_damage * loss_damage +
|
| 311 |
+
self.lambda_mogi * loss_mogi +
|
| 312 |
+
self.lambda_reg * loss_reg)
|
| 313 |
+
|
| 314 |
+
loss_dict = {
|
| 315 |
+
'total': total_loss.item(),
|
| 316 |
+
'data': loss_data.item(),
|
| 317 |
+
'physics': loss_physics.item(),
|
| 318 |
+
'smooth': loss_smooth.item(),
|
| 319 |
+
'damage': loss_damage.item(),
|
| 320 |
+
'mogi': loss_mogi.item(),
|
| 321 |
+
'reg': loss_reg.item()
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
return total_loss, loss_dict
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class CrackPINNTrainer:
|
| 328 |
+
|
| 329 |
+
def __init__(self, model, device='cpu', lr=1e-3, weight_decay=1e-4):
|
| 330 |
+
self.model = model.to(device)
|
| 331 |
+
self.device = device
|
| 332 |
+
|
| 333 |
+
self.optimizer = optim.AdamW(
|
| 334 |
+
model.parameters(),
|
| 335 |
+
lr=lr,
|
| 336 |
+
weight_decay=weight_decay,
|
| 337 |
+
betas=(0.9, 0.999)
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 341 |
+
self.optimizer,
|
| 342 |
+
mode='min',
|
| 343 |
+
factor=0.5,
|
| 344 |
+
patience=10,
|
| 345 |
+
verbose=True
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
self.criterion = CrackPINNLoss(
|
| 349 |
+
lambda_data=1.0,
|
| 350 |
+
lambda_physics=0.5,
|
| 351 |
+
lambda_smooth=0.1,
|
| 352 |
+
lambda_damage=0.3,
|
| 353 |
+
lambda_mogi=0.2,
|
| 354 |
+
lambda_reg=1e-4
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
self.train_losses = []
|
| 358 |
+
self.val_losses = []
|
| 359 |
+
|
| 360 |
+
def train_epoch(self, train_loader):
|
| 361 |
+
self.model.train()
|
| 362 |
+
|
| 363 |
+
epoch_losses = []
|
| 364 |
+
loss_components = {
|
| 365 |
+
'data': [], 'physics': [], 'smooth': [],
|
| 366 |
+
'damage': [], 'mogi': [], 'reg': []
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
for X_batch, y_batch in train_loader:
|
| 370 |
+
X_batch = X_batch.to(self.device)
|
| 371 |
+
y_batch = y_batch.to(self.device)
|
| 372 |
+
|
| 373 |
+
pred_dist, pred_total, physics_outputs = self.model(X_batch, return_physics=True)
|
| 374 |
+
|
| 375 |
+
loss, loss_dict = self.criterion(
|
| 376 |
+
pred_dist, pred_total, y_batch, self.model, physics_outputs
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
self.optimizer.zero_grad()
|
| 380 |
+
loss.backward()
|
| 381 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 382 |
+
self.optimizer.step()
|
| 383 |
+
|
| 384 |
+
epoch_losses.append(loss_dict['total'])
|
| 385 |
+
for key in loss_components.keys():
|
| 386 |
+
loss_components[key].append(loss_dict[key])
|
| 387 |
+
|
| 388 |
+
avg_loss = np.mean(epoch_losses)
|
| 389 |
+
avg_components = {key: np.mean(values) for key, values in loss_components.items()}
|
| 390 |
+
|
| 391 |
+
return avg_loss, avg_components
|
| 392 |
+
|
| 393 |
+
def validate(self, val_loader):
|
| 394 |
+
self.model.eval()
|
| 395 |
+
|
| 396 |
+
val_losses = []
|
| 397 |
+
all_preds = []
|
| 398 |
+
all_trues = []
|
| 399 |
+
|
| 400 |
+
with torch.no_grad():
|
| 401 |
+
for X_batch, y_batch in val_loader:
|
| 402 |
+
X_batch = X_batch.to(self.device)
|
| 403 |
+
y_batch = y_batch.to(self.device)
|
| 404 |
+
|
| 405 |
+
pred_dist, pred_total = self.model(X_batch, return_physics=False)
|
| 406 |
+
|
| 407 |
+
loss, _ = self.criterion(pred_dist, pred_total, y_batch, self.model)
|
| 408 |
+
|
| 409 |
+
val_losses.append(loss.item())
|
| 410 |
+
all_preds.append(pred_dist.cpu().numpy())
|
| 411 |
+
all_trues.append(y_batch.cpu().numpy())
|
| 412 |
+
|
| 413 |
+
avg_loss = np.mean(val_losses)
|
| 414 |
+
|
| 415 |
+
all_preds = np.concatenate(all_preds, axis=0)
|
| 416 |
+
all_trues = np.concatenate(all_trues, axis=0)
|
| 417 |
+
|
| 418 |
+
ss_res = np.sum((all_trues - all_preds) ** 2)
|
| 419 |
+
ss_tot = np.sum((all_trues - np.mean(all_trues)) ** 2)
|
| 420 |
+
r2 = 1 - (ss_res / (ss_tot + 1e-8))
|
| 421 |
+
|
| 422 |
+
rmse = np.sqrt(np.mean((all_trues - all_preds) ** 2))
|
| 423 |
+
|
| 424 |
+
pred_total_counts = all_preds.sum(axis=1)
|
| 425 |
+
true_total_counts = all_trues.sum(axis=1)
|
| 426 |
+
total_count_error = np.mean(np.abs(pred_total_counts - true_total_counts))
|
| 427 |
+
|
| 428 |
+
metrics = {
|
| 429 |
+
'r2': r2,
|
| 430 |
+
'rmse': rmse,
|
| 431 |
+
'total_count_mae': total_count_error
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
return avg_loss, metrics
|
| 435 |
+
|
| 436 |
+
def fit(self, X_train, y_train, X_val, y_val, epochs=200, batch_size=16, patience=30):
|
| 437 |
+
train_dataset = TensorDataset(
|
| 438 |
+
torch.FloatTensor(X_train),
|
| 439 |
+
torch.FloatTensor(y_train)
|
| 440 |
+
)
|
| 441 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 442 |
+
|
| 443 |
+
val_dataset = TensorDataset(
|
| 444 |
+
torch.FloatTensor(X_val),
|
| 445 |
+
torch.FloatTensor(y_val)
|
| 446 |
+
)
|
| 447 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 448 |
+
|
| 449 |
+
best_val_loss = float('inf')
|
| 450 |
+
patience_counter = 0
|
| 451 |
+
best_model_state = None
|
| 452 |
+
|
| 453 |
+
print("\nStarting training...")
|
| 454 |
+
print("=" * 80)
|
| 455 |
+
|
| 456 |
+
for epoch in range(epochs):
|
| 457 |
+
train_loss, train_components = self.train_epoch(train_loader)
|
| 458 |
+
val_loss, val_metrics = self.validate(val_loader)
|
| 459 |
+
|
| 460 |
+
self.train_losses.append(train_loss)
|
| 461 |
+
self.val_losses.append(val_loss)
|
| 462 |
+
|
| 463 |
+
self.scheduler.step(val_loss)
|
| 464 |
+
|
| 465 |
+
if (epoch + 1) % 10 == 0 or epoch == 0:
|
| 466 |
+
print(f"Epoch {epoch+1}/{epochs}")
|
| 467 |
+
print(f" Train Loss: {train_loss:.4f} "
|
| 468 |
+
f"(data: {train_components['data']:.4f}, "
|
| 469 |
+
f"phys: {train_components['physics']:.4f}, "
|
| 470 |
+
f"damage: {train_components['damage']:.4f})")
|
| 471 |
+
print(f" Val Loss: {val_loss:.4f} | "
|
| 472 |
+
f"R2: {val_metrics['r2']:.4f} | "
|
| 473 |
+
f"RMSE: {val_metrics['rmse']:.2f}")
|
| 474 |
+
|
| 475 |
+
if val_loss < best_val_loss:
|
| 476 |
+
best_val_loss = val_loss
|
| 477 |
+
patience_counter = 0
|
| 478 |
+
best_model_state = self.model.state_dict().copy()
|
| 479 |
+
else:
|
| 480 |
+
patience_counter += 1
|
| 481 |
+
|
| 482 |
+
if patience_counter >= patience:
|
| 483 |
+
print(f"\nEarly stopping. Best val loss: {best_val_loss:.4f}")
|
| 484 |
+
break
|
| 485 |
+
|
| 486 |
+
if best_model_state is not None:
|
| 487 |
+
self.model.load_state_dict(best_model_state)
|
| 488 |
+
|
| 489 |
+
print("=" * 80)
|
| 490 |
+
print(f"Training complete. Best val loss: {best_val_loss:.4f}")
|
| 491 |
+
|
| 492 |
+
def predict(self, X):
|
| 493 |
+
self.model.eval()
|
| 494 |
+
|
| 495 |
+
with torch.no_grad():
|
| 496 |
+
X_tensor = torch.FloatTensor(X).to(self.device)
|
| 497 |
+
pred_dist, pred_total = self.model(X_tensor, return_physics=False)
|
| 498 |
+
|
| 499 |
+
pred_dist = pred_dist.cpu().numpy()
|
| 500 |
+
pred_total = pred_total.cpu().numpy().flatten()
|
| 501 |
+
|
| 502 |
+
return pred_dist, pred_total
|
| 503 |
+
|
| 504 |
+
def predict_with_physics(self, X):
|
| 505 |
+
self.model.eval()
|
| 506 |
+
|
| 507 |
+
with torch.no_grad():
|
| 508 |
+
X_tensor = torch.FloatTensor(X).to(self.device)
|
| 509 |
+
pred_dist, pred_total, physics = self.model(X_tensor, return_physics=True)
|
| 510 |
+
|
| 511 |
+
result = {
|
| 512 |
+
'angle_distribution': pred_dist.cpu().numpy(),
|
| 513 |
+
'total_count': pred_total.cpu().numpy().flatten(),
|
| 514 |
+
'D0': physics['D0'].cpu().numpy().flatten(),
|
| 515 |
+
'lambda': physics['lambda'].cpu().numpy().flatten(),
|
| 516 |
+
'D_n': physics['D_n'].cpu().numpy().flatten()
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
return result
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
numpy>=1.21.0
|
| 3 |
+
pandas>=1.3.0
|
| 4 |
+
scikit-learn>=1.0.0
|
| 5 |
+
matplotlib>=3.5.0
|
train.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from sklearn.model_selection import train_test_split
|
| 6 |
+
import pickle
|
| 7 |
+
|
| 8 |
+
from data_loader import CrackDataLoader, DamageCalculator
|
| 9 |
+
from model import CrackTransformerPINN, CrackPINNTrainer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def plot_training_history(trainer, save_path=None):
|
| 13 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
|
| 14 |
+
|
| 15 |
+
axes[0].plot(trainer.train_losses, label='Train Loss', linewidth=2)
|
| 16 |
+
axes[0].plot(trainer.val_losses, label='Val Loss', linewidth=2)
|
| 17 |
+
axes[0].set_xlabel('Epoch', fontsize=12)
|
| 18 |
+
axes[0].set_ylabel('Loss', fontsize=12)
|
| 19 |
+
axes[0].set_title('Training History', fontsize=14, fontweight='bold')
|
| 20 |
+
axes[0].legend()
|
| 21 |
+
axes[0].grid(True, alpha=0.3)
|
| 22 |
+
|
| 23 |
+
axes[1].semilogy(trainer.train_losses, label='Train Loss', linewidth=2)
|
| 24 |
+
axes[1].semilogy(trainer.val_losses, label='Val Loss', linewidth=2)
|
| 25 |
+
axes[1].set_xlabel('Epoch', fontsize=12)
|
| 26 |
+
axes[1].set_ylabel('Loss (log scale)', fontsize=12)
|
| 27 |
+
axes[1].set_title('Training History (Log Scale)', fontsize=14, fontweight='bold')
|
| 28 |
+
axes[1].legend()
|
| 29 |
+
axes[1].grid(True, alpha=0.3)
|
| 30 |
+
|
| 31 |
+
if len(trainer.train_losses) > 1:
|
| 32 |
+
train_improvement = np.diff(trainer.train_losses)
|
| 33 |
+
axes[2].plot(train_improvement, linewidth=2, alpha=0.7)
|
| 34 |
+
axes[2].axhline(y=0, color='r', linestyle='--', alpha=0.5)
|
| 35 |
+
axes[2].set_xlabel('Epoch', fontsize=12)
|
| 36 |
+
axes[2].set_ylabel('Loss Change', fontsize=12)
|
| 37 |
+
axes[2].set_title('Convergence Rate', fontsize=14, fontweight='bold')
|
| 38 |
+
axes[2].grid(True, alpha=0.3)
|
| 39 |
+
|
| 40 |
+
plt.tight_layout()
|
| 41 |
+
|
| 42 |
+
if save_path:
|
| 43 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 44 |
+
print(f"Training history saved to: {save_path}")
|
| 45 |
+
|
| 46 |
+
plt.close()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def plot_prediction_examples(X_test, y_test, trainer, angle_bins, loader, n_examples=4, save_path=None):
|
| 50 |
+
X_test_original = loader.scaler_X.inverse_transform(X_test)
|
| 51 |
+
y_test_original = loader.scaler_y.inverse_transform(y_test)
|
| 52 |
+
|
| 53 |
+
y_pred_norm, pred_totals = trainer.predict(X_test[:n_examples])
|
| 54 |
+
y_pred = loader.scaler_y.inverse_transform(y_pred_norm)
|
| 55 |
+
|
| 56 |
+
fig, axes = plt.subplots(2, 2, figsize=(14, 10), subplot_kw=dict(projection='polar'))
|
| 57 |
+
axes = axes.flatten()
|
| 58 |
+
|
| 59 |
+
for i in range(min(n_examples, len(axes))):
|
| 60 |
+
ax = axes[i]
|
| 61 |
+
|
| 62 |
+
theta = np.deg2rad(angle_bins)
|
| 63 |
+
|
| 64 |
+
ax.plot(theta, y_test_original[i], 'o-', label='True', linewidth=2, markersize=4, alpha=0.7)
|
| 65 |
+
ax.plot(theta, y_pred[i], 's-', label='Predicted', linewidth=2, markersize=3, alpha=0.7)
|
| 66 |
+
|
| 67 |
+
pH = X_test_original[i, 0]
|
| 68 |
+
FN = X_test_original[i, 1]
|
| 69 |
+
FT = X_test_original[i, 2]
|
| 70 |
+
T = X_test_original[i, 3]
|
| 71 |
+
phase = X_test_original[i, 4]
|
| 72 |
+
phase_str = "Unstable" if phase < 0.5 else "Peak"
|
| 73 |
+
|
| 74 |
+
D0 = DamageCalculator.compute_total_damage(pH, FN, FT, T)
|
| 75 |
+
lambda_coef = DamageCalculator.compute_lambda(D0)
|
| 76 |
+
|
| 77 |
+
true_total = y_test_original[i].sum()
|
| 78 |
+
pred_total = y_pred[i].sum()
|
| 79 |
+
|
| 80 |
+
title = f"pH={pH:.0f}, FN={FN:.0f}, FT={FT:.0f}, T={T:.0f}C\n"
|
| 81 |
+
title += f"D0={D0:.3f}, lambda={lambda_coef:.3f}\n"
|
| 82 |
+
title += f"{phase_str} | True: {true_total:.0f}, Pred: {pred_total:.0f}"
|
| 83 |
+
|
| 84 |
+
ax.set_title(title, fontsize=9, pad=20)
|
| 85 |
+
ax.legend(loc='upper right', fontsize=8)
|
| 86 |
+
ax.set_theta_zero_location('N')
|
| 87 |
+
ax.set_theta_direction(-1)
|
| 88 |
+
ax.grid(True, alpha=0.3)
|
| 89 |
+
|
| 90 |
+
plt.tight_layout()
|
| 91 |
+
|
| 92 |
+
if save_path:
|
| 93 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 94 |
+
print(f"Prediction examples saved to: {save_path}")
|
| 95 |
+
|
| 96 |
+
plt.close()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def plot_damage_analysis(X, y, save_path=None):
|
| 100 |
+
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
| 101 |
+
|
| 102 |
+
D0_values = []
|
| 103 |
+
total_cracks = y.sum(axis=1)
|
| 104 |
+
|
| 105 |
+
for i in range(X.shape[0]):
|
| 106 |
+
D0 = DamageCalculator.compute_total_damage(X[i, 0], X[i, 1], X[i, 2], X[i, 3])
|
| 107 |
+
D0_values.append(D0)
|
| 108 |
+
|
| 109 |
+
D0_values = np.array(D0_values)
|
| 110 |
+
|
| 111 |
+
axes[0, 0].scatter(D0_values, total_cracks, alpha=0.6, edgecolors='black', linewidth=0.5)
|
| 112 |
+
axes[0, 0].set_xlabel('Initial Damage Factor D0', fontsize=12)
|
| 113 |
+
axes[0, 0].set_ylabel('Total Crack Count', fontsize=12)
|
| 114 |
+
axes[0, 0].set_title('D0 vs Total Cracks', fontsize=14, fontweight='bold')
|
| 115 |
+
axes[0, 0].grid(True, alpha=0.3)
|
| 116 |
+
|
| 117 |
+
axes[0, 1].scatter(X[:, 0], total_cracks, alpha=0.6, c=D0_values, cmap='viridis')
|
| 118 |
+
axes[0, 1].set_xlabel('pH Value', fontsize=12)
|
| 119 |
+
axes[0, 1].set_ylabel('Total Crack Count', fontsize=12)
|
| 120 |
+
axes[0, 1].set_title('pH vs Total Cracks', fontsize=14, fontweight='bold')
|
| 121 |
+
axes[0, 1].grid(True, alpha=0.3)
|
| 122 |
+
|
| 123 |
+
axes[1, 0].scatter(X[:, 1], total_cracks, alpha=0.6, c=D0_values, cmap='viridis')
|
| 124 |
+
axes[1, 0].set_xlabel('Freeze-thaw Cycles (FN)', fontsize=12)
|
| 125 |
+
axes[1, 0].set_ylabel('Total Crack Count', fontsize=12)
|
| 126 |
+
axes[1, 0].set_title('FN vs Total Cracks', fontsize=14, fontweight='bold')
|
| 127 |
+
axes[1, 0].grid(True, alpha=0.3)
|
| 128 |
+
|
| 129 |
+
scatter = axes[1, 1].scatter(X[:, 3], total_cracks, alpha=0.6, c=D0_values, cmap='viridis')
|
| 130 |
+
axes[1, 1].set_xlabel('Damage Temperature (T)', fontsize=12)
|
| 131 |
+
axes[1, 1].set_ylabel('Total Crack Count', fontsize=12)
|
| 132 |
+
axes[1, 1].set_title('T vs Total Cracks', fontsize=14, fontweight='bold')
|
| 133 |
+
axes[1, 1].grid(True, alpha=0.3)
|
| 134 |
+
|
| 135 |
+
plt.colorbar(scatter, ax=axes[1, 1], label='D0')
|
| 136 |
+
|
| 137 |
+
plt.tight_layout()
|
| 138 |
+
|
| 139 |
+
if save_path:
|
| 140 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 141 |
+
print(f"Damage analysis saved to: {save_path}")
|
| 142 |
+
|
| 143 |
+
plt.close()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def main():
|
| 147 |
+
print("=" * 80)
|
| 148 |
+
print("Transformer-PINN Crack Prediction Model")
|
| 149 |
+
print("Based on: Mechanism of micro-damage evolution in rocks")
|
| 150 |
+
print("under multiple coupled cyclic stresses")
|
| 151 |
+
print("=" * 80)
|
| 152 |
+
|
| 153 |
+
base_path = "./data"
|
| 154 |
+
output_dir = "./output"
|
| 155 |
+
|
| 156 |
+
if not os.path.exists(output_dir):
|
| 157 |
+
os.makedirs(output_dir)
|
| 158 |
+
|
| 159 |
+
print("\n" + "=" * 80)
|
| 160 |
+
print("Step 1: Loading/Generating Data")
|
| 161 |
+
print("=" * 80)
|
| 162 |
+
|
| 163 |
+
loader = CrackDataLoader(base_path, stress_type="major")
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
X, y, angle_bins, damage_list = loader.load_all_data(phase="both")
|
| 167 |
+
except:
|
| 168 |
+
print("Real data not found. Generating synthetic data...")
|
| 169 |
+
X, y, angle_bins = loader.create_synthetic_data(n_samples=200, output_dim=72)
|
| 170 |
+
|
| 171 |
+
stats = loader.get_statistics(X, y)
|
| 172 |
+
print("\nData statistics:")
|
| 173 |
+
for key, value in stats.items():
|
| 174 |
+
print(f" {key}: {value}")
|
| 175 |
+
|
| 176 |
+
print("\n" + "=" * 80)
|
| 177 |
+
print("Step 2: Splitting Dataset (Train:Val:Test = 64:16:20)")
|
| 178 |
+
print("=" * 80)
|
| 179 |
+
|
| 180 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 181 |
+
X, y, test_size=0.2, random_state=42
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 185 |
+
X_train, y_train, test_size=0.2, random_state=42
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
print(f"Training set: {X_train.shape[0]} samples")
|
| 189 |
+
print(f"Validation set: {X_val.shape[0]} samples")
|
| 190 |
+
print(f"Test set: {X_test.shape[0]} samples")
|
| 191 |
+
|
| 192 |
+
print("\n" + "=" * 80)
|
| 193 |
+
print("Step 3: Normalizing Data")
|
| 194 |
+
print("=" * 80)
|
| 195 |
+
|
| 196 |
+
X_train_norm, y_train_norm, X_val_norm, y_val_norm = loader.normalize_data(
|
| 197 |
+
X_train, y_train, X_val, y_val
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
X_test_norm = loader.scaler_X.transform(X_test)
|
| 201 |
+
y_test_norm = loader.scaler_y.transform(y_test)
|
| 202 |
+
|
| 203 |
+
print("Normalization complete")
|
| 204 |
+
|
| 205 |
+
print("\n" + "=" * 80)
|
| 206 |
+
print("Step 4: Creating Model")
|
| 207 |
+
print("=" * 80)
|
| 208 |
+
|
| 209 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 210 |
+
print(f"Using device: {device}")
|
| 211 |
+
|
| 212 |
+
model = CrackTransformerPINN(
|
| 213 |
+
input_dim=5,
|
| 214 |
+
output_dim=y.shape[1],
|
| 215 |
+
hidden_dims=[128, 256, 256, 128],
|
| 216 |
+
dropout=0.2
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
n_params = sum(p.numel() for p in model.parameters())
|
| 220 |
+
print(f"Model parameters: {n_params:,}")
|
| 221 |
+
|
| 222 |
+
print("\nModel components:")
|
| 223 |
+
print(" - Transformer Encoder (8 heads, 4 layers)")
|
| 224 |
+
print(" - Mogi-Coulomb Yield Criterion Layer")
|
| 225 |
+
print(" - Weibull Strength Distribution Layer")
|
| 226 |
+
print(" - Energy-based Damage Evolution Layer")
|
| 227 |
+
print(" - PINN Decoder with Physics Constraints")
|
| 228 |
+
|
| 229 |
+
print("\n" + "=" * 80)
|
| 230 |
+
print("Step 5: Training Model")
|
| 231 |
+
print("=" * 80)
|
| 232 |
+
|
| 233 |
+
trainer = CrackPINNTrainer(
|
| 234 |
+
model,
|
| 235 |
+
device=device,
|
| 236 |
+
lr=1e-3,
|
| 237 |
+
weight_decay=1e-4
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
trainer.fit(
|
| 241 |
+
X_train_norm, y_train_norm,
|
| 242 |
+
X_val_norm, y_val_norm,
|
| 243 |
+
epochs=300,
|
| 244 |
+
batch_size=8,
|
| 245 |
+
patience=50
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
print("\n" + "=" * 80)
|
| 249 |
+
print("Step 6: Testing Model")
|
| 250 |
+
print("=" * 80)
|
| 251 |
+
|
| 252 |
+
test_loss, test_metrics = trainer.validate(
|
| 253 |
+
torch.utils.data.DataLoader(
|
| 254 |
+
torch.utils.data.TensorDataset(
|
| 255 |
+
torch.FloatTensor(X_test_norm),
|
| 256 |
+
torch.FloatTensor(y_test_norm)
|
| 257 |
+
),
|
| 258 |
+
batch_size=8,
|
| 259 |
+
shuffle=False
|
| 260 |
+
)
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
print(f"Test set performance:")
|
| 264 |
+
print(f" Loss: {test_loss:.4f}")
|
| 265 |
+
print(f" R2: {test_metrics['r2']:.4f}")
|
| 266 |
+
print(f" RMSE: {test_metrics['rmse']:.2f}")
|
| 267 |
+
print(f" Total Count MAE: {test_metrics['total_count_mae']:.2f}")
|
| 268 |
+
|
| 269 |
+
print("\n" + "=" * 80)
|
| 270 |
+
print("Step 7: Saving Model")
|
| 271 |
+
print("=" * 80)
|
| 272 |
+
|
| 273 |
+
model_path = os.path.join(output_dir, "crack_transformer_pinn.pth")
|
| 274 |
+
torch.save({
|
| 275 |
+
'model_state_dict': model.state_dict(),
|
| 276 |
+
'model_config': {
|
| 277 |
+
'input_dim': 5,
|
| 278 |
+
'output_dim': y.shape[1],
|
| 279 |
+
'hidden_dims': [128, 256, 256, 128],
|
| 280 |
+
'dropout': 0.2
|
| 281 |
+
},
|
| 282 |
+
'test_metrics': test_metrics
|
| 283 |
+
}, model_path)
|
| 284 |
+
print(f"Model saved to: {model_path}")
|
| 285 |
+
|
| 286 |
+
scaler_path = os.path.join(output_dir, "scalers.pkl")
|
| 287 |
+
with open(scaler_path, 'wb') as f:
|
| 288 |
+
pickle.dump({
|
| 289 |
+
'scaler_X': loader.scaler_X,
|
| 290 |
+
'scaler_y': loader.scaler_y,
|
| 291 |
+
'angle_bins': angle_bins
|
| 292 |
+
}, f)
|
| 293 |
+
print(f"Scalers saved to: {scaler_path}")
|
| 294 |
+
|
| 295 |
+
print("\n" + "=" * 80)
|
| 296 |
+
print("Step 8: Generating Visualizations")
|
| 297 |
+
print("=" * 80)
|
| 298 |
+
|
| 299 |
+
history_path = os.path.join(output_dir, "training_history.png")
|
| 300 |
+
plot_training_history(trainer, save_path=history_path)
|
| 301 |
+
|
| 302 |
+
examples_path = os.path.join(output_dir, "prediction_examples.png")
|
| 303 |
+
plot_prediction_examples(
|
| 304 |
+
X_test_norm, y_test_norm,
|
| 305 |
+
trainer, angle_bins, loader,
|
| 306 |
+
n_examples=4,
|
| 307 |
+
save_path=examples_path
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
damage_path = os.path.join(output_dir, "damage_analysis.png")
|
| 311 |
+
plot_damage_analysis(X, y, save_path=damage_path)
|
| 312 |
+
|
| 313 |
+
print("\n" + "=" * 80)
|
| 314 |
+
print("Training Pipeline Complete!")
|
| 315 |
+
print("=" * 80)
|
| 316 |
+
print(f"\nGenerated files:")
|
| 317 |
+
print(f" 1. Model checkpoint: {model_path}")
|
| 318 |
+
print(f" 2. Scalers: {scaler_path}")
|
| 319 |
+
print(f" 3. Training history: {history_path}")
|
| 320 |
+
print(f" 4. Prediction examples: {examples_path}")
|
| 321 |
+
print(f" 5. Damage analysis: {damage_path}")
|
| 322 |
+
|
| 323 |
+
print("\nPhysics constraints applied:")
|
| 324 |
+
print(" - Mogi-Coulomb yield criterion: tau_oct = C1 + C2 * sigma_m2")
|
| 325 |
+
print(" - Weibull strength: D_q = 1 - exp(-(F/F0)^m)")
|
| 326 |
+
print(" - Energy damage: D_n = (2/pi) * arctan(b * U_p)")
|
| 327 |
+
print(" - Total damage: D_total = 1 - (1-D_ft)(1-D_ch)(1-D_th)")
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
if __name__ == "__main__":
|
| 331 |
+
main()
|