Initial commit
Browse files- compare_models.py +130 -0
- example_eval.py +102 -0
- example_plot.py +28 -0
- train.py +198 -0
- utils_data.py +131 -0
- utils_model.py +15 -0
compare_models.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch_geometric.nn.models import SchNet, DimeNetPlusPlus
|
| 5 |
+
|
| 6 |
+
import ase
|
| 7 |
+
import ase.io
|
| 8 |
+
import re
|
| 9 |
+
from sklearn.metrics import r2_score
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
|
| 18 |
+
from utils_model import ModellDimeNet
|
| 19 |
+
|
| 20 |
+
def get_model_and_optimizer(model_type):
|
| 21 |
+
if model_type == 'SchNet':
|
| 22 |
+
model = SchNet()
|
| 23 |
+
elif model_type == 'DimeNet':
|
| 24 |
+
model = ModelDimeNet()
|
| 25 |
+
|
| 26 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
|
| 27 |
+
|
| 28 |
+
return model, optimizer
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def train_epoch(model, optimizer, geoms, energies_n, mean_grad=32):
|
| 32 |
+
j = 0
|
| 33 |
+
model.train();
|
| 34 |
+
|
| 35 |
+
for geom, energy in zip(geoms, energies_n.clone().detach()):
|
| 36 |
+
if j == 0:
|
| 37 |
+
optimizer.zero_grad()
|
| 38 |
+
coords = torch.tensor(geom.get_positions(), dtype=torch.float32)
|
| 39 |
+
atoms = torch.tensor(geom.get_atomic_numbers())
|
| 40 |
+
batch = torch.zeros_like(atoms)
|
| 41 |
+
|
| 42 |
+
en = energy.clone().detach()
|
| 43 |
+
pred = model(atoms, coords, batch)
|
| 44 |
+
loss = F.huber_loss(pred.squeeze(), en)
|
| 45 |
+
(loss / mean_grad).backward(); j += 1
|
| 46 |
+
if j == mean_grad - 1:
|
| 47 |
+
optimizer.step()
|
| 48 |
+
j = 0
|
| 49 |
+
|
| 50 |
+
def test_epoch(model, optimizer, geoms, energies_n):
|
| 51 |
+
all_loss = 0
|
| 52 |
+
all_mols = 0
|
| 53 |
+
all_preds = []
|
| 54 |
+
all_trues = []
|
| 55 |
+
model.eval();
|
| 56 |
+
|
| 57 |
+
for geom, energy in zip(geoms, energies_n.clone().detach()):
|
| 58 |
+
coords = torch.tensor(geom.get_positions(), dtype=torch.float32)
|
| 59 |
+
atoms = torch.tensor(geom.get_atomic_numbers())
|
| 60 |
+
batch = torch.zeros_like(atoms)
|
| 61 |
+
|
| 62 |
+
en = energy.clone().detach()
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
pred = model(atoms, coords, batch)
|
| 65 |
+
all_preds.append(pred.item())
|
| 66 |
+
all_trues.append(en.item())
|
| 67 |
+
all_loss += F.l1_loss(pred.squeeze(), en).item()
|
| 68 |
+
all_mols += 1
|
| 69 |
+
return {
|
| 70 |
+
'r2_score': r2_score(np.array(all_trues), np.array(all_preds)),
|
| 71 |
+
'mae': all_loss / all_mols,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
def train(model, optimizer, geoms, energies_n, n_epochs=100):
|
| 75 |
+
best_r2score = -1e100
|
| 76 |
+
best_mae = 1e100
|
| 77 |
+
|
| 78 |
+
for i in tqdm(range(n_epochs)):
|
| 79 |
+
train_epoch(model, optimizer, geoms, energies_n)
|
| 80 |
+
|
| 81 |
+
metrics = test_epoch(model, optimizer, geoms, energies_n)
|
| 82 |
+
|
| 83 |
+
if best_r2score < metrics['r2_score']:
|
| 84 |
+
best_r2score = metrics['r2_score']
|
| 85 |
+
if best_mae > metrics['mae']:
|
| 86 |
+
best_mae = metrics['mae']
|
| 87 |
+
|
| 88 |
+
return best_r2score, best_mae
|
| 89 |
+
|
| 90 |
+
def main(trajectory_file, model_type):
|
| 91 |
+
geoms = ase.io.read(trajectory_file, format='xyz', index=':')
|
| 92 |
+
|
| 93 |
+
with open(trajectory_file) as f:
|
| 94 |
+
cont = f.read()
|
| 95 |
+
|
| 96 |
+
energies = []
|
| 97 |
+
lines = cont.split('\n'); i = 0
|
| 98 |
+
while i < len(lines):
|
| 99 |
+
try:
|
| 100 |
+
n = int(lines[i].strip())
|
| 101 |
+
except ValueError:
|
| 102 |
+
break
|
| 103 |
+
comment = lines[i+1]
|
| 104 |
+
energy = float(re.findall('energy\\:\\s+(-?\\d*\.\\d*)', comment)[0])
|
| 105 |
+
energies.append(energy)
|
| 106 |
+
i += n + 2
|
| 107 |
+
|
| 108 |
+
energies = torch.tensor(energies)
|
| 109 |
+
energies_n = (energies - energies.min()) * 627.5
|
| 110 |
+
|
| 111 |
+
model, optimizer = get_model_and_optimizer(model_type)
|
| 112 |
+
|
| 113 |
+
best_r2score, best_mae = train(model, optimizer, geoms, energies_n)
|
| 114 |
+
|
| 115 |
+
print(f'R2_score: {best_r2score:.4f}')
|
| 116 |
+
print(f'MAE: {best_mae:.3f} kcal/mol')
|
| 117 |
+
|
| 118 |
+
avaliable_models = ['SchNet', 'DimeNet']
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
parser = argparse.ArgumentParser(description="Обработчик файлов с различными моделями")
|
| 122 |
+
parser.add_argument("filename", help="Путь к обрабатываемому файлу")
|
| 123 |
+
parser.add_argument("model",
|
| 124 |
+
choices=avaliable_models,
|
| 125 |
+
help=f"Выбор модели из доступных: {', '.join(avaliable_models)}")
|
| 126 |
+
|
| 127 |
+
args = parser.parse_args()
|
| 128 |
+
|
| 129 |
+
main(args.filename, args.model)
|
| 130 |
+
|
example_eval.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from utils_data import read_data, MolDataset, collate_mol, get_train_test_data
|
| 7 |
+
from utils_model import ModelDimeNet
|
| 8 |
+
|
| 9 |
+
def main(denormalize, checkpoint_path, data_filename):
|
| 10 |
+
model = ModelDimeNet()
|
| 11 |
+
model.load_state_dict(torch.load(checkpoint_path, weights_only=True))
|
| 12 |
+
|
| 13 |
+
all_numbers, all_coords, energies, groups = read_data(data_filename)
|
| 14 |
+
ds_all = MolDataset(all_numbers, all_coords, energies, normalize=denormalize)
|
| 15 |
+
|
| 16 |
+
ds_train, ds_test = get_train_test_data(ds_all, groups, 'finetune')
|
| 17 |
+
|
| 18 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 19 |
+
|
| 20 |
+
model.to(device);
|
| 21 |
+
|
| 22 |
+
ens = []
|
| 23 |
+
|
| 24 |
+
for numbers, coords, energy in ds_test:
|
| 25 |
+
if len(ens) in ds_train.indices:
|
| 26 |
+
ens.append(energy)
|
| 27 |
+
else:
|
| 28 |
+
coords = torch.tensor(coords, dtype=torch.float32).to(device)
|
| 29 |
+
atoms = torch.tensor(numbers).to(device)
|
| 30 |
+
batch = torch.zeros_like(atoms).to(device)
|
| 31 |
+
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
ens.append(model(atoms, coords, batch).item())
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if denormalize:
|
| 37 |
+
ens = ens.sign() * ens.abs() ** 10
|
| 38 |
+
|
| 39 |
+
ensa = np.array(ens[1:])
|
| 40 |
+
|
| 41 |
+
n_atoms = 3
|
| 42 |
+
n_modes = n_atoms * 3 - 6
|
| 43 |
+
modes_i = []
|
| 44 |
+
|
| 45 |
+
for i in range(n_modes):
|
| 46 |
+
modes_i.append(ensa[0:][16*i:16*i+16])
|
| 47 |
+
|
| 48 |
+
all_a = []
|
| 49 |
+
|
| 50 |
+
m = 0
|
| 51 |
+
for i in range(n_modes):
|
| 52 |
+
for j in range(i + 1, n_modes):
|
| 53 |
+
for k in range(16):
|
| 54 |
+
all_a.append(ensa[16*n_modes+m:16*n_modes+m+16] - modes_i[j] - modes_i[i][k])
|
| 55 |
+
m += 16
|
| 56 |
+
|
| 57 |
+
arr = np.concatenate([modes_i, all_a]).reshape((-1))
|
| 58 |
+
|
| 59 |
+
with open('template.rst', 'r') as f:
|
| 60 |
+
content = f.read()
|
| 61 |
+
|
| 62 |
+
fo = open('filled.rst', 'w')
|
| 63 |
+
|
| 64 |
+
i = 0
|
| 65 |
+
for line in content.split('\n'):
|
| 66 |
+
if i < len(arr):
|
| 67 |
+
new_line = line.replace('{}', f'{arr[i]:.10f}')
|
| 68 |
+
if line != new_line:
|
| 69 |
+
i += 1
|
| 70 |
+
else:
|
| 71 |
+
new_line = line
|
| 72 |
+
fo.write(new_line + '\n')
|
| 73 |
+
|
| 74 |
+
del fo
|
| 75 |
+
|
| 76 |
+
if __name__ == '__main__':
|
| 77 |
+
parser = argparse.ArgumentParser()
|
| 78 |
+
|
| 79 |
+
# Обязательные аргументы
|
| 80 |
+
|
| 81 |
+
parser.add_argument('checkpoint_path',
|
| 82 |
+
type=str,
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument('data_filename',
|
| 85 |
+
type=str,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Флаги (булевые параметры)
|
| 89 |
+
parser.add_argument('--denormalize',
|
| 90 |
+
action='store_true',
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
args = parser.parse_args()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# Вызов основной функции
|
| 98 |
+
main(
|
| 99 |
+
denormalize=args.denormalize,
|
| 100 |
+
checkpoint_path=args.checkpoint_path,
|
| 101 |
+
data_filename=args.data_filename
|
| 102 |
+
)
|
example_plot.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
def main(picklefile):
|
| 6 |
+
with open(picklefile, 'rb') as f:
|
| 7 |
+
all_metrics = pickle.load(f)
|
| 8 |
+
|
| 9 |
+
plt.plot([i for i in range(len(all_metrics))], [a[1] for a in all_metrics])
|
| 10 |
+
plt.grid(True)
|
| 11 |
+
plt.ylim(0.0, 0.05)
|
| 12 |
+
plt.xlabel('Эпоха')
|
| 13 |
+
plt.ylabel('MAE$')
|
| 14 |
+
plt.show()
|
| 15 |
+
|
| 16 |
+
plt.plot([i for i in range(len(all_metrics))], [a[0] for a in all_metrics])
|
| 17 |
+
plt.grid(True)
|
| 18 |
+
plt.ylim(0.0, 1.0)
|
| 19 |
+
plt.xlabel('Эпоха')
|
| 20 |
+
plt.ylabel('$R^2$')
|
| 21 |
+
plt.show()
|
| 22 |
+
|
| 23 |
+
if __name__ == "__main__":
|
| 24 |
+
parser = argparse.ArgumentParser(description='Process a pickle file')
|
| 25 |
+
parser.add_argument('picklefile', type=str, help='Path to pickle file')
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
|
| 28 |
+
main(args.picklefile)
|
train.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.metrics import r2_score
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
import pickle
|
| 10 |
+
|
| 11 |
+
from utils_data import read_data, MolDataset, collate_mol, get_train_test_data
|
| 12 |
+
from utils_model import ModelDimeNet
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_model():
|
| 16 |
+
model = ModelDimeNet()
|
| 17 |
+
return model
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_optimizer(model, e_start):
|
| 21 |
+
optimizer = torch.optim.RMSprop(model.parameters(), lr=10 ** -e_start)
|
| 22 |
+
|
| 23 |
+
return optimizer
|
| 24 |
+
|
| 25 |
+
def get_loss(mode):
|
| 26 |
+
if mode[0] == 'mae':
|
| 27 |
+
return lambda pred, en: (pred - en).abs().mean()
|
| 28 |
+
|
| 29 |
+
if mode[0] == 'adaptive':
|
| 30 |
+
return lambda pred, en: ((pred - en).abs() / (en.abs() + 1e-5) ** mode[1]).mean()
|
| 31 |
+
|
| 32 |
+
def train_epoch(model, optimizer, dl_train, loss_fn, device):
|
| 33 |
+
model.train()
|
| 34 |
+
|
| 35 |
+
for atoms, coords, energy, batch in dl_train:
|
| 36 |
+
optimizer.zero_grad()
|
| 37 |
+
|
| 38 |
+
atoms = atoms.to(device)
|
| 39 |
+
coords = coords.to(device)
|
| 40 |
+
energy = energy.to(device)
|
| 41 |
+
batch = batch.to(device)
|
| 42 |
+
|
| 43 |
+
en = energy.squeeze()
|
| 44 |
+
pred = model(atoms, coords, batch).squeeze()
|
| 45 |
+
loss = loss_fn(pred, en)
|
| 46 |
+
loss.backward()
|
| 47 |
+
optimizer.step()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_epoch(model, optimizer, dl_test, device):
|
| 51 |
+
all_loss = 0
|
| 52 |
+
all_mols = 0
|
| 53 |
+
all_preds = []
|
| 54 |
+
all_trues = []
|
| 55 |
+
model.eval()
|
| 56 |
+
|
| 57 |
+
for atoms, coords, energy, batch in dl_test:
|
| 58 |
+
atoms = atoms.to(device)
|
| 59 |
+
coords = coords.to(device)
|
| 60 |
+
energy = energy.to(device)
|
| 61 |
+
batch = batch.to(device)
|
| 62 |
+
|
| 63 |
+
en = energy.squeeze()
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
pred = model(atoms, coords, batch).squeeze()
|
| 66 |
+
all_preds.append(pred.cpu().numpy())
|
| 67 |
+
all_trues.append(en.cpu().numpy())
|
| 68 |
+
all_loss += F.l1_loss(pred.squeeze(), en).item() * len(pred)
|
| 69 |
+
all_mols += len(pred)
|
| 70 |
+
|
| 71 |
+
all_trues = np.concatenate(all_trues)
|
| 72 |
+
all_preds = np.concatenate(all_preds)
|
| 73 |
+
|
| 74 |
+
return {
|
| 75 |
+
'r2_score': r2_score(np.array(all_trues), np.array(all_preds)),
|
| 76 |
+
'mae': all_loss / all_mols,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
def refresh_lr(optimizer, i, n, e_start, downscale=2.0):
|
| 80 |
+
for g in optimizer.param_groups:
|
| 81 |
+
g['lr'] = 10 ** -(e_start + i / n * downscale)
|
| 82 |
+
|
| 83 |
+
return 10 ** -(e_start + i / n * downscale)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def train(n_epoch, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix):
|
| 87 |
+
all_metrics = []
|
| 88 |
+
new_lr = e_start
|
| 89 |
+
|
| 90 |
+
for i in tqdm(range(n_epoch)):
|
| 91 |
+
train_epoch(model, optimizer, dl_train, loss_fn, device)
|
| 92 |
+
|
| 93 |
+
metrics = test_epoch(model, optimizer, dl_test, device)
|
| 94 |
+
|
| 95 |
+
cur_lr = new_lr
|
| 96 |
+
new_lr = refresh_lr(optimizer, i, n_epoch, e_start)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
all_metrics.append((
|
| 100 |
+
metrics['r2_score'],
|
| 101 |
+
metrics['mae'],
|
| 102 |
+
cur_lr,
|
| 103 |
+
))
|
| 104 |
+
|
| 105 |
+
torch.save(model.state_dict(), checkpoint_prefix + '.ckpt')
|
| 106 |
+
|
| 107 |
+
return all_metrics
|
| 108 |
+
|
| 109 |
+
def main(loss_mode, normalize, pretrain, checkpoint_prefix, data_filename):
|
| 110 |
+
all_numbers, all_coords, energies, groups = read_data(data_filename)
|
| 111 |
+
ds_all = MolDataset(all_numbers, all_coords, energies, normalize=normalize)
|
| 112 |
+
|
| 113 |
+
loss_fn = get_loss(loss_mode)
|
| 114 |
+
|
| 115 |
+
model = get_model()
|
| 116 |
+
|
| 117 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 118 |
+
model = model.to(device)
|
| 119 |
+
|
| 120 |
+
# Pretraining
|
| 121 |
+
if pretrain:
|
| 122 |
+
e_start = 4
|
| 123 |
+
|
| 124 |
+
ds_train, ds_test = get_train_test_data(ds_all, groups, 'pretrain')
|
| 125 |
+
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=32, shuffle=True, collate_fn=collate_mol)
|
| 126 |
+
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=32, shuffle=False, collate_fn=collate_mol)
|
| 127 |
+
|
| 128 |
+
optimizer = get_optimizer(model, e_start=e_start)
|
| 129 |
+
|
| 130 |
+
all_metrics = train(100, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix + '_pretrain_model')
|
| 131 |
+
with open(checkpoint_prefix + '_pretrain_metrics.pkl', 'wb') as f:
|
| 132 |
+
pickle.dump(all_metrics, f)
|
| 133 |
+
|
| 134 |
+
# Fine-tuting
|
| 135 |
+
e_start = 5
|
| 136 |
+
|
| 137 |
+
ds_train, ds_test = get_train_test_data(ds_all, groups, 'finetune')
|
| 138 |
+
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=32, shuffle=True, collate_fn=collate_mol)
|
| 139 |
+
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=32, shuffle=False, collate_fn=collate_mol)
|
| 140 |
+
|
| 141 |
+
optimizer = get_optimizer(model, e_start=e_start)
|
| 142 |
+
|
| 143 |
+
all_metrics = train(100, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix + '_finetune_model')
|
| 144 |
+
with open(checkpoint_prefix + '_finetune_metrics.pkl', 'wb') as f:
|
| 145 |
+
pickle.dump(all_metrics, f)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
parser = argparse.ArgumentParser(description='Параметры для обучения модели')
|
| 150 |
+
|
| 151 |
+
# Обязательные аргументы
|
| 152 |
+
parser.add_argument('loss_mode',
|
| 153 |
+
choices=['mae', 'adaptive'],
|
| 154 |
+
help="Режим потерь: 'mae' или 'adaptive'")
|
| 155 |
+
parser.add_argument('checkpoint_prefix',
|
| 156 |
+
type=str,
|
| 157 |
+
help="Префикс для чекпоинтов")
|
| 158 |
+
parser.add_argument('data_filename',
|
| 159 |
+
type=str,
|
| 160 |
+
help="Путь к файлу с датасетом")
|
| 161 |
+
|
| 162 |
+
# Флаги (булевые параметры)
|
| 163 |
+
parser.add_argument('--normalize',
|
| 164 |
+
action='store_true',
|
| 165 |
+
help="Применить нормализацию (только для loss_mode='mae')")
|
| 166 |
+
parser.add_argument('--pretrain',
|
| 167 |
+
action='store_true',
|
| 168 |
+
help="Использовать предобучение")
|
| 169 |
+
|
| 170 |
+
# Параметр только для adaptive режима
|
| 171 |
+
parser.add_argument('--loss_k',
|
| 172 |
+
type=float,
|
| 173 |
+
default=None,
|
| 174 |
+
help="Коэффициент k для adaptive loss (требуется при loss_mode='adaptive')")
|
| 175 |
+
|
| 176 |
+
args = parser.parse_args()
|
| 177 |
+
|
| 178 |
+
# Проверка совместимости параметров
|
| 179 |
+
if args.loss_mode == 'adaptive':
|
| 180 |
+
if args.normalize:
|
| 181 |
+
raise ValueError("Параметр --normalize несовместим с loss_mode='adaptive'")
|
| 182 |
+
if args.loss_k is None:
|
| 183 |
+
raise ValueError("Для adaptive loss требуется параметр --loss_k")
|
| 184 |
+
# Формируем кортеж для adaptive режима
|
| 185 |
+
loss_mode_arg = ('adaptive', args.loss_k)
|
| 186 |
+
else: # loss_mode == 'mae'
|
| 187 |
+
if args.loss_k is not None:
|
| 188 |
+
raise ValueError("Параметр --loss_k можно использовать только с loss_mode='adaptive'")
|
| 189 |
+
loss_mode_arg = ('mae', )
|
| 190 |
+
|
| 191 |
+
# Вызов основной функции
|
| 192 |
+
main(
|
| 193 |
+
loss_mode=loss_mode_arg,
|
| 194 |
+
normalize=args.normalize,
|
| 195 |
+
pretrain=args.pretrain,
|
| 196 |
+
checkpoint_prefix=args.checkpoint_prefix,
|
| 197 |
+
data_filename=args.data_filename,
|
| 198 |
+
)
|
utils_data.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def read_data(filename):
|
| 6 |
+
all_coords = []
|
| 7 |
+
all_numbers = []
|
| 8 |
+
|
| 9 |
+
with open(filename) as f:
|
| 10 |
+
cont = f.read()
|
| 11 |
+
|
| 12 |
+
energies = []
|
| 13 |
+
groups = []
|
| 14 |
+
lines = cont.split('\n'); i = 0
|
| 15 |
+
mol_en = None
|
| 16 |
+
while i < len(lines):
|
| 17 |
+
try:
|
| 18 |
+
n = int(lines[i].strip())
|
| 19 |
+
except ValueError:
|
| 20 |
+
break
|
| 21 |
+
comment = lines[i+1]
|
| 22 |
+
energy = float(re.findall('Energy\\:\\s+(-?\\d*\.\\d*)', comment)[0])
|
| 23 |
+
g0 = re.findall('Grid: 0', comment)
|
| 24 |
+
if g0:
|
| 25 |
+
mol_en = energy
|
| 26 |
+
grp = ()
|
| 27 |
+
g1 = re.findall('Grid: (\\d+): (\\d+)', comment)
|
| 28 |
+
if g1:
|
| 29 |
+
grp = (g1[0][1], )
|
| 30 |
+
g2 = re.findall('Grid: (\\d+): (\\d+), (\\d+): (\\d+)', comment)
|
| 31 |
+
if g2:
|
| 32 |
+
grp = (g2[0][1], g2[0][3])
|
| 33 |
+
energies.append(energy - mol_en)
|
| 34 |
+
groups.append(grp)
|
| 35 |
+
j = 0
|
| 36 |
+
all_coords.append([])
|
| 37 |
+
all_numbers.append([])
|
| 38 |
+
while j < n:
|
| 39 |
+
at, x, y, z = list(filter(lambda x: x != '', lines[i+j+2].strip().split(' ')))
|
| 40 |
+
all_coords[-1].append((float(x), float(y), float(z)))
|
| 41 |
+
all_numbers[-1].append(int(at))
|
| 42 |
+
j += 1
|
| 43 |
+
i += n + 2
|
| 44 |
+
|
| 45 |
+
energies = torch.tensor(energies)
|
| 46 |
+
|
| 47 |
+
return all_numbers, all_coords, energies, groups
|
| 48 |
+
|
| 49 |
+
class MolDataset(torch.utils.data.Dataset):
|
| 50 |
+
def __init__(self, all_numbers, all_coords, energies, normalize=False):
|
| 51 |
+
self.numbers = all_numbers
|
| 52 |
+
self.coords = all_coords
|
| 53 |
+
self.energies = energies
|
| 54 |
+
self.normalize = normalize
|
| 55 |
+
|
| 56 |
+
def __len__(self):
|
| 57 |
+
return len(self.energies)
|
| 58 |
+
|
| 59 |
+
def __getitem__(self, ind):
|
| 60 |
+
energy = self.energies[ind]
|
| 61 |
+
|
| 62 |
+
atoms = torch.tensor(self.numbers[ind])
|
| 63 |
+
coords = torch.tensor(self.coords[ind], dtype=torch.float32)
|
| 64 |
+
|
| 65 |
+
if self.normalize:
|
| 66 |
+
energy = energy.sign() * energy.abs() ** 0.1
|
| 67 |
+
|
| 68 |
+
return atoms, coords, energy
|
| 69 |
+
|
| 70 |
+
def collate_mol(batch):
|
| 71 |
+
"""
|
| 72 |
+
Collate function for molecular dataset.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
batch: List of tuples (atoms, coords, energy) from MolDataset
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
atoms_cat: Concatenated atomic numbers tensor of shape [total_atoms]
|
| 79 |
+
coords_cat: Concatenated coordinates tensor of shape [total_atoms, 3]
|
| 80 |
+
energies: Energy tensor of shape [batch_size]
|
| 81 |
+
batch_tensor: Batch indices tensor of shape [total_atoms]
|
| 82 |
+
"""
|
| 83 |
+
atoms_list = []
|
| 84 |
+
coords_list = []
|
| 85 |
+
energies_list = []
|
| 86 |
+
batch_indices = []
|
| 87 |
+
|
| 88 |
+
# Process each molecule in the batch
|
| 89 |
+
for i, (atoms, coords, energy) in enumerate(batch):
|
| 90 |
+
n_atoms = atoms.size(0)
|
| 91 |
+
|
| 92 |
+
# Store components
|
| 93 |
+
atoms_list.append(atoms)
|
| 94 |
+
coords_list.append(coords)
|
| 95 |
+
energies_list.append(energy)
|
| 96 |
+
|
| 97 |
+
# Create batch indices: [i, i, ..., i] for n_atoms times
|
| 98 |
+
batch_indices.append(torch.full((n_atoms,), i, dtype=torch.long))
|
| 99 |
+
|
| 100 |
+
# Concatenate all components
|
| 101 |
+
atoms_cat = torch.cat(atoms_list, dim=0) # shape: [total_atoms]
|
| 102 |
+
coords_cat = torch.cat(coords_list, dim=0) # shape: [total_atoms, 3]
|
| 103 |
+
energies = torch.stack(energies_list) # shape: [batch_size]
|
| 104 |
+
batch_tensor = torch.cat(batch_indices, dim=0) # shape: [total_atoms]
|
| 105 |
+
|
| 106 |
+
return atoms_cat, coords_cat, energies, batch_tensor
|
| 107 |
+
|
| 108 |
+
def get_train_test_data(ds_all, groups, mode, test_idcs=range(28986, 29803)):
|
| 109 |
+
grid1_selection = ['1', '3', '5', '7', '8', '9', '10', '12', '14', '16']
|
| 110 |
+
grid2_selection = ['1', '5', '8', '9', '12', '16']
|
| 111 |
+
|
| 112 |
+
assert(mode in ['pretrain', 'finetune'])
|
| 113 |
+
|
| 114 |
+
pretrain = mode == 'pretrain'
|
| 115 |
+
|
| 116 |
+
train_idces = []
|
| 117 |
+
|
| 118 |
+
for i in range(len(groups)):
|
| 119 |
+
if len(groups[i]) == 0:
|
| 120 |
+
if (i in test_idcs) != pretrain: train_idces.append(i)
|
| 121 |
+
elif len(groups[i]) == 1:
|
| 122 |
+
if pretrain or groups[i][0] in grid1_selection:
|
| 123 |
+
if (i in test_idcs) != pretrain: train_idces.append(i)
|
| 124 |
+
elif len(groups[i]) == 2:
|
| 125 |
+
if pretrain or groups[i][0] in grid2_selection and groups[i][1] in grid2_selection:
|
| 126 |
+
if (i in test_idcs) != pretrain: train_idces.append(i)
|
| 127 |
+
|
| 128 |
+
ds_train = torch.utils.data.Subset(ds_all, train_idces)
|
| 129 |
+
ds_test = torch.utils.data.Subset(ds_all, test_idcs)
|
| 130 |
+
|
| 131 |
+
return ds_train, ds_test
|
utils_model.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch_geometric
|
| 4 |
+
from torch_geometric.nn.models import DimeNetPlusPlus
|
| 5 |
+
|
| 6 |
+
class ModelDimeNet(nn.Module):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
self.net = DimeNetPlusPlus(hidden_channels=256, out_channels=256, num_blocks=4, num_spherical=8, num_radial=8, int_emb_size=64, basis_emb_size=64, out_emb_channels=64)
|
| 11 |
+
self.head = nn.Linear(256, 1)
|
| 12 |
+
|
| 13 |
+
def forward(self, atoms, coords, batch):
|
| 14 |
+
emb = self.net(atoms, coords, batch)
|
| 15 |
+
return self.head(emb)
|