|
|
import torch |
|
|
import argparse |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from utils_data import read_data, MolDataset, collate_mol, get_train_test_data |
|
|
from utils_model import ModelDimeNet |
|
|
|
|
|
def main(denormalize, checkpoint_path, data_filename): |
|
|
model = ModelDimeNet() |
|
|
model.load_state_dict(torch.load(checkpoint_path, weights_only=True)) |
|
|
|
|
|
all_numbers, all_coords, energies, groups = read_data(data_filename) |
|
|
ds_all = MolDataset(all_numbers, all_coords, energies, normalize=denormalize) |
|
|
|
|
|
ds_train, ds_test = get_train_test_data(ds_all, groups, 'finetune') |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
model.to(device); |
|
|
|
|
|
ens = [] |
|
|
|
|
|
for numbers, coords, energy in ds_test: |
|
|
if len(ens) in ds_train.indices: |
|
|
ens.append(energy) |
|
|
else: |
|
|
coords = torch.tensor(coords, dtype=torch.float32).to(device) |
|
|
atoms = torch.tensor(numbers).to(device) |
|
|
batch = torch.zeros_like(atoms).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
ens.append(model(atoms, coords, batch).item()) |
|
|
|
|
|
|
|
|
if denormalize: |
|
|
ens = ens.sign() * ens.abs() ** 10 |
|
|
|
|
|
ensa = np.array(ens[1:]) |
|
|
|
|
|
n_atoms = 3 |
|
|
n_modes = n_atoms * 3 - 6 |
|
|
modes_i = [] |
|
|
|
|
|
for i in range(n_modes): |
|
|
modes_i.append(ensa[0:][16*i:16*i+16]) |
|
|
|
|
|
all_a = [] |
|
|
|
|
|
m = 0 |
|
|
for i in range(n_modes): |
|
|
for j in range(i + 1, n_modes): |
|
|
for k in range(16): |
|
|
all_a.append(ensa[16*n_modes+m:16*n_modes+m+16] - modes_i[j] - modes_i[i][k]) |
|
|
m += 16 |
|
|
|
|
|
arr = np.concatenate([modes_i, all_a]).reshape((-1)) |
|
|
|
|
|
with open('template.rst', 'r') as f: |
|
|
content = f.read() |
|
|
|
|
|
fo = open('filled.rst', 'w') |
|
|
|
|
|
i = 0 |
|
|
for line in content.split('\n'): |
|
|
if i < len(arr): |
|
|
new_line = line.replace('{}', f'{arr[i]:.10f}') |
|
|
if line != new_line: |
|
|
i += 1 |
|
|
else: |
|
|
new_line = line |
|
|
fo.write(new_line + '\n') |
|
|
|
|
|
del fo |
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('checkpoint_path', |
|
|
type=str, |
|
|
) |
|
|
parser.add_argument('data_filename', |
|
|
type=str, |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument('--denormalize', |
|
|
action='store_true', |
|
|
) |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
main( |
|
|
denormalize=args.denormalize, |
|
|
checkpoint_path=args.checkpoint_path, |
|
|
data_filename=args.data_filename |
|
|
) |