guanwencan commited on
Commit
5e4dee3
·
verified ·
1 Parent(s): 6d85675

Upload 5 files

Browse files
Files changed (5) hide show
  1. data_loader.py +239 -0
  2. inference.py +238 -0
  3. model.py +519 -0
  4. requirements.txt +5 -0
  5. train.py +331 -0
data_loader.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import numpy as np
4
+ import pandas as pd
5
+ from glob import glob
6
+ from sklearn.preprocessing import StandardScaler
7
+
8
+
9
+ class DamageCalculator:
10
+
11
+ @staticmethod
12
+ def compute_freeze_thaw_damage(FN, FT, a1=0.002, b1=1.0, c1=0.02):
13
+ return a1 * (FN ** b1) * np.exp(c1 * FT)
14
+
15
+ @staticmethod
16
+ def compute_chemical_damage(pH, a2=0.01, b2=1.5):
17
+ return a2 * np.abs(pH - 7.0) ** b2
18
+
19
+ @staticmethod
20
+ def compute_thermal_damage(T, T0=100.0, a3=0.0003, b3=1.2):
21
+ if T < T0:
22
+ return 0.0
23
+ return a3 * ((T - T0) ** b3)
24
+
25
+ @staticmethod
26
+ def compute_total_damage(pH, FN, FT, T):
27
+ D_ft = DamageCalculator.compute_freeze_thaw_damage(FN, FT)
28
+ D_ch = DamageCalculator.compute_chemical_damage(pH)
29
+ D_th = DamageCalculator.compute_thermal_damage(T)
30
+
31
+ D_total = 1.0 - (1.0 - D_ft) * (1.0 - D_ch) * (1.0 - D_th)
32
+ return np.clip(D_total, 0.0, 0.99)
33
+
34
+ @staticmethod
35
+ def compute_lambda(D0):
36
+ return 1.0 - D0
37
+
38
+
39
+ class CrackDataLoader:
40
+
41
+ def __init__(self, base_path, stress_type="major"):
42
+ self.base_path = base_path
43
+ self.stress_type = stress_type
44
+
45
+ if stress_type == "major":
46
+ self.data_dir = os.path.join(base_path, "major_principal_stress")
47
+ else:
48
+ self.data_dir = os.path.join(base_path, "minor_principal_stress")
49
+
50
+ self.scaler_X = StandardScaler()
51
+ self.scaler_y = StandardScaler()
52
+ self.damage_calculator = DamageCalculator()
53
+
54
+ def parse_filename(self, filename):
55
+ pattern = r'(\d+)-(\d+)-(\d+)-(\d+)'
56
+ match = re.search(pattern, filename)
57
+
58
+ if match:
59
+ pH = int(match.group(1))
60
+ FN = int(match.group(2))
61
+ FT = int(match.group(3))
62
+ T = int(match.group(4))
63
+
64
+ return {
65
+ 'pH': pH,
66
+ 'FN': FN,
67
+ 'FT': FT,
68
+ 'T': T
69
+ }
70
+ else:
71
+ raise ValueError(f"Cannot parse filename: {filename}")
72
+
73
+ def load_single_csv(self, csv_path):
74
+ data = pd.read_csv(csv_path, header=None, names=['angle', 'count'])
75
+ angles = data['angle'].values
76
+ counts = data['count'].values
77
+ return angles, counts
78
+
79
+ def load_all_data(self, phase="both"):
80
+ X_list = []
81
+ y_list = []
82
+ damage_list = []
83
+
84
+ if phase == "both":
85
+ subdirs = ["unstable_development", "peak_stress"]
86
+ elif phase == "early":
87
+ subdirs = ["unstable_development"]
88
+ elif phase == "peak":
89
+ subdirs = ["peak_stress"]
90
+ else:
91
+ raise ValueError(f"Unknown phase: {phase}")
92
+
93
+ for subdir in subdirs:
94
+ subdir_path = os.path.join(self.data_dir, subdir)
95
+
96
+ if not os.path.exists(subdir_path):
97
+ print(f"Warning: Directory does not exist {subdir_path}")
98
+ continue
99
+
100
+ phase_code = 0 if "unstable" in subdir else 1
101
+
102
+ csv_files = glob(os.path.join(subdir_path, "*.csv"))
103
+
104
+ print(f"Loading {len(csv_files)} files from {subdir}...")
105
+
106
+ for csv_file in csv_files:
107
+ try:
108
+ params = self.parse_filename(os.path.basename(csv_file))
109
+
110
+ angles, counts = self.load_single_csv(csv_file)
111
+
112
+ D0 = DamageCalculator.compute_total_damage(
113
+ params['pH'], params['FN'], params['FT'], params['T']
114
+ )
115
+ lambda_coef = DamageCalculator.compute_lambda(D0)
116
+
117
+ features = np.array([
118
+ params['pH'],
119
+ params['FN'],
120
+ params['FT'],
121
+ params['T'],
122
+ phase_code
123
+ ], dtype=np.float32)
124
+
125
+ X_list.append(features)
126
+ y_list.append(counts)
127
+ damage_list.append({'D0': D0, 'lambda': lambda_coef})
128
+
129
+ except Exception as e:
130
+ print(f"Skipping file {csv_file}: {e}")
131
+ continue
132
+
133
+ if len(X_list) == 0:
134
+ raise ValueError("No data loaded successfully!")
135
+
136
+ X = np.array(X_list)
137
+
138
+ y_length = len(y_list[0])
139
+ y_padded = []
140
+
141
+ for y_sample in y_list:
142
+ if len(y_sample) < y_length:
143
+ y_sample = np.pad(y_sample, (0, y_length - len(y_sample)), 'constant')
144
+ elif len(y_sample) > y_length:
145
+ y_sample = y_sample[:y_length]
146
+ y_padded.append(y_sample)
147
+
148
+ y = np.array(y_padded)
149
+
150
+ angles, _ = self.load_single_csv(csv_files[0])
151
+ angle_bins = angles[:y_length]
152
+
153
+ print(f"\nData loading complete:")
154
+ print(f" Samples: {X.shape[0]}")
155
+ print(f" Input features: {X.shape[1]} (pH, FN, FT, T, phase)")
156
+ print(f" Output dimension: {y.shape[1]} (angle bins)")
157
+ print(f" Angle range: {angle_bins[0]:.1f} - {angle_bins[-1]:.1f}")
158
+ print(f" Total cracks range: {y.sum(axis=1).min():.0f} - {y.sum(axis=1).max():.0f}")
159
+
160
+ return X, y, angle_bins, damage_list
161
+
162
+ def create_synthetic_data(self, n_samples=100, output_dim=72):
163
+ pH_values = [1, 3, 5, 7]
164
+ FN_values = [5, 10, 20, 40]
165
+ FT_values = [10, 20, 30, 40]
166
+ T_values = [25, 300, 600, 900]
167
+ phase_values = [0, 1]
168
+
169
+ X_list = []
170
+ y_list = []
171
+
172
+ for _ in range(n_samples):
173
+ pH = np.random.choice(pH_values)
174
+ FN = np.random.choice(FN_values)
175
+ FT = np.random.choice(FT_values)
176
+ T = np.random.choice(T_values)
177
+ phase = np.random.choice(phase_values)
178
+
179
+ D0 = DamageCalculator.compute_total_damage(pH, FN, FT, T)
180
+
181
+ if self.stress_type == "major":
182
+ peak_angle = 90.0 + np.random.normal(0, 10)
183
+ spread = 15.0 + D0 * 20.0
184
+ else:
185
+ peak_angle = 45.0 + np.random.normal(0, 15)
186
+ spread = 20.0 + D0 * 25.0
187
+
188
+ angles = np.linspace(0, 175, output_dim)
189
+ distribution = np.exp(-0.5 * ((angles - peak_angle) / spread) ** 2)
190
+ distribution = distribution * (100 + D0 * 200) * (1 + 0.5 * phase)
191
+ distribution = distribution + np.random.normal(0, 5, output_dim)
192
+ distribution = np.maximum(distribution, 0)
193
+
194
+ X_list.append([pH, FN, FT, T, phase])
195
+ y_list.append(distribution)
196
+
197
+ X = np.array(X_list, dtype=np.float32)
198
+ y = np.array(y_list, dtype=np.float32)
199
+ angle_bins = np.linspace(0, 175, output_dim)
200
+
201
+ return X, y, angle_bins
202
+
203
+ def normalize_data(self, X_train, y_train, X_test=None, y_test=None):
204
+ X_train_norm = self.scaler_X.fit_transform(X_train)
205
+ y_train_norm = self.scaler_y.fit_transform(y_train)
206
+
207
+ if X_test is not None and y_test is not None:
208
+ X_test_norm = self.scaler_X.transform(X_test)
209
+ y_test_norm = self.scaler_y.transform(y_test)
210
+ return X_train_norm, y_train_norm, X_test_norm, y_test_norm
211
+ else:
212
+ return X_train_norm, y_train_norm
213
+
214
+ def denormalize_output(self, y_norm):
215
+ return self.scaler_y.inverse_transform(y_norm)
216
+
217
+ def get_statistics(self, X, y):
218
+ stats = {
219
+ 'n_samples': X.shape[0],
220
+ 'input_dim': X.shape[1],
221
+ 'output_dim': y.shape[1],
222
+ 'pH_range': (X[:, 0].min(), X[:, 0].max()),
223
+ 'FN_range': (X[:, 1].min(), X[:, 1].max()),
224
+ 'FT_range': (X[:, 2].min(), X[:, 2].max()),
225
+ 'T_range': (X[:, 3].min(), X[:, 3].max()),
226
+ 'total_cracks_range': (y.sum(axis=1).min(), y.sum(axis=1).max()),
227
+ 'total_cracks_mean': y.sum(axis=1).mean(),
228
+ 'total_cracks_std': y.sum(axis=1).std(),
229
+ }
230
+
231
+ D0_values = []
232
+ for i in range(X.shape[0]):
233
+ D0 = DamageCalculator.compute_total_damage(X[i, 0], X[i, 1], X[i, 2], X[i, 3])
234
+ D0_values.append(D0)
235
+
236
+ stats['D0_range'] = (min(D0_values), max(D0_values))
237
+ stats['D0_mean'] = np.mean(D0_values)
238
+
239
+ return stats
inference.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import pickle
5
+ import matplotlib.pyplot as plt
6
+
7
+ from model import CrackTransformerPINN
8
+ from data_loader import DamageCalculator
9
+
10
+
11
+ class CrackPredictor:
12
+
13
+ def __init__(self, model_path, scaler_path, device='cpu'):
14
+ self.device = device
15
+
16
+ with open(scaler_path, 'rb') as f:
17
+ scalers = pickle.load(f)
18
+
19
+ self.scaler_X = scalers['scaler_X']
20
+ self.scaler_y = scalers['scaler_y']
21
+ self.angle_bins = scalers['angle_bins']
22
+
23
+ checkpoint = torch.load(model_path, map_location=device)
24
+
25
+ if 'model_config' in checkpoint:
26
+ config = checkpoint['model_config']
27
+ self.model = CrackTransformerPINN(
28
+ input_dim=config['input_dim'],
29
+ output_dim=config['output_dim'],
30
+ hidden_dims=config['hidden_dims'],
31
+ dropout=config['dropout']
32
+ )
33
+ self.model.load_state_dict(checkpoint['model_state_dict'])
34
+ else:
35
+ self.model = CrackTransformerPINN(
36
+ input_dim=5,
37
+ output_dim=len(self.angle_bins),
38
+ hidden_dims=[128, 256, 256, 128],
39
+ dropout=0.2
40
+ )
41
+ self.model.load_state_dict(checkpoint)
42
+
43
+ self.model.to(device)
44
+ self.model.eval()
45
+
46
+ def predict(self, pH, FN, FT, T, phase):
47
+ X = np.array([[pH, FN, FT, T, phase]], dtype=np.float32)
48
+
49
+ X_norm = self.scaler_X.transform(X)
50
+
51
+ with torch.no_grad():
52
+ X_tensor = torch.FloatTensor(X_norm).to(self.device)
53
+ pred_dist_norm, pred_total = self.model(X_tensor, return_physics=False)
54
+
55
+ pred_dist_norm = pred_dist_norm.cpu().numpy()
56
+ pred_total = pred_total.cpu().numpy().flatten()
57
+
58
+ pred_dist = self.scaler_y.inverse_transform(pred_dist_norm)
59
+
60
+ D0 = DamageCalculator.compute_total_damage(pH, FN, FT, T)
61
+ lambda_coef = DamageCalculator.compute_lambda(D0)
62
+
63
+ return {
64
+ 'angle_distribution': pred_dist[0],
65
+ 'total_count': pred_total[0],
66
+ 'D0': D0,
67
+ 'lambda': lambda_coef
68
+ }
69
+
70
+ def predict_with_physics(self, pH, FN, FT, T, phase):
71
+ X = np.array([[pH, FN, FT, T, phase]], dtype=np.float32)
72
+
73
+ X_norm = self.scaler_X.transform(X)
74
+
75
+ with torch.no_grad():
76
+ X_tensor = torch.FloatTensor(X_norm).to(self.device)
77
+ pred_dist_norm, pred_total, physics = self.model(X_tensor, return_physics=True)
78
+
79
+ pred_dist_norm = pred_dist_norm.cpu().numpy()
80
+
81
+ pred_dist = self.scaler_y.inverse_transform(pred_dist_norm)
82
+
83
+ result = {
84
+ 'angle_distribution': pred_dist[0],
85
+ 'total_count': pred_total.cpu().numpy().flatten()[0],
86
+ 'D0': physics['D0'].cpu().numpy().flatten()[0],
87
+ 'lambda': physics['lambda'].cpu().numpy().flatten()[0],
88
+ 'D_n': physics['D_n'].cpu().numpy().flatten()[0],
89
+ 'tau_oct': physics['tau_oct'].cpu().numpy().flatten()[0],
90
+ 'yield_stress': physics['yield_stress'].cpu().numpy().flatten()[0],
91
+ 'C1': physics['C1'].cpu().numpy().flatten()[0],
92
+ 'C2': physics['C2'].cpu().numpy().flatten()[0],
93
+ 'D_q': physics['D_q'].cpu().numpy().flatten()[0],
94
+ 'm': physics['m'].cpu().numpy().flatten()[0],
95
+ 'F0': physics['F0'].cpu().numpy().flatten()[0]
96
+ }
97
+
98
+ return result
99
+
100
+ def predict_batch(self, X):
101
+ X_norm = self.scaler_X.transform(X)
102
+
103
+ with torch.no_grad():
104
+ X_tensor = torch.FloatTensor(X_norm).to(self.device)
105
+ pred_dist_norm, pred_total = self.model(X_tensor, return_physics=False)
106
+
107
+ pred_dist_norm = pred_dist_norm.cpu().numpy()
108
+ pred_total = pred_total.cpu().numpy().flatten()
109
+
110
+ pred_dist = self.scaler_y.inverse_transform(pred_dist_norm)
111
+
112
+ return pred_dist, pred_total
113
+
114
+ def visualize(self, result, title=None, save_path=None):
115
+ fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(projection='polar'))
116
+
117
+ theta = np.deg2rad(self.angle_bins)
118
+ pred_dist = result['angle_distribution']
119
+
120
+ ax.plot(theta, pred_dist, 'o-', linewidth=2, markersize=4, color='blue')
121
+ ax.fill(theta, pred_dist, alpha=0.3, color='blue')
122
+
123
+ if title:
124
+ ax.set_title(title, fontsize=14, pad=20)
125
+ else:
126
+ total = result['total_count']
127
+ D0 = result['D0']
128
+ lam = result['lambda']
129
+ ax.set_title(f'Predicted Crack Distribution\nTotal: {total:.0f}, D0: {D0:.3f}, lambda: {lam:.3f}',
130
+ fontsize=12, pad=20)
131
+
132
+ ax.set_theta_zero_location('N')
133
+ ax.set_theta_direction(-1)
134
+ ax.grid(True, alpha=0.3)
135
+
136
+ if save_path:
137
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
138
+
139
+ plt.show()
140
+
141
+ def compare_stress_types(self, pH, FN, FT, T, save_path=None):
142
+ result_unstable = self.predict(pH, FN, FT, T, phase=0)
143
+ result_peak = self.predict(pH, FN, FT, T, phase=1)
144
+
145
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6), subplot_kw=dict(projection='polar'))
146
+
147
+ theta = np.deg2rad(self.angle_bins)
148
+
149
+ axes[0].plot(theta, result_unstable['angle_distribution'], 'o-',
150
+ linewidth=2, markersize=4, color='blue')
151
+ axes[0].fill(theta, result_unstable['angle_distribution'], alpha=0.3, color='blue')
152
+ axes[0].set_title(f'Unstable Development Phase\nTotal: {result_unstable["total_count"]:.0f}',
153
+ fontsize=12, pad=20)
154
+ axes[0].set_theta_zero_location('N')
155
+ axes[0].set_theta_direction(-1)
156
+ axes[0].grid(True, alpha=0.3)
157
+
158
+ axes[1].plot(theta, result_peak['angle_distribution'], 'o-',
159
+ linewidth=2, markersize=4, color='red')
160
+ axes[1].fill(theta, result_peak['angle_distribution'], alpha=0.3, color='red')
161
+ axes[1].set_title(f'Peak Stress Phase\nTotal: {result_peak["total_count"]:.0f}',
162
+ fontsize=12, pad=20)
163
+ axes[1].set_theta_zero_location('N')
164
+ axes[1].set_theta_direction(-1)
165
+ axes[1].grid(True, alpha=0.3)
166
+
167
+ D0 = result_unstable['D0']
168
+ lam = result_unstable['lambda']
169
+ fig.suptitle(f'pH={pH}, FN={FN}, FT={FT}, T={T}\nD0={D0:.3f}, lambda={lam:.3f}',
170
+ fontsize=14, fontweight='bold')
171
+
172
+ plt.tight_layout()
173
+
174
+ if save_path:
175
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
176
+
177
+ plt.show()
178
+
179
+
180
+ def main():
181
+ model_path = "./output/crack_transformer_pinn.pth"
182
+ scaler_path = "./output/scalers.pkl"
183
+
184
+ if not os.path.exists(model_path):
185
+ print("Model file not found. Please train the model first using train.py")
186
+ return
187
+
188
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
189
+ predictor = CrackPredictor(model_path, scaler_path, device=device)
190
+
191
+ print("=" * 60)
192
+ print("Transformer-PINN Crack Prediction - Inference Demo")
193
+ print("=" * 60)
194
+
195
+ test_cases = [
196
+ {'pH': 3, 'FN': 30, 'FT': 40, 'T': 25, 'phase': 0},
197
+ {'pH': 3, 'FN': 40, 'FT': 20, 'T': 300, 'phase': 1},
198
+ {'pH': 7, 'FN': 10, 'FT': 40, 'T': 300, 'phase': 1},
199
+ {'pH': 5, 'FN': 10, 'FT': 20, 'T': 900, 'phase': 0},
200
+ ]
201
+
202
+ for i, params in enumerate(test_cases):
203
+ result = predictor.predict(**params)
204
+
205
+ print(f"\nTest Case {i+1}:")
206
+ print(f" Input: pH={params['pH']}, FN={params['FN']}, FT={params['FT']}, T={params['T']}, phase={params['phase']}")
207
+ print(f" D0 (Initial Damage): {result['D0']:.4f}")
208
+ print(f" Lambda (Damage Coefficient): {result['lambda']:.4f}")
209
+ print(f" Predicted Total Cracks: {result['total_count']:.0f}")
210
+ print(f" Peak Angle: {predictor.angle_bins[result['angle_distribution'].argmax()]:.1f} degrees")
211
+ print(f" Peak Count: {result['angle_distribution'].max():.0f}")
212
+
213
+ print("\n" + "=" * 60)
214
+ print("Testing Physics Output")
215
+ print("=" * 60)
216
+
217
+ physics_result = predictor.predict_with_physics(pH=3, FN=30, FT=40, T=25, phase=1)
218
+
219
+ print("\nPhysics Parameters:")
220
+ print(f" Mogi-Coulomb:")
221
+ print(f" tau_oct: {physics_result['tau_oct']:.4f}")
222
+ print(f" yield_stress: {physics_result['yield_stress']:.4f}")
223
+ print(f" C1: {physics_result['C1']:.4f}")
224
+ print(f" C2: {physics_result['C2']:.4f}")
225
+ print(f" Weibull Distribution:")
226
+ print(f" D_q: {physics_result['D_q']:.4f}")
227
+ print(f" m: {physics_result['m']:.4f}")
228
+ print(f" F0: {physics_result['F0']:.4f}")
229
+ print(f" Energy Damage:")
230
+ print(f" D_n: {physics_result['D_n']:.4f}")
231
+
232
+ print("\n" + "=" * 60)
233
+ print("Inference complete!")
234
+ print("=" * 60)
235
+
236
+
237
+ if __name__ == "__main__":
238
+ main()
model.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import TensorDataset, DataLoader
6
+ from sklearn.model_selection import train_test_split
7
+
8
+
9
+ class MogiCoulombLayer(nn.Module):
10
+
11
+ def __init__(self, hidden_dim):
12
+ super(MogiCoulombLayer, self).__init__()
13
+ self.C1_net = nn.Linear(hidden_dim, 1)
14
+ self.C2_net = nn.Linear(hidden_dim, 1)
15
+
16
+ def forward(self, features, sigma1, sigma2, sigma3):
17
+ C1 = torch.relu(self.C1_net(features))
18
+ C2 = torch.sigmoid(self.C2_net(features))
19
+
20
+ tau_oct = (1.0/3.0) * torch.sqrt(
21
+ (sigma1 - sigma2)**2 + (sigma2 - sigma3)**2 + (sigma1 - sigma3)**2
22
+ )
23
+ sigma_m2 = (sigma1 + sigma3) / 2.0
24
+
25
+ yield_stress = C1 + C2 * sigma_m2
26
+
27
+ return tau_oct, yield_stress, C1, C2
28
+
29
+
30
+ class WeibullStrengthLayer(nn.Module):
31
+
32
+ def __init__(self, hidden_dim):
33
+ super(WeibullStrengthLayer, self).__init__()
34
+ self.m_net = nn.Sequential(
35
+ nn.Linear(hidden_dim, 32),
36
+ nn.Tanh(),
37
+ nn.Linear(32, 1),
38
+ nn.Softplus()
39
+ )
40
+ self.F0_net = nn.Sequential(
41
+ nn.Linear(hidden_dim, 32),
42
+ nn.Tanh(),
43
+ nn.Linear(32, 1),
44
+ nn.Softplus()
45
+ )
46
+
47
+ def forward(self, features, F):
48
+ m = self.m_net(features) + 1.0
49
+ F0 = self.F0_net(features) + 0.1
50
+
51
+ D_q = 1.0 - torch.exp(-torch.pow(F / F0, m))
52
+
53
+ return D_q, m, F0
54
+
55
+
56
+ class EnergyDamageLayer(nn.Module):
57
+
58
+ def __init__(self, hidden_dim):
59
+ super(EnergyDamageLayer, self).__init__()
60
+ self.a_net = nn.Sequential(
61
+ nn.Linear(hidden_dim, 32),
62
+ nn.Tanh(),
63
+ nn.Linear(32, 1),
64
+ nn.Softplus()
65
+ )
66
+ self.b_net = nn.Sequential(
67
+ nn.Linear(hidden_dim, 32),
68
+ nn.Tanh(),
69
+ nn.Linear(32, 1),
70
+ nn.Softplus()
71
+ )
72
+
73
+ def forward(self, features, delta_sigma, D0):
74
+ a = self.a_net(features) + 0.1
75
+ b = self.b_net(features) + 0.01
76
+
77
+ effective_stress = delta_sigma / (1.0 - D0 + 1e-8)
78
+
79
+ U_p = a * torch.exp(b * effective_stress)
80
+
81
+ D_n = (2.0 / np.pi) * torch.atan(b * U_p)
82
+
83
+ return D_n, U_p, a, b
84
+
85
+
86
+ class CrackTransformerPINN(nn.Module):
87
+
88
+ def __init__(self, input_dim=5, output_dim=72, hidden_dims=[128, 256, 256, 128], dropout=0.2):
89
+ super(CrackTransformerPINN, self).__init__()
90
+
91
+ self.input_dim = input_dim
92
+ self.output_dim = output_dim
93
+
94
+ self.input_embedding = nn.Sequential(
95
+ nn.Linear(input_dim, hidden_dims[0]),
96
+ nn.LayerNorm(hidden_dims[0]),
97
+ nn.GELU(),
98
+ nn.Dropout(dropout * 0.5)
99
+ )
100
+
101
+ self.damage_encoder = nn.Sequential(
102
+ nn.Linear(input_dim, hidden_dims[0]),
103
+ nn.Tanh(),
104
+ nn.Linear(hidden_dims[0], hidden_dims[0])
105
+ )
106
+
107
+ encoder_layer = nn.TransformerEncoderLayer(
108
+ d_model=hidden_dims[0],
109
+ nhead=8,
110
+ dim_feedforward=hidden_dims[0] * 4,
111
+ dropout=dropout,
112
+ activation='gelu',
113
+ batch_first=True,
114
+ norm_first=True
115
+ )
116
+
117
+ self.transformer_encoder = nn.TransformerEncoder(
118
+ encoder_layer,
119
+ num_layers=4,
120
+ norm=nn.LayerNorm(hidden_dims[0])
121
+ )
122
+
123
+ self.mogi_coulomb = MogiCoulombLayer(hidden_dims[0])
124
+ self.weibull_strength = WeibullStrengthLayer(hidden_dims[0])
125
+ self.energy_damage = EnergyDamageLayer(hidden_dims[0])
126
+
127
+ self.angle_decoder = nn.ModuleList()
128
+ prev_dim = hidden_dims[0] * 2
129
+
130
+ for hidden_dim in hidden_dims[1:]:
131
+ self.angle_decoder.append(nn.Linear(prev_dim, hidden_dim))
132
+ self.angle_decoder.append(nn.LayerNorm(hidden_dim))
133
+ self.angle_decoder.append(nn.Tanh())
134
+ self.angle_decoder.append(nn.Dropout(dropout))
135
+ prev_dim = hidden_dim
136
+
137
+ self.angle_output = nn.Sequential(
138
+ nn.Linear(prev_dim, output_dim),
139
+ nn.ReLU()
140
+ )
141
+
142
+ self.total_count_head = nn.Sequential(
143
+ nn.Linear(hidden_dims[0] * 2, 64),
144
+ nn.Tanh(),
145
+ nn.Linear(64, 1),
146
+ nn.ReLU()
147
+ )
148
+
149
+ self.damage_factor_head = nn.Sequential(
150
+ nn.Linear(hidden_dims[0], 32),
151
+ nn.Tanh(),
152
+ nn.Linear(32, 1),
153
+ nn.Sigmoid()
154
+ )
155
+
156
+ def compute_initial_damage(self, pH, FN, FT, T):
157
+ D_ft = 0.002 * FN * torch.exp(0.02 * FT)
158
+ D_ch = 0.01 * torch.abs(pH - 7.0) ** 1.5
159
+ D_th = torch.where(
160
+ T > 100.0,
161
+ 0.0003 * (T - 100.0) ** 1.2,
162
+ torch.zeros_like(T)
163
+ )
164
+
165
+ D_total = 1.0 - (1.0 - D_ft) * (1.0 - D_ch) * (1.0 - D_th)
166
+ D_total = torch.clamp(D_total, 0.0, 0.99)
167
+
168
+ return D_total
169
+
170
+ def forward(self, x, return_physics=False):
171
+ batch_size = x.shape[0]
172
+
173
+ pH = x[:, 0:1]
174
+ FN = x[:, 1:2]
175
+ FT = x[:, 2:3]
176
+ T = x[:, 3:4]
177
+ phase = x[:, 4:5]
178
+
179
+ D0 = self.compute_initial_damage(pH, FN, FT, T)
180
+ lambda_coef = 1.0 - D0
181
+
182
+ x_embedded = self.input_embedding(x)
183
+ damage_features = self.damage_encoder(x)
184
+
185
+ x_seq = x_embedded.unsqueeze(1)
186
+ encoded = self.transformer_encoder(x_seq)
187
+ encoded = encoded.squeeze(1)
188
+
189
+ combined = torch.cat([encoded, damage_features], dim=-1)
190
+
191
+ h = combined
192
+ for layer in self.angle_decoder:
193
+ h = layer(h)
194
+
195
+ angle_dist = self.angle_output(h)
196
+
197
+ total_count = self.total_count_head(combined)
198
+
199
+ predicted_D0 = self.damage_factor_head(encoded)
200
+
201
+ if return_physics:
202
+ sigma1 = 100.0 * torch.ones(batch_size, 1, device=x.device)
203
+ sigma2 = 50.0 * torch.ones(batch_size, 1, device=x.device)
204
+ sigma3 = 30.0 * torch.ones(batch_size, 1, device=x.device)
205
+ delta_sigma = 20.0 * torch.ones(batch_size, 1, device=x.device)
206
+ F_contact = 10.0 * torch.ones(batch_size, 1, device=x.device)
207
+
208
+ tau_oct, yield_stress, C1, C2 = self.mogi_coulomb(encoded, sigma1, sigma2, sigma3)
209
+ D_q, m, F0 = self.weibull_strength(encoded, F_contact)
210
+ D_n, U_p, a, b = self.energy_damage(encoded, delta_sigma, D0)
211
+
212
+ physics_outputs = {
213
+ 'D0': D0,
214
+ 'lambda': lambda_coef,
215
+ 'predicted_D0': predicted_D0,
216
+ 'tau_oct': tau_oct,
217
+ 'yield_stress': yield_stress,
218
+ 'C1': C1,
219
+ 'C2': C2,
220
+ 'D_q': D_q,
221
+ 'm': m,
222
+ 'F0': F0,
223
+ 'D_n': D_n,
224
+ 'U_p': U_p,
225
+ 'a': a,
226
+ 'b': b
227
+ }
228
+
229
+ return angle_dist, total_count, physics_outputs
230
+
231
+ return angle_dist, total_count
232
+
233
+
234
+ class CrackPINNLoss(nn.Module):
235
+
236
+ def __init__(self, lambda_data=1.0, lambda_physics=0.5, lambda_smooth=0.1,
237
+ lambda_damage=0.3, lambda_mogi=0.2, lambda_reg=1e-4):
238
+ super(CrackPINNLoss, self).__init__()
239
+
240
+ self.lambda_data = lambda_data
241
+ self.lambda_physics = lambda_physics
242
+ self.lambda_smooth = lambda_smooth
243
+ self.lambda_damage = lambda_damage
244
+ self.lambda_mogi = lambda_mogi
245
+ self.lambda_reg = lambda_reg
246
+
247
+ self.mse_loss = nn.MSELoss()
248
+
249
+ def data_loss(self, pred_dist, true_dist):
250
+ return self.mse_loss(pred_dist, true_dist)
251
+
252
+ def physics_loss(self, pred_dist, pred_total, true_dist):
253
+ pred_sum = pred_dist.sum(dim=1, keepdim=True)
254
+ true_sum = true_dist.sum(dim=1, keepdim=True)
255
+
256
+ loss_consistency = self.mse_loss(pred_sum, pred_total)
257
+ loss_total = self.mse_loss(pred_total, true_sum)
258
+
259
+ return loss_consistency + loss_total
260
+
261
+ def smoothness_loss(self, pred_dist):
262
+ diff = pred_dist[:, 1:] - pred_dist[:, :-1]
263
+ return torch.mean(diff ** 2)
264
+
265
+ def damage_consistency_loss(self, physics_outputs):
266
+ D0 = physics_outputs['D0']
267
+ predicted_D0 = physics_outputs['predicted_D0']
268
+
269
+ loss_D0 = self.mse_loss(predicted_D0, D0)
270
+
271
+ lambda_coef = physics_outputs['lambda']
272
+ loss_lambda = torch.mean(torch.relu(-lambda_coef) + torch.relu(lambda_coef - 1.0))
273
+
274
+ return loss_D0 + loss_lambda
275
+
276
+ def mogi_coulomb_loss(self, physics_outputs):
277
+ tau_oct = physics_outputs['tau_oct']
278
+ yield_stress = physics_outputs['yield_stress']
279
+
280
+ loss_yield = torch.mean(torch.relu(tau_oct - yield_stress))
281
+
282
+ C1 = physics_outputs['C1']
283
+ C2 = physics_outputs['C2']
284
+ loss_params = torch.mean(torch.relu(-C1)) + torch.mean(torch.relu(C2 - 1.0))
285
+
286
+ return loss_yield + 0.1 * loss_params
287
+
288
+ def regularization_loss(self, model):
289
+ l2_reg = torch.tensor(0.0, device=next(model.parameters()).device)
290
+ for param in model.parameters():
291
+ l2_reg += torch.norm(param, p=2) ** 2
292
+ return l2_reg
293
+
294
+ def forward(self, pred_dist, pred_total, true_dist, model, physics_outputs=None):
295
+ loss_data = self.data_loss(pred_dist, true_dist)
296
+ loss_physics = self.physics_loss(pred_dist, pred_total, true_dist)
297
+ loss_smooth = self.smoothness_loss(pred_dist)
298
+ loss_reg = self.regularization_loss(model)
299
+
300
+ loss_damage = torch.tensor(0.0, device=pred_dist.device)
301
+ loss_mogi = torch.tensor(0.0, device=pred_dist.device)
302
+
303
+ if physics_outputs is not None:
304
+ loss_damage = self.damage_consistency_loss(physics_outputs)
305
+ loss_mogi = self.mogi_coulomb_loss(physics_outputs)
306
+
307
+ total_loss = (self.lambda_data * loss_data +
308
+ self.lambda_physics * loss_physics +
309
+ self.lambda_smooth * loss_smooth +
310
+ self.lambda_damage * loss_damage +
311
+ self.lambda_mogi * loss_mogi +
312
+ self.lambda_reg * loss_reg)
313
+
314
+ loss_dict = {
315
+ 'total': total_loss.item(),
316
+ 'data': loss_data.item(),
317
+ 'physics': loss_physics.item(),
318
+ 'smooth': loss_smooth.item(),
319
+ 'damage': loss_damage.item(),
320
+ 'mogi': loss_mogi.item(),
321
+ 'reg': loss_reg.item()
322
+ }
323
+
324
+ return total_loss, loss_dict
325
+
326
+
327
+ class CrackPINNTrainer:
328
+
329
+ def __init__(self, model, device='cpu', lr=1e-3, weight_decay=1e-4):
330
+ self.model = model.to(device)
331
+ self.device = device
332
+
333
+ self.optimizer = optim.AdamW(
334
+ model.parameters(),
335
+ lr=lr,
336
+ weight_decay=weight_decay,
337
+ betas=(0.9, 0.999)
338
+ )
339
+
340
+ self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
341
+ self.optimizer,
342
+ mode='min',
343
+ factor=0.5,
344
+ patience=10,
345
+ verbose=True
346
+ )
347
+
348
+ self.criterion = CrackPINNLoss(
349
+ lambda_data=1.0,
350
+ lambda_physics=0.5,
351
+ lambda_smooth=0.1,
352
+ lambda_damage=0.3,
353
+ lambda_mogi=0.2,
354
+ lambda_reg=1e-4
355
+ )
356
+
357
+ self.train_losses = []
358
+ self.val_losses = []
359
+
360
+ def train_epoch(self, train_loader):
361
+ self.model.train()
362
+
363
+ epoch_losses = []
364
+ loss_components = {
365
+ 'data': [], 'physics': [], 'smooth': [],
366
+ 'damage': [], 'mogi': [], 'reg': []
367
+ }
368
+
369
+ for X_batch, y_batch in train_loader:
370
+ X_batch = X_batch.to(self.device)
371
+ y_batch = y_batch.to(self.device)
372
+
373
+ pred_dist, pred_total, physics_outputs = self.model(X_batch, return_physics=True)
374
+
375
+ loss, loss_dict = self.criterion(
376
+ pred_dist, pred_total, y_batch, self.model, physics_outputs
377
+ )
378
+
379
+ self.optimizer.zero_grad()
380
+ loss.backward()
381
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
382
+ self.optimizer.step()
383
+
384
+ epoch_losses.append(loss_dict['total'])
385
+ for key in loss_components.keys():
386
+ loss_components[key].append(loss_dict[key])
387
+
388
+ avg_loss = np.mean(epoch_losses)
389
+ avg_components = {key: np.mean(values) for key, values in loss_components.items()}
390
+
391
+ return avg_loss, avg_components
392
+
393
+ def validate(self, val_loader):
394
+ self.model.eval()
395
+
396
+ val_losses = []
397
+ all_preds = []
398
+ all_trues = []
399
+
400
+ with torch.no_grad():
401
+ for X_batch, y_batch in val_loader:
402
+ X_batch = X_batch.to(self.device)
403
+ y_batch = y_batch.to(self.device)
404
+
405
+ pred_dist, pred_total = self.model(X_batch, return_physics=False)
406
+
407
+ loss, _ = self.criterion(pred_dist, pred_total, y_batch, self.model)
408
+
409
+ val_losses.append(loss.item())
410
+ all_preds.append(pred_dist.cpu().numpy())
411
+ all_trues.append(y_batch.cpu().numpy())
412
+
413
+ avg_loss = np.mean(val_losses)
414
+
415
+ all_preds = np.concatenate(all_preds, axis=0)
416
+ all_trues = np.concatenate(all_trues, axis=0)
417
+
418
+ ss_res = np.sum((all_trues - all_preds) ** 2)
419
+ ss_tot = np.sum((all_trues - np.mean(all_trues)) ** 2)
420
+ r2 = 1 - (ss_res / (ss_tot + 1e-8))
421
+
422
+ rmse = np.sqrt(np.mean((all_trues - all_preds) ** 2))
423
+
424
+ pred_total_counts = all_preds.sum(axis=1)
425
+ true_total_counts = all_trues.sum(axis=1)
426
+ total_count_error = np.mean(np.abs(pred_total_counts - true_total_counts))
427
+
428
+ metrics = {
429
+ 'r2': r2,
430
+ 'rmse': rmse,
431
+ 'total_count_mae': total_count_error
432
+ }
433
+
434
+ return avg_loss, metrics
435
+
436
+ def fit(self, X_train, y_train, X_val, y_val, epochs=200, batch_size=16, patience=30):
437
+ train_dataset = TensorDataset(
438
+ torch.FloatTensor(X_train),
439
+ torch.FloatTensor(y_train)
440
+ )
441
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
442
+
443
+ val_dataset = TensorDataset(
444
+ torch.FloatTensor(X_val),
445
+ torch.FloatTensor(y_val)
446
+ )
447
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
448
+
449
+ best_val_loss = float('inf')
450
+ patience_counter = 0
451
+ best_model_state = None
452
+
453
+ print("\nStarting training...")
454
+ print("=" * 80)
455
+
456
+ for epoch in range(epochs):
457
+ train_loss, train_components = self.train_epoch(train_loader)
458
+ val_loss, val_metrics = self.validate(val_loader)
459
+
460
+ self.train_losses.append(train_loss)
461
+ self.val_losses.append(val_loss)
462
+
463
+ self.scheduler.step(val_loss)
464
+
465
+ if (epoch + 1) % 10 == 0 or epoch == 0:
466
+ print(f"Epoch {epoch+1}/{epochs}")
467
+ print(f" Train Loss: {train_loss:.4f} "
468
+ f"(data: {train_components['data']:.4f}, "
469
+ f"phys: {train_components['physics']:.4f}, "
470
+ f"damage: {train_components['damage']:.4f})")
471
+ print(f" Val Loss: {val_loss:.4f} | "
472
+ f"R2: {val_metrics['r2']:.4f} | "
473
+ f"RMSE: {val_metrics['rmse']:.2f}")
474
+
475
+ if val_loss < best_val_loss:
476
+ best_val_loss = val_loss
477
+ patience_counter = 0
478
+ best_model_state = self.model.state_dict().copy()
479
+ else:
480
+ patience_counter += 1
481
+
482
+ if patience_counter >= patience:
483
+ print(f"\nEarly stopping. Best val loss: {best_val_loss:.4f}")
484
+ break
485
+
486
+ if best_model_state is not None:
487
+ self.model.load_state_dict(best_model_state)
488
+
489
+ print("=" * 80)
490
+ print(f"Training complete. Best val loss: {best_val_loss:.4f}")
491
+
492
+ def predict(self, X):
493
+ self.model.eval()
494
+
495
+ with torch.no_grad():
496
+ X_tensor = torch.FloatTensor(X).to(self.device)
497
+ pred_dist, pred_total = self.model(X_tensor, return_physics=False)
498
+
499
+ pred_dist = pred_dist.cpu().numpy()
500
+ pred_total = pred_total.cpu().numpy().flatten()
501
+
502
+ return pred_dist, pred_total
503
+
504
+ def predict_with_physics(self, X):
505
+ self.model.eval()
506
+
507
+ with torch.no_grad():
508
+ X_tensor = torch.FloatTensor(X).to(self.device)
509
+ pred_dist, pred_total, physics = self.model(X_tensor, return_physics=True)
510
+
511
+ result = {
512
+ 'angle_distribution': pred_dist.cpu().numpy(),
513
+ 'total_count': pred_total.cpu().numpy().flatten(),
514
+ 'D0': physics['D0'].cpu().numpy().flatten(),
515
+ 'lambda': physics['lambda'].cpu().numpy().flatten(),
516
+ 'D_n': physics['D_n'].cpu().numpy().flatten()
517
+ }
518
+
519
+ return result
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ numpy>=1.21.0
3
+ pandas>=1.3.0
4
+ scikit-learn>=1.0.0
5
+ matplotlib>=3.5.0
train.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ from sklearn.model_selection import train_test_split
6
+ import pickle
7
+
8
+ from data_loader import CrackDataLoader, DamageCalculator
9
+ from model import CrackTransformerPINN, CrackPINNTrainer
10
+
11
+
12
+ def plot_training_history(trainer, save_path=None):
13
+ fig, axes = plt.subplots(1, 3, figsize=(15, 4))
14
+
15
+ axes[0].plot(trainer.train_losses, label='Train Loss', linewidth=2)
16
+ axes[0].plot(trainer.val_losses, label='Val Loss', linewidth=2)
17
+ axes[0].set_xlabel('Epoch', fontsize=12)
18
+ axes[0].set_ylabel('Loss', fontsize=12)
19
+ axes[0].set_title('Training History', fontsize=14, fontweight='bold')
20
+ axes[0].legend()
21
+ axes[0].grid(True, alpha=0.3)
22
+
23
+ axes[1].semilogy(trainer.train_losses, label='Train Loss', linewidth=2)
24
+ axes[1].semilogy(trainer.val_losses, label='Val Loss', linewidth=2)
25
+ axes[1].set_xlabel('Epoch', fontsize=12)
26
+ axes[1].set_ylabel('Loss (log scale)', fontsize=12)
27
+ axes[1].set_title('Training History (Log Scale)', fontsize=14, fontweight='bold')
28
+ axes[1].legend()
29
+ axes[1].grid(True, alpha=0.3)
30
+
31
+ if len(trainer.train_losses) > 1:
32
+ train_improvement = np.diff(trainer.train_losses)
33
+ axes[2].plot(train_improvement, linewidth=2, alpha=0.7)
34
+ axes[2].axhline(y=0, color='r', linestyle='--', alpha=0.5)
35
+ axes[2].set_xlabel('Epoch', fontsize=12)
36
+ axes[2].set_ylabel('Loss Change', fontsize=12)
37
+ axes[2].set_title('Convergence Rate', fontsize=14, fontweight='bold')
38
+ axes[2].grid(True, alpha=0.3)
39
+
40
+ plt.tight_layout()
41
+
42
+ if save_path:
43
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
44
+ print(f"Training history saved to: {save_path}")
45
+
46
+ plt.close()
47
+
48
+
49
+ def plot_prediction_examples(X_test, y_test, trainer, angle_bins, loader, n_examples=4, save_path=None):
50
+ X_test_original = loader.scaler_X.inverse_transform(X_test)
51
+ y_test_original = loader.scaler_y.inverse_transform(y_test)
52
+
53
+ y_pred_norm, pred_totals = trainer.predict(X_test[:n_examples])
54
+ y_pred = loader.scaler_y.inverse_transform(y_pred_norm)
55
+
56
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10), subplot_kw=dict(projection='polar'))
57
+ axes = axes.flatten()
58
+
59
+ for i in range(min(n_examples, len(axes))):
60
+ ax = axes[i]
61
+
62
+ theta = np.deg2rad(angle_bins)
63
+
64
+ ax.plot(theta, y_test_original[i], 'o-', label='True', linewidth=2, markersize=4, alpha=0.7)
65
+ ax.plot(theta, y_pred[i], 's-', label='Predicted', linewidth=2, markersize=3, alpha=0.7)
66
+
67
+ pH = X_test_original[i, 0]
68
+ FN = X_test_original[i, 1]
69
+ FT = X_test_original[i, 2]
70
+ T = X_test_original[i, 3]
71
+ phase = X_test_original[i, 4]
72
+ phase_str = "Unstable" if phase < 0.5 else "Peak"
73
+
74
+ D0 = DamageCalculator.compute_total_damage(pH, FN, FT, T)
75
+ lambda_coef = DamageCalculator.compute_lambda(D0)
76
+
77
+ true_total = y_test_original[i].sum()
78
+ pred_total = y_pred[i].sum()
79
+
80
+ title = f"pH={pH:.0f}, FN={FN:.0f}, FT={FT:.0f}, T={T:.0f}C\n"
81
+ title += f"D0={D0:.3f}, lambda={lambda_coef:.3f}\n"
82
+ title += f"{phase_str} | True: {true_total:.0f}, Pred: {pred_total:.0f}"
83
+
84
+ ax.set_title(title, fontsize=9, pad=20)
85
+ ax.legend(loc='upper right', fontsize=8)
86
+ ax.set_theta_zero_location('N')
87
+ ax.set_theta_direction(-1)
88
+ ax.grid(True, alpha=0.3)
89
+
90
+ plt.tight_layout()
91
+
92
+ if save_path:
93
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
94
+ print(f"Prediction examples saved to: {save_path}")
95
+
96
+ plt.close()
97
+
98
+
99
+ def plot_damage_analysis(X, y, save_path=None):
100
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10))
101
+
102
+ D0_values = []
103
+ total_cracks = y.sum(axis=1)
104
+
105
+ for i in range(X.shape[0]):
106
+ D0 = DamageCalculator.compute_total_damage(X[i, 0], X[i, 1], X[i, 2], X[i, 3])
107
+ D0_values.append(D0)
108
+
109
+ D0_values = np.array(D0_values)
110
+
111
+ axes[0, 0].scatter(D0_values, total_cracks, alpha=0.6, edgecolors='black', linewidth=0.5)
112
+ axes[0, 0].set_xlabel('Initial Damage Factor D0', fontsize=12)
113
+ axes[0, 0].set_ylabel('Total Crack Count', fontsize=12)
114
+ axes[0, 0].set_title('D0 vs Total Cracks', fontsize=14, fontweight='bold')
115
+ axes[0, 0].grid(True, alpha=0.3)
116
+
117
+ axes[0, 1].scatter(X[:, 0], total_cracks, alpha=0.6, c=D0_values, cmap='viridis')
118
+ axes[0, 1].set_xlabel('pH Value', fontsize=12)
119
+ axes[0, 1].set_ylabel('Total Crack Count', fontsize=12)
120
+ axes[0, 1].set_title('pH vs Total Cracks', fontsize=14, fontweight='bold')
121
+ axes[0, 1].grid(True, alpha=0.3)
122
+
123
+ axes[1, 0].scatter(X[:, 1], total_cracks, alpha=0.6, c=D0_values, cmap='viridis')
124
+ axes[1, 0].set_xlabel('Freeze-thaw Cycles (FN)', fontsize=12)
125
+ axes[1, 0].set_ylabel('Total Crack Count', fontsize=12)
126
+ axes[1, 0].set_title('FN vs Total Cracks', fontsize=14, fontweight='bold')
127
+ axes[1, 0].grid(True, alpha=0.3)
128
+
129
+ scatter = axes[1, 1].scatter(X[:, 3], total_cracks, alpha=0.6, c=D0_values, cmap='viridis')
130
+ axes[1, 1].set_xlabel('Damage Temperature (T)', fontsize=12)
131
+ axes[1, 1].set_ylabel('Total Crack Count', fontsize=12)
132
+ axes[1, 1].set_title('T vs Total Cracks', fontsize=14, fontweight='bold')
133
+ axes[1, 1].grid(True, alpha=0.3)
134
+
135
+ plt.colorbar(scatter, ax=axes[1, 1], label='D0')
136
+
137
+ plt.tight_layout()
138
+
139
+ if save_path:
140
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
141
+ print(f"Damage analysis saved to: {save_path}")
142
+
143
+ plt.close()
144
+
145
+
146
+ def main():
147
+ print("=" * 80)
148
+ print("Transformer-PINN Crack Prediction Model")
149
+ print("Based on: Mechanism of micro-damage evolution in rocks")
150
+ print("under multiple coupled cyclic stresses")
151
+ print("=" * 80)
152
+
153
+ base_path = "./data"
154
+ output_dir = "./output"
155
+
156
+ if not os.path.exists(output_dir):
157
+ os.makedirs(output_dir)
158
+
159
+ print("\n" + "=" * 80)
160
+ print("Step 1: Loading/Generating Data")
161
+ print("=" * 80)
162
+
163
+ loader = CrackDataLoader(base_path, stress_type="major")
164
+
165
+ try:
166
+ X, y, angle_bins, damage_list = loader.load_all_data(phase="both")
167
+ except:
168
+ print("Real data not found. Generating synthetic data...")
169
+ X, y, angle_bins = loader.create_synthetic_data(n_samples=200, output_dim=72)
170
+
171
+ stats = loader.get_statistics(X, y)
172
+ print("\nData statistics:")
173
+ for key, value in stats.items():
174
+ print(f" {key}: {value}")
175
+
176
+ print("\n" + "=" * 80)
177
+ print("Step 2: Splitting Dataset (Train:Val:Test = 64:16:20)")
178
+ print("=" * 80)
179
+
180
+ X_train, X_test, y_train, y_test = train_test_split(
181
+ X, y, test_size=0.2, random_state=42
182
+ )
183
+
184
+ X_train, X_val, y_train, y_val = train_test_split(
185
+ X_train, y_train, test_size=0.2, random_state=42
186
+ )
187
+
188
+ print(f"Training set: {X_train.shape[0]} samples")
189
+ print(f"Validation set: {X_val.shape[0]} samples")
190
+ print(f"Test set: {X_test.shape[0]} samples")
191
+
192
+ print("\n" + "=" * 80)
193
+ print("Step 3: Normalizing Data")
194
+ print("=" * 80)
195
+
196
+ X_train_norm, y_train_norm, X_val_norm, y_val_norm = loader.normalize_data(
197
+ X_train, y_train, X_val, y_val
198
+ )
199
+
200
+ X_test_norm = loader.scaler_X.transform(X_test)
201
+ y_test_norm = loader.scaler_y.transform(y_test)
202
+
203
+ print("Normalization complete")
204
+
205
+ print("\n" + "=" * 80)
206
+ print("Step 4: Creating Model")
207
+ print("=" * 80)
208
+
209
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
210
+ print(f"Using device: {device}")
211
+
212
+ model = CrackTransformerPINN(
213
+ input_dim=5,
214
+ output_dim=y.shape[1],
215
+ hidden_dims=[128, 256, 256, 128],
216
+ dropout=0.2
217
+ )
218
+
219
+ n_params = sum(p.numel() for p in model.parameters())
220
+ print(f"Model parameters: {n_params:,}")
221
+
222
+ print("\nModel components:")
223
+ print(" - Transformer Encoder (8 heads, 4 layers)")
224
+ print(" - Mogi-Coulomb Yield Criterion Layer")
225
+ print(" - Weibull Strength Distribution Layer")
226
+ print(" - Energy-based Damage Evolution Layer")
227
+ print(" - PINN Decoder with Physics Constraints")
228
+
229
+ print("\n" + "=" * 80)
230
+ print("Step 5: Training Model")
231
+ print("=" * 80)
232
+
233
+ trainer = CrackPINNTrainer(
234
+ model,
235
+ device=device,
236
+ lr=1e-3,
237
+ weight_decay=1e-4
238
+ )
239
+
240
+ trainer.fit(
241
+ X_train_norm, y_train_norm,
242
+ X_val_norm, y_val_norm,
243
+ epochs=300,
244
+ batch_size=8,
245
+ patience=50
246
+ )
247
+
248
+ print("\n" + "=" * 80)
249
+ print("Step 6: Testing Model")
250
+ print("=" * 80)
251
+
252
+ test_loss, test_metrics = trainer.validate(
253
+ torch.utils.data.DataLoader(
254
+ torch.utils.data.TensorDataset(
255
+ torch.FloatTensor(X_test_norm),
256
+ torch.FloatTensor(y_test_norm)
257
+ ),
258
+ batch_size=8,
259
+ shuffle=False
260
+ )
261
+ )
262
+
263
+ print(f"Test set performance:")
264
+ print(f" Loss: {test_loss:.4f}")
265
+ print(f" R2: {test_metrics['r2']:.4f}")
266
+ print(f" RMSE: {test_metrics['rmse']:.2f}")
267
+ print(f" Total Count MAE: {test_metrics['total_count_mae']:.2f}")
268
+
269
+ print("\n" + "=" * 80)
270
+ print("Step 7: Saving Model")
271
+ print("=" * 80)
272
+
273
+ model_path = os.path.join(output_dir, "crack_transformer_pinn.pth")
274
+ torch.save({
275
+ 'model_state_dict': model.state_dict(),
276
+ 'model_config': {
277
+ 'input_dim': 5,
278
+ 'output_dim': y.shape[1],
279
+ 'hidden_dims': [128, 256, 256, 128],
280
+ 'dropout': 0.2
281
+ },
282
+ 'test_metrics': test_metrics
283
+ }, model_path)
284
+ print(f"Model saved to: {model_path}")
285
+
286
+ scaler_path = os.path.join(output_dir, "scalers.pkl")
287
+ with open(scaler_path, 'wb') as f:
288
+ pickle.dump({
289
+ 'scaler_X': loader.scaler_X,
290
+ 'scaler_y': loader.scaler_y,
291
+ 'angle_bins': angle_bins
292
+ }, f)
293
+ print(f"Scalers saved to: {scaler_path}")
294
+
295
+ print("\n" + "=" * 80)
296
+ print("Step 8: Generating Visualizations")
297
+ print("=" * 80)
298
+
299
+ history_path = os.path.join(output_dir, "training_history.png")
300
+ plot_training_history(trainer, save_path=history_path)
301
+
302
+ examples_path = os.path.join(output_dir, "prediction_examples.png")
303
+ plot_prediction_examples(
304
+ X_test_norm, y_test_norm,
305
+ trainer, angle_bins, loader,
306
+ n_examples=4,
307
+ save_path=examples_path
308
+ )
309
+
310
+ damage_path = os.path.join(output_dir, "damage_analysis.png")
311
+ plot_damage_analysis(X, y, save_path=damage_path)
312
+
313
+ print("\n" + "=" * 80)
314
+ print("Training Pipeline Complete!")
315
+ print("=" * 80)
316
+ print(f"\nGenerated files:")
317
+ print(f" 1. Model checkpoint: {model_path}")
318
+ print(f" 2. Scalers: {scaler_path}")
319
+ print(f" 3. Training history: {history_path}")
320
+ print(f" 4. Prediction examples: {examples_path}")
321
+ print(f" 5. Damage analysis: {damage_path}")
322
+
323
+ print("\nPhysics constraints applied:")
324
+ print(" - Mogi-Coulomb yield criterion: tau_oct = C1 + C2 * sigma_m2")
325
+ print(" - Weibull strength: D_q = 1 - exp(-(F/F0)^m)")
326
+ print(" - Energy damage: D_n = (2/pi) * arctan(b * U_p)")
327
+ print(" - Total damage: D_total = 1 - (1-D_ft)(1-D_ch)(1-D_th)")
328
+
329
+
330
+ if __name__ == "__main__":
331
+ main()