timcryt commited on
Commit
5fae7ca
·
verified ·
1 Parent(s): 77f2fd3

Initial commit

Browse files
Files changed (6) hide show
  1. compare_models.py +130 -0
  2. example_eval.py +102 -0
  3. example_plot.py +28 -0
  4. train.py +198 -0
  5. utils_data.py +131 -0
  6. utils_model.py +15 -0
compare_models.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from torch_geometric.nn.models import SchNet, DimeNetPlusPlus
5
+
6
+ import ase
7
+ import ase.io
8
+ import re
9
+ from sklearn.metrics import r2_score
10
+ import numpy as np
11
+
12
+ import sys
13
+
14
+ from tqdm import tqdm
15
+
16
+ import argparse
17
+
18
+ from utils_model import ModellDimeNet
19
+
20
+ def get_model_and_optimizer(model_type):
21
+ if model_type == 'SchNet':
22
+ model = SchNet()
23
+ elif model_type == 'DimeNet':
24
+ model = ModelDimeNet()
25
+
26
+ optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
27
+
28
+ return model, optimizer
29
+
30
+
31
+ def train_epoch(model, optimizer, geoms, energies_n, mean_grad=32):
32
+ j = 0
33
+ model.train();
34
+
35
+ for geom, energy in zip(geoms, energies_n.clone().detach()):
36
+ if j == 0:
37
+ optimizer.zero_grad()
38
+ coords = torch.tensor(geom.get_positions(), dtype=torch.float32)
39
+ atoms = torch.tensor(geom.get_atomic_numbers())
40
+ batch = torch.zeros_like(atoms)
41
+
42
+ en = energy.clone().detach()
43
+ pred = model(atoms, coords, batch)
44
+ loss = F.huber_loss(pred.squeeze(), en)
45
+ (loss / mean_grad).backward(); j += 1
46
+ if j == mean_grad - 1:
47
+ optimizer.step()
48
+ j = 0
49
+
50
+ def test_epoch(model, optimizer, geoms, energies_n):
51
+ all_loss = 0
52
+ all_mols = 0
53
+ all_preds = []
54
+ all_trues = []
55
+ model.eval();
56
+
57
+ for geom, energy in zip(geoms, energies_n.clone().detach()):
58
+ coords = torch.tensor(geom.get_positions(), dtype=torch.float32)
59
+ atoms = torch.tensor(geom.get_atomic_numbers())
60
+ batch = torch.zeros_like(atoms)
61
+
62
+ en = energy.clone().detach()
63
+ with torch.no_grad():
64
+ pred = model(atoms, coords, batch)
65
+ all_preds.append(pred.item())
66
+ all_trues.append(en.item())
67
+ all_loss += F.l1_loss(pred.squeeze(), en).item()
68
+ all_mols += 1
69
+ return {
70
+ 'r2_score': r2_score(np.array(all_trues), np.array(all_preds)),
71
+ 'mae': all_loss / all_mols,
72
+ }
73
+
74
+ def train(model, optimizer, geoms, energies_n, n_epochs=100):
75
+ best_r2score = -1e100
76
+ best_mae = 1e100
77
+
78
+ for i in tqdm(range(n_epochs)):
79
+ train_epoch(model, optimizer, geoms, energies_n)
80
+
81
+ metrics = test_epoch(model, optimizer, geoms, energies_n)
82
+
83
+ if best_r2score < metrics['r2_score']:
84
+ best_r2score = metrics['r2_score']
85
+ if best_mae > metrics['mae']:
86
+ best_mae = metrics['mae']
87
+
88
+ return best_r2score, best_mae
89
+
90
+ def main(trajectory_file, model_type):
91
+ geoms = ase.io.read(trajectory_file, format='xyz', index=':')
92
+
93
+ with open(trajectory_file) as f:
94
+ cont = f.read()
95
+
96
+ energies = []
97
+ lines = cont.split('\n'); i = 0
98
+ while i < len(lines):
99
+ try:
100
+ n = int(lines[i].strip())
101
+ except ValueError:
102
+ break
103
+ comment = lines[i+1]
104
+ energy = float(re.findall('energy\\:\\s+(-?\\d*\.\\d*)', comment)[0])
105
+ energies.append(energy)
106
+ i += n + 2
107
+
108
+ energies = torch.tensor(energies)
109
+ energies_n = (energies - energies.min()) * 627.5
110
+
111
+ model, optimizer = get_model_and_optimizer(model_type)
112
+
113
+ best_r2score, best_mae = train(model, optimizer, geoms, energies_n)
114
+
115
+ print(f'R2_score: {best_r2score:.4f}')
116
+ print(f'MAE: {best_mae:.3f} kcal/mol')
117
+
118
+ avaliable_models = ['SchNet', 'DimeNet']
119
+
120
+ if __name__ == "__main__":
121
+ parser = argparse.ArgumentParser(description="Обработчик файлов с различными моделями")
122
+ parser.add_argument("filename", help="Путь к обрабатываемому файлу")
123
+ parser.add_argument("model",
124
+ choices=avaliable_models,
125
+ help=f"Выбор модели из доступных: {', '.join(avaliable_models)}")
126
+
127
+ args = parser.parse_args()
128
+
129
+ main(args.filename, args.model)
130
+
example_eval.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+
4
+ import numpy as np
5
+
6
+ from utils_data import read_data, MolDataset, collate_mol, get_train_test_data
7
+ from utils_model import ModelDimeNet
8
+
9
+ def main(denormalize, checkpoint_path, data_filename):
10
+ model = ModelDimeNet()
11
+ model.load_state_dict(torch.load(checkpoint_path, weights_only=True))
12
+
13
+ all_numbers, all_coords, energies, groups = read_data(data_filename)
14
+ ds_all = MolDataset(all_numbers, all_coords, energies, normalize=denormalize)
15
+
16
+ ds_train, ds_test = get_train_test_data(ds_all, groups, 'finetune')
17
+
18
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+
20
+ model.to(device);
21
+
22
+ ens = []
23
+
24
+ for numbers, coords, energy in ds_test:
25
+ if len(ens) in ds_train.indices:
26
+ ens.append(energy)
27
+ else:
28
+ coords = torch.tensor(coords, dtype=torch.float32).to(device)
29
+ atoms = torch.tensor(numbers).to(device)
30
+ batch = torch.zeros_like(atoms).to(device)
31
+
32
+ with torch.no_grad():
33
+ ens.append(model(atoms, coords, batch).item())
34
+
35
+
36
+ if denormalize:
37
+ ens = ens.sign() * ens.abs() ** 10
38
+
39
+ ensa = np.array(ens[1:])
40
+
41
+ n_atoms = 3
42
+ n_modes = n_atoms * 3 - 6
43
+ modes_i = []
44
+
45
+ for i in range(n_modes):
46
+ modes_i.append(ensa[0:][16*i:16*i+16])
47
+
48
+ all_a = []
49
+
50
+ m = 0
51
+ for i in range(n_modes):
52
+ for j in range(i + 1, n_modes):
53
+ for k in range(16):
54
+ all_a.append(ensa[16*n_modes+m:16*n_modes+m+16] - modes_i[j] - modes_i[i][k])
55
+ m += 16
56
+
57
+ arr = np.concatenate([modes_i, all_a]).reshape((-1))
58
+
59
+ with open('template.rst', 'r') as f:
60
+ content = f.read()
61
+
62
+ fo = open('filled.rst', 'w')
63
+
64
+ i = 0
65
+ for line in content.split('\n'):
66
+ if i < len(arr):
67
+ new_line = line.replace('{}', f'{arr[i]:.10f}')
68
+ if line != new_line:
69
+ i += 1
70
+ else:
71
+ new_line = line
72
+ fo.write(new_line + '\n')
73
+
74
+ del fo
75
+
76
+ if __name__ == '__main__':
77
+ parser = argparse.ArgumentParser()
78
+
79
+ # Обязательные аргументы
80
+
81
+ parser.add_argument('checkpoint_path',
82
+ type=str,
83
+ )
84
+ parser.add_argument('data_filename',
85
+ type=str,
86
+ )
87
+
88
+ # Флаги (булевые параметры)
89
+ parser.add_argument('--denormalize',
90
+ action='store_true',
91
+ )
92
+
93
+
94
+ args = parser.parse_args()
95
+
96
+
97
+ # Вызов основной функции
98
+ main(
99
+ denormalize=args.denormalize,
100
+ checkpoint_path=args.checkpoint_path,
101
+ data_filename=args.data_filename
102
+ )
example_plot.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import matplotlib.pyplot as plt
3
+ import argparse
4
+
5
+ def main(picklefile):
6
+ with open(picklefile, 'rb') as f:
7
+ all_metrics = pickle.load(f)
8
+
9
+ plt.plot([i for i in range(len(all_metrics))], [a[1] for a in all_metrics])
10
+ plt.grid(True)
11
+ plt.ylim(0.0, 0.05)
12
+ plt.xlabel('Эпоха')
13
+ plt.ylabel('MAE$')
14
+ plt.show()
15
+
16
+ plt.plot([i for i in range(len(all_metrics))], [a[0] for a in all_metrics])
17
+ plt.grid(True)
18
+ plt.ylim(0.0, 1.0)
19
+ plt.xlabel('Эпоха')
20
+ plt.ylabel('$R^2$')
21
+ plt.show()
22
+
23
+ if __name__ == "__main__":
24
+ parser = argparse.ArgumentParser(description='Process a pickle file')
25
+ parser.add_argument('picklefile', type=str, help='Path to pickle file')
26
+ args = parser.parse_args()
27
+
28
+ main(args.picklefile)
train.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics import r2_score
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+ import argparse
8
+
9
+ import pickle
10
+
11
+ from utils_data import read_data, MolDataset, collate_mol, get_train_test_data
12
+ from utils_model import ModelDimeNet
13
+
14
+
15
+ def get_model():
16
+ model = ModelDimeNet()
17
+ return model
18
+
19
+
20
+ def get_optimizer(model, e_start):
21
+ optimizer = torch.optim.RMSprop(model.parameters(), lr=10 ** -e_start)
22
+
23
+ return optimizer
24
+
25
+ def get_loss(mode):
26
+ if mode[0] == 'mae':
27
+ return lambda pred, en: (pred - en).abs().mean()
28
+
29
+ if mode[0] == 'adaptive':
30
+ return lambda pred, en: ((pred - en).abs() / (en.abs() + 1e-5) ** mode[1]).mean()
31
+
32
+ def train_epoch(model, optimizer, dl_train, loss_fn, device):
33
+ model.train()
34
+
35
+ for atoms, coords, energy, batch in dl_train:
36
+ optimizer.zero_grad()
37
+
38
+ atoms = atoms.to(device)
39
+ coords = coords.to(device)
40
+ energy = energy.to(device)
41
+ batch = batch.to(device)
42
+
43
+ en = energy.squeeze()
44
+ pred = model(atoms, coords, batch).squeeze()
45
+ loss = loss_fn(pred, en)
46
+ loss.backward()
47
+ optimizer.step()
48
+
49
+
50
+ def test_epoch(model, optimizer, dl_test, device):
51
+ all_loss = 0
52
+ all_mols = 0
53
+ all_preds = []
54
+ all_trues = []
55
+ model.eval()
56
+
57
+ for atoms, coords, energy, batch in dl_test:
58
+ atoms = atoms.to(device)
59
+ coords = coords.to(device)
60
+ energy = energy.to(device)
61
+ batch = batch.to(device)
62
+
63
+ en = energy.squeeze()
64
+ with torch.no_grad():
65
+ pred = model(atoms, coords, batch).squeeze()
66
+ all_preds.append(pred.cpu().numpy())
67
+ all_trues.append(en.cpu().numpy())
68
+ all_loss += F.l1_loss(pred.squeeze(), en).item() * len(pred)
69
+ all_mols += len(pred)
70
+
71
+ all_trues = np.concatenate(all_trues)
72
+ all_preds = np.concatenate(all_preds)
73
+
74
+ return {
75
+ 'r2_score': r2_score(np.array(all_trues), np.array(all_preds)),
76
+ 'mae': all_loss / all_mols,
77
+ }
78
+
79
+ def refresh_lr(optimizer, i, n, e_start, downscale=2.0):
80
+ for g in optimizer.param_groups:
81
+ g['lr'] = 10 ** -(e_start + i / n * downscale)
82
+
83
+ return 10 ** -(e_start + i / n * downscale)
84
+
85
+
86
+ def train(n_epoch, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix):
87
+ all_metrics = []
88
+ new_lr = e_start
89
+
90
+ for i in tqdm(range(n_epoch)):
91
+ train_epoch(model, optimizer, dl_train, loss_fn, device)
92
+
93
+ metrics = test_epoch(model, optimizer, dl_test, device)
94
+
95
+ cur_lr = new_lr
96
+ new_lr = refresh_lr(optimizer, i, n_epoch, e_start)
97
+
98
+
99
+ all_metrics.append((
100
+ metrics['r2_score'],
101
+ metrics['mae'],
102
+ cur_lr,
103
+ ))
104
+
105
+ torch.save(model.state_dict(), checkpoint_prefix + '.ckpt')
106
+
107
+ return all_metrics
108
+
109
+ def main(loss_mode, normalize, pretrain, checkpoint_prefix, data_filename):
110
+ all_numbers, all_coords, energies, groups = read_data(data_filename)
111
+ ds_all = MolDataset(all_numbers, all_coords, energies, normalize=normalize)
112
+
113
+ loss_fn = get_loss(loss_mode)
114
+
115
+ model = get_model()
116
+
117
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
118
+ model = model.to(device)
119
+
120
+ # Pretraining
121
+ if pretrain:
122
+ e_start = 4
123
+
124
+ ds_train, ds_test = get_train_test_data(ds_all, groups, 'pretrain')
125
+ dl_train = torch.utils.data.DataLoader(ds_train, batch_size=32, shuffle=True, collate_fn=collate_mol)
126
+ dl_test = torch.utils.data.DataLoader(ds_test, batch_size=32, shuffle=False, collate_fn=collate_mol)
127
+
128
+ optimizer = get_optimizer(model, e_start=e_start)
129
+
130
+ all_metrics = train(100, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix + '_pretrain_model')
131
+ with open(checkpoint_prefix + '_pretrain_metrics.pkl', 'wb') as f:
132
+ pickle.dump(all_metrics, f)
133
+
134
+ # Fine-tuting
135
+ e_start = 5
136
+
137
+ ds_train, ds_test = get_train_test_data(ds_all, groups, 'finetune')
138
+ dl_train = torch.utils.data.DataLoader(ds_train, batch_size=32, shuffle=True, collate_fn=collate_mol)
139
+ dl_test = torch.utils.data.DataLoader(ds_test, batch_size=32, shuffle=False, collate_fn=collate_mol)
140
+
141
+ optimizer = get_optimizer(model, e_start=e_start)
142
+
143
+ all_metrics = train(100, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix + '_finetune_model')
144
+ with open(checkpoint_prefix + '_finetune_metrics.pkl', 'wb') as f:
145
+ pickle.dump(all_metrics, f)
146
+
147
+
148
+ if __name__ == "__main__":
149
+ parser = argparse.ArgumentParser(description='Параметры для обучения модели')
150
+
151
+ # Обязательные аргументы
152
+ parser.add_argument('loss_mode',
153
+ choices=['mae', 'adaptive'],
154
+ help="Режим потерь: 'mae' или 'adaptive'")
155
+ parser.add_argument('checkpoint_prefix',
156
+ type=str,
157
+ help="Префикс для чекпоинтов")
158
+ parser.add_argument('data_filename',
159
+ type=str,
160
+ help="Путь к файлу с датасетом")
161
+
162
+ # Флаги (булевые параметры)
163
+ parser.add_argument('--normalize',
164
+ action='store_true',
165
+ help="Применить нормализацию (только для loss_mode='mae')")
166
+ parser.add_argument('--pretrain',
167
+ action='store_true',
168
+ help="Использовать предобучение")
169
+
170
+ # Параметр только для adaptive режима
171
+ parser.add_argument('--loss_k',
172
+ type=float,
173
+ default=None,
174
+ help="Коэффициент k для adaptive loss (требуется при loss_mode='adaptive')")
175
+
176
+ args = parser.parse_args()
177
+
178
+ # Проверка совместимости параметров
179
+ if args.loss_mode == 'adaptive':
180
+ if args.normalize:
181
+ raise ValueError("Параметр --normalize несовместим с loss_mode='adaptive'")
182
+ if args.loss_k is None:
183
+ raise ValueError("Для adaptive loss требуется параметр --loss_k")
184
+ # Формируем кортеж для adaptive режима
185
+ loss_mode_arg = ('adaptive', args.loss_k)
186
+ else: # loss_mode == 'mae'
187
+ if args.loss_k is not None:
188
+ raise ValueError("Параметр --loss_k можно использовать только с loss_mode='adaptive'")
189
+ loss_mode_arg = ('mae', )
190
+
191
+ # Вызов основной функции
192
+ main(
193
+ loss_mode=loss_mode_arg,
194
+ normalize=args.normalize,
195
+ pretrain=args.pretrain,
196
+ checkpoint_prefix=args.checkpoint_prefix,
197
+ data_filename=args.data_filename,
198
+ )
utils_data.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+
4
+
5
+ def read_data(filename):
6
+ all_coords = []
7
+ all_numbers = []
8
+
9
+ with open(filename) as f:
10
+ cont = f.read()
11
+
12
+ energies = []
13
+ groups = []
14
+ lines = cont.split('\n'); i = 0
15
+ mol_en = None
16
+ while i < len(lines):
17
+ try:
18
+ n = int(lines[i].strip())
19
+ except ValueError:
20
+ break
21
+ comment = lines[i+1]
22
+ energy = float(re.findall('Energy\\:\\s+(-?\\d*\.\\d*)', comment)[0])
23
+ g0 = re.findall('Grid: 0', comment)
24
+ if g0:
25
+ mol_en = energy
26
+ grp = ()
27
+ g1 = re.findall('Grid: (\\d+): (\\d+)', comment)
28
+ if g1:
29
+ grp = (g1[0][1], )
30
+ g2 = re.findall('Grid: (\\d+): (\\d+), (\\d+): (\\d+)', comment)
31
+ if g2:
32
+ grp = (g2[0][1], g2[0][3])
33
+ energies.append(energy - mol_en)
34
+ groups.append(grp)
35
+ j = 0
36
+ all_coords.append([])
37
+ all_numbers.append([])
38
+ while j < n:
39
+ at, x, y, z = list(filter(lambda x: x != '', lines[i+j+2].strip().split(' ')))
40
+ all_coords[-1].append((float(x), float(y), float(z)))
41
+ all_numbers[-1].append(int(at))
42
+ j += 1
43
+ i += n + 2
44
+
45
+ energies = torch.tensor(energies)
46
+
47
+ return all_numbers, all_coords, energies, groups
48
+
49
+ class MolDataset(torch.utils.data.Dataset):
50
+ def __init__(self, all_numbers, all_coords, energies, normalize=False):
51
+ self.numbers = all_numbers
52
+ self.coords = all_coords
53
+ self.energies = energies
54
+ self.normalize = normalize
55
+
56
+ def __len__(self):
57
+ return len(self.energies)
58
+
59
+ def __getitem__(self, ind):
60
+ energy = self.energies[ind]
61
+
62
+ atoms = torch.tensor(self.numbers[ind])
63
+ coords = torch.tensor(self.coords[ind], dtype=torch.float32)
64
+
65
+ if self.normalize:
66
+ energy = energy.sign() * energy.abs() ** 0.1
67
+
68
+ return atoms, coords, energy
69
+
70
+ def collate_mol(batch):
71
+ """
72
+ Collate function for molecular dataset.
73
+
74
+ Args:
75
+ batch: List of tuples (atoms, coords, energy) from MolDataset
76
+
77
+ Returns:
78
+ atoms_cat: Concatenated atomic numbers tensor of shape [total_atoms]
79
+ coords_cat: Concatenated coordinates tensor of shape [total_atoms, 3]
80
+ energies: Energy tensor of shape [batch_size]
81
+ batch_tensor: Batch indices tensor of shape [total_atoms]
82
+ """
83
+ atoms_list = []
84
+ coords_list = []
85
+ energies_list = []
86
+ batch_indices = []
87
+
88
+ # Process each molecule in the batch
89
+ for i, (atoms, coords, energy) in enumerate(batch):
90
+ n_atoms = atoms.size(0)
91
+
92
+ # Store components
93
+ atoms_list.append(atoms)
94
+ coords_list.append(coords)
95
+ energies_list.append(energy)
96
+
97
+ # Create batch indices: [i, i, ..., i] for n_atoms times
98
+ batch_indices.append(torch.full((n_atoms,), i, dtype=torch.long))
99
+
100
+ # Concatenate all components
101
+ atoms_cat = torch.cat(atoms_list, dim=0) # shape: [total_atoms]
102
+ coords_cat = torch.cat(coords_list, dim=0) # shape: [total_atoms, 3]
103
+ energies = torch.stack(energies_list) # shape: [batch_size]
104
+ batch_tensor = torch.cat(batch_indices, dim=0) # shape: [total_atoms]
105
+
106
+ return atoms_cat, coords_cat, energies, batch_tensor
107
+
108
+ def get_train_test_data(ds_all, groups, mode, test_idcs=range(28986, 29803)):
109
+ grid1_selection = ['1', '3', '5', '7', '8', '9', '10', '12', '14', '16']
110
+ grid2_selection = ['1', '5', '8', '9', '12', '16']
111
+
112
+ assert(mode in ['pretrain', 'finetune'])
113
+
114
+ pretrain = mode == 'pretrain'
115
+
116
+ train_idces = []
117
+
118
+ for i in range(len(groups)):
119
+ if len(groups[i]) == 0:
120
+ if (i in test_idcs) != pretrain: train_idces.append(i)
121
+ elif len(groups[i]) == 1:
122
+ if pretrain or groups[i][0] in grid1_selection:
123
+ if (i in test_idcs) != pretrain: train_idces.append(i)
124
+ elif len(groups[i]) == 2:
125
+ if pretrain or groups[i][0] in grid2_selection and groups[i][1] in grid2_selection:
126
+ if (i in test_idcs) != pretrain: train_idces.append(i)
127
+
128
+ ds_train = torch.utils.data.Subset(ds_all, train_idces)
129
+ ds_test = torch.utils.data.Subset(ds_all, test_idcs)
130
+
131
+ return ds_train, ds_test
utils_model.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch_geometric
4
+ from torch_geometric.nn.models import DimeNetPlusPlus
5
+
6
+ class ModelDimeNet(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ self.net = DimeNetPlusPlus(hidden_channels=256, out_channels=256, num_blocks=4, num_spherical=8, num_radial=8, int_emb_size=64, basis_emb_size=64, out_emb_channels=64)
11
+ self.head = nn.Linear(256, 1)
12
+
13
+ def forward(self, atoms, coords, batch):
14
+ emb = self.net(atoms, coords, batch)
15
+ return self.head(emb)