willengler-uc commited on
Commit
5294e39
·
verified ·
1 Parent(s): 615d845

Upload RPV TTS models

Browse files
.gitattributes CHANGED
@@ -39,3 +39,4 @@ model_perovskite_Habs-main/model.dill filter=lfs diff=lfs merge=lfs -text
39
  model_Mg_alloy-main/model.dill filter=lfs diff=lfs merge=lfs -text
40
  model_thermalcond_aflow/model.dill filter=lfs diff=lfs merge=lfs -text
41
  model_thermalexp_aflow/model.dill filter=lfs diff=lfs merge=lfs -text
 
 
39
  model_Mg_alloy-main/model.dill filter=lfs diff=lfs merge=lfs -text
40
  model_thermalcond_aflow/model.dill filter=lfs diff=lfs merge=lfs -text
41
  model_thermalexp_aflow/model.dill filter=lfs diff=lfs merge=lfs -text
42
+ model_RPV_TTS/model.dill filter=lfs diff=lfs merge=lfs -text
model_RPV_TTS/.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
3
+ *.dill filter=lfs diff=lfs merge=lfs -text
4
+ *.pkl filter=lfs diff=lfs merge=lfs -text
model_RPV_TTS/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # model_RPV_TTS
2
+ Random forest model to predict the transition temperature shift (TTS) of reactor pressure vessel (RPV) steels
model_RPV_TTS/RandomForestRegressor.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86f924788562e4e81c05df3a557d26673f2238f54fff06666090643f4b0a541c
3
+ size 166993571
model_RPV_TTS/StandardScaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd721d5060f293a045bff734f7cf489b44544091830112bb5909835f03a3f4db
3
+ size 1645
model_RPV_TTS/X_train.csv ADDED
The diff for this file is too large to render. See raw diff
 
model_RPV_TTS/model.dill ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eeb1a8465c13f16a5b21afb1ba4d59d696c57bd71c3348c40fe442a95e2badab
3
+ size 343058384
model_RPV_TTS/predict_RPV_TTS.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import joblib
5
+ import dill
6
+ from mastml.feature_generators import ElementalFeatureGenerator, OneHotGroupGenerator
7
+
8
+ def get_preds_ebars_domains(df_test):
9
+ d = 'model_RPV_TTS'
10
+ scaler = joblib.load(os.path.join(d, 'StandardScaler.pkl'))
11
+ model = joblib.load(os.path.join(d, 'RandomForestRegressor.pkl'))
12
+ df_features = pd.read_csv(os.path.join(d, 'X_train.csv'))
13
+ recal_params = pd.read_csv(os.path.join(d, 'recal_dict.csv'))
14
+
15
+ features = df_features.columns.tolist()
16
+ df_test = df_test[features]
17
+
18
+ X = scaler.transform(df_test)
19
+
20
+ # Make predictions
21
+ preds = model.predict(X)
22
+
23
+ # Get ebars and recalibrate them
24
+ errs_list = list()
25
+ a = recal_params['a'][0]
26
+ b = recal_params['b'][0]
27
+ c = recal_params['c'][0]
28
+ for i, x in X.iterrows():
29
+ preds_list = list()
30
+ for pred in model.model.estimators_:
31
+ preds_list.append(pred.predict(np.array(x).reshape(1, -1))[0])
32
+ errs_list.append(np.std(preds_list))
33
+ ebars = a * np.array(errs_list)**2 + b * np.array(errs_list) + c
34
+
35
+ # Get domains
36
+ with open(os.path.join(d, 'model.dill'), 'rb') as f:
37
+ model_domain = dill.load(f)
38
+
39
+ domains = model_domain.predict(X)
40
+
41
+ return preds, ebars, domains
42
+
43
+
44
+ def make_predictions(df_test):
45
+
46
+ # Process data
47
+ X_train = pd.read_csv('model_RPV_TTS/X_train.csv')
48
+ feature_names = X_train.columns.tolist()
49
+
50
+ # Convert Product form encoding to numbers
51
+ pf = df_test['Product Form']
52
+ pf_0 = list()
53
+ pf_1 = list()
54
+ pf_2 = list()
55
+ pf_3 = list()
56
+ pf_4 = list()
57
+ pf_5 = list()
58
+ for i in pf:
59
+ if i == 'F':
60
+ pf_0.append(1)
61
+ pf_1.append(0)
62
+ pf_2.append(0)
63
+ pf_3.append(0)
64
+ pf_4.append(0)
65
+ pf_5.append(0)
66
+ elif i == 'HAZ':
67
+ pf_0.append(0)
68
+ pf_1.append(1)
69
+ pf_2.append(0)
70
+ pf_3.append(0)
71
+ pf_4.append(0)
72
+ pf_5.append(0)
73
+ elif i == 'P':
74
+ pf_0.append(0)
75
+ pf_1.append(0)
76
+ pf_2.append(1)
77
+ pf_3.append(0)
78
+ pf_4.append(0)
79
+ pf_5.append(0)
80
+ elif i == 'PCE':
81
+ pf_0.append(0)
82
+ pf_1.append(0)
83
+ pf_2.append(0)
84
+ pf_3.append(1)
85
+ pf_4.append(0)
86
+ pf_5.append(0)
87
+ elif i == 'SRM':
88
+ pf_0.append(0)
89
+ pf_1.append(0)
90
+ pf_2.append(0)
91
+ pf_3.append(0)
92
+ pf_4.append(1)
93
+ pf_5.append(0)
94
+ elif i == 'W':
95
+ pf_0.append(0)
96
+ pf_1.append(0)
97
+ pf_2.append(0)
98
+ pf_3.append(0)
99
+ pf_4.append(0)
100
+ pf_5.append(1)
101
+ else:
102
+ raise ValueError('Product form must be one of F, HAZ, P, PCE, SRM, W')
103
+
104
+ df_test['Product Form_0'] = pf_0
105
+ df_test['Product Form_1'] = pf_1
106
+ df_test['Product Form_2'] = pf_2
107
+ df_test['Product Form_3'] = pf_3
108
+ df_test['Product Form_4'] = pf_4
109
+ df_test['Product Form_5'] = pf_5
110
+
111
+ del df_test['Product Form']
112
+
113
+ # Check the data
114
+ cols_in = df_test.columns.tolist()
115
+ for c_in in cols_in:
116
+ if c_in not in feature_names:
117
+ print('Error with input feature', c_in)
118
+ print('Input features should be', feature_names)
119
+ break
120
+
121
+ # Get the ML predicted values
122
+ preds, ebars, domains = get_preds_ebars_domains(df_test)
123
+
124
+ pred_dict = {'Predicted TTS (degC)': preds,
125
+ 'Ebar TTS (degC)': ebars}
126
+
127
+ for d in domains.columns.tolist():
128
+ pred_dict[d] = domains[d]
129
+
130
+ del pred_dict['y_pred']
131
+ #del pred_dict['d_pred']
132
+ del pred_dict['y_stdu_pred']
133
+ del pred_dict['y_stdc_pred']
134
+
135
+ for f in feature_names:
136
+ pred_dict[f] = np.array(df_test[f]).ravel()
137
+
138
+ return pd.DataFrame(pred_dict)
model_RPV_TTS/recal_dict.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ a,b,c
2
+ -0.0014836483656973376,0.7730364503167175,2.4083842823104162
model_RPV_TTS/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ scikit-learn
2
+ numpy
3
+ pandas
4
+ mastml
5
+ pymatgen