File size: 3,894 Bytes
5fae7ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn.models import SchNet, DimeNetPlusPlus
import ase
import ase.io
import re
from sklearn.metrics import r2_score
import numpy as np
import sys
from tqdm import tqdm
import argparse
from utils_model import ModellDimeNet
def get_model_and_optimizer(model_type):
if model_type == 'SchNet':
model = SchNet()
elif model_type == 'DimeNet':
model = ModelDimeNet()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
return model, optimizer
def train_epoch(model, optimizer, geoms, energies_n, mean_grad=32):
j = 0
model.train();
for geom, energy in zip(geoms, energies_n.clone().detach()):
if j == 0:
optimizer.zero_grad()
coords = torch.tensor(geom.get_positions(), dtype=torch.float32)
atoms = torch.tensor(geom.get_atomic_numbers())
batch = torch.zeros_like(atoms)
en = energy.clone().detach()
pred = model(atoms, coords, batch)
loss = F.huber_loss(pred.squeeze(), en)
(loss / mean_grad).backward(); j += 1
if j == mean_grad - 1:
optimizer.step()
j = 0
def test_epoch(model, optimizer, geoms, energies_n):
all_loss = 0
all_mols = 0
all_preds = []
all_trues = []
model.eval();
for geom, energy in zip(geoms, energies_n.clone().detach()):
coords = torch.tensor(geom.get_positions(), dtype=torch.float32)
atoms = torch.tensor(geom.get_atomic_numbers())
batch = torch.zeros_like(atoms)
en = energy.clone().detach()
with torch.no_grad():
pred = model(atoms, coords, batch)
all_preds.append(pred.item())
all_trues.append(en.item())
all_loss += F.l1_loss(pred.squeeze(), en).item()
all_mols += 1
return {
'r2_score': r2_score(np.array(all_trues), np.array(all_preds)),
'mae': all_loss / all_mols,
}
def train(model, optimizer, geoms, energies_n, n_epochs=100):
best_r2score = -1e100
best_mae = 1e100
for i in tqdm(range(n_epochs)):
train_epoch(model, optimizer, geoms, energies_n)
metrics = test_epoch(model, optimizer, geoms, energies_n)
if best_r2score < metrics['r2_score']:
best_r2score = metrics['r2_score']
if best_mae > metrics['mae']:
best_mae = metrics['mae']
return best_r2score, best_mae
def main(trajectory_file, model_type):
geoms = ase.io.read(trajectory_file, format='xyz', index=':')
with open(trajectory_file) as f:
cont = f.read()
energies = []
lines = cont.split('\n'); i = 0
while i < len(lines):
try:
n = int(lines[i].strip())
except ValueError:
break
comment = lines[i+1]
energy = float(re.findall('energy\\:\\s+(-?\\d*\.\\d*)', comment)[0])
energies.append(energy)
i += n + 2
energies = torch.tensor(energies)
energies_n = (energies - energies.min()) * 627.5
model, optimizer = get_model_and_optimizer(model_type)
best_r2score, best_mae = train(model, optimizer, geoms, energies_n)
print(f'R2_score: {best_r2score:.4f}')
print(f'MAE: {best_mae:.3f} kcal/mol')
avaliable_models = ['SchNet', 'DimeNet']
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Обработчик файлов с различными моделями")
parser.add_argument("filename", help="Путь к обрабатываемому файлу")
parser.add_argument("model",
choices=avaliable_models,
help=f"Выбор модели из доступных: {', '.join(avaliable_models)}")
args = parser.parse_args()
main(args.filename, args.model)
|