vscf_mlff / compare_models.py
timcryt's picture
Initial commit
5fae7ca verified
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn.models import SchNet, DimeNetPlusPlus
import ase
import ase.io
import re
from sklearn.metrics import r2_score
import numpy as np
import sys
from tqdm import tqdm
import argparse
from utils_model import ModellDimeNet
def get_model_and_optimizer(model_type):
if model_type == 'SchNet':
model = SchNet()
elif model_type == 'DimeNet':
model = ModelDimeNet()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
return model, optimizer
def train_epoch(model, optimizer, geoms, energies_n, mean_grad=32):
j = 0
model.train();
for geom, energy in zip(geoms, energies_n.clone().detach()):
if j == 0:
optimizer.zero_grad()
coords = torch.tensor(geom.get_positions(), dtype=torch.float32)
atoms = torch.tensor(geom.get_atomic_numbers())
batch = torch.zeros_like(atoms)
en = energy.clone().detach()
pred = model(atoms, coords, batch)
loss = F.huber_loss(pred.squeeze(), en)
(loss / mean_grad).backward(); j += 1
if j == mean_grad - 1:
optimizer.step()
j = 0
def test_epoch(model, optimizer, geoms, energies_n):
all_loss = 0
all_mols = 0
all_preds = []
all_trues = []
model.eval();
for geom, energy in zip(geoms, energies_n.clone().detach()):
coords = torch.tensor(geom.get_positions(), dtype=torch.float32)
atoms = torch.tensor(geom.get_atomic_numbers())
batch = torch.zeros_like(atoms)
en = energy.clone().detach()
with torch.no_grad():
pred = model(atoms, coords, batch)
all_preds.append(pred.item())
all_trues.append(en.item())
all_loss += F.l1_loss(pred.squeeze(), en).item()
all_mols += 1
return {
'r2_score': r2_score(np.array(all_trues), np.array(all_preds)),
'mae': all_loss / all_mols,
}
def train(model, optimizer, geoms, energies_n, n_epochs=100):
best_r2score = -1e100
best_mae = 1e100
for i in tqdm(range(n_epochs)):
train_epoch(model, optimizer, geoms, energies_n)
metrics = test_epoch(model, optimizer, geoms, energies_n)
if best_r2score < metrics['r2_score']:
best_r2score = metrics['r2_score']
if best_mae > metrics['mae']:
best_mae = metrics['mae']
return best_r2score, best_mae
def main(trajectory_file, model_type):
geoms = ase.io.read(trajectory_file, format='xyz', index=':')
with open(trajectory_file) as f:
cont = f.read()
energies = []
lines = cont.split('\n'); i = 0
while i < len(lines):
try:
n = int(lines[i].strip())
except ValueError:
break
comment = lines[i+1]
energy = float(re.findall('energy\\:\\s+(-?\\d*\.\\d*)', comment)[0])
energies.append(energy)
i += n + 2
energies = torch.tensor(energies)
energies_n = (energies - energies.min()) * 627.5
model, optimizer = get_model_and_optimizer(model_type)
best_r2score, best_mae = train(model, optimizer, geoms, energies_n)
print(f'R2_score: {best_r2score:.4f}')
print(f'MAE: {best_mae:.3f} kcal/mol')
avaliable_models = ['SchNet', 'DimeNet']
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Обработчик файлов с различными моделями")
parser.add_argument("filename", help="Путь к обрабатываемому файлу")
parser.add_argument("model",
choices=avaliable_models,
help=f"Выбор модели из доступных: {', '.join(avaliable_models)}")
args = parser.parse_args()
main(args.filename, args.model)