vscf_mlff / example_eval.py
timcryt's picture
Initial commit
5fae7ca verified
import torch
import argparse
import numpy as np
from utils_data import read_data, MolDataset, collate_mol, get_train_test_data
from utils_model import ModelDimeNet
def main(denormalize, checkpoint_path, data_filename):
model = ModelDimeNet()
model.load_state_dict(torch.load(checkpoint_path, weights_only=True))
all_numbers, all_coords, energies, groups = read_data(data_filename)
ds_all = MolDataset(all_numbers, all_coords, energies, normalize=denormalize)
ds_train, ds_test = get_train_test_data(ds_all, groups, 'finetune')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device);
ens = []
for numbers, coords, energy in ds_test:
if len(ens) in ds_train.indices:
ens.append(energy)
else:
coords = torch.tensor(coords, dtype=torch.float32).to(device)
atoms = torch.tensor(numbers).to(device)
batch = torch.zeros_like(atoms).to(device)
with torch.no_grad():
ens.append(model(atoms, coords, batch).item())
if denormalize:
ens = ens.sign() * ens.abs() ** 10
ensa = np.array(ens[1:])
n_atoms = 3
n_modes = n_atoms * 3 - 6
modes_i = []
for i in range(n_modes):
modes_i.append(ensa[0:][16*i:16*i+16])
all_a = []
m = 0
for i in range(n_modes):
for j in range(i + 1, n_modes):
for k in range(16):
all_a.append(ensa[16*n_modes+m:16*n_modes+m+16] - modes_i[j] - modes_i[i][k])
m += 16
arr = np.concatenate([modes_i, all_a]).reshape((-1))
with open('template.rst', 'r') as f:
content = f.read()
fo = open('filled.rst', 'w')
i = 0
for line in content.split('\n'):
if i < len(arr):
new_line = line.replace('{}', f'{arr[i]:.10f}')
if line != new_line:
i += 1
else:
new_line = line
fo.write(new_line + '\n')
del fo
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Обязательные аргументы
parser.add_argument('checkpoint_path',
type=str,
)
parser.add_argument('data_filename',
type=str,
)
# Флаги (булевые параметры)
parser.add_argument('--denormalize',
action='store_true',
)
args = parser.parse_args()
# Вызов основной функции
main(
denormalize=args.denormalize,
checkpoint_path=args.checkpoint_path,
data_filename=args.data_filename
)