File size: 2,730 Bytes
5fae7ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import argparse

import numpy as np

from utils_data import read_data, MolDataset, collate_mol, get_train_test_data
from utils_model import ModelDimeNet

def main(denormalize, checkpoint_path, data_filename):
    model = ModelDimeNet()
    model.load_state_dict(torch.load(checkpoint_path, weights_only=True))
    
    all_numbers, all_coords, energies, groups = read_data(data_filename)
    ds_all = MolDataset(all_numbers, all_coords, energies, normalize=denormalize)

    ds_train, ds_test = get_train_test_data(ds_all, groups, 'finetune')

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model.to(device);
    
    ens = []

    for numbers, coords, energy in ds_test:
        if len(ens) in ds_train.indices:
            ens.append(energy)
        else:
            coords = torch.tensor(coords, dtype=torch.float32).to(device)
            atoms = torch.tensor(numbers).to(device)
            batch = torch.zeros_like(atoms).to(device)

            with torch.no_grad():
                ens.append(model(atoms, coords, batch).item())


    if denormalize:
        ens = ens.sign() * ens.abs() ** 10

    ensa = np.array(ens[1:])

    n_atoms = 3
    n_modes = n_atoms * 3 - 6
    modes_i = []

    for i in range(n_modes):
        modes_i.append(ensa[0:][16*i:16*i+16])
    
    all_a = []

    m = 0
    for i in range(n_modes):
        for j in range(i + 1, n_modes):
            for k in range(16):
                all_a.append(ensa[16*n_modes+m:16*n_modes+m+16] - modes_i[j] - modes_i[i][k])
                m += 16

    arr = np.concatenate([modes_i, all_a]).reshape((-1))

    with open('template.rst', 'r') as f:
        content = f.read()

    fo = open('filled.rst', 'w')

    i = 0
    for line in content.split('\n'):
        if i < len(arr):
            new_line = line.replace('{}', f'{arr[i]:.10f}')
            if line != new_line:
                i += 1
        else:
            new_line = line
        fo.write(new_line + '\n')

    del fo

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    
    # Обязательные аргументы

    parser.add_argument('checkpoint_path', 
                        type=str,
                       )
    parser.add_argument('data_filename', 
                        type=str,
                       )
    
    # Флаги (булевые параметры)
    parser.add_argument('--denormalize', 
                        action='store_true',
                       )
    
    
    args = parser.parse_args()
    
    
    # Вызов основной функции
    main(
        denormalize=args.denormalize,
        checkpoint_path=args.checkpoint_path,
        data_filename=args.data_filename
    )