import torch import argparse import numpy as np from utils_data import read_data, MolDataset, collate_mol, get_train_test_data from utils_model import ModelDimeNet def main(denormalize, checkpoint_path, data_filename): model = ModelDimeNet() model.load_state_dict(torch.load(checkpoint_path, weights_only=True)) all_numbers, all_coords, energies, groups = read_data(data_filename) ds_all = MolDataset(all_numbers, all_coords, energies, normalize=denormalize) ds_train, ds_test = get_train_test_data(ds_all, groups, 'finetune') device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device); ens = [] for numbers, coords, energy in ds_test: if len(ens) in ds_train.indices: ens.append(energy) else: coords = torch.tensor(coords, dtype=torch.float32).to(device) atoms = torch.tensor(numbers).to(device) batch = torch.zeros_like(atoms).to(device) with torch.no_grad(): ens.append(model(atoms, coords, batch).item()) if denormalize: ens = ens.sign() * ens.abs() ** 10 ensa = np.array(ens[1:]) n_atoms = 3 n_modes = n_atoms * 3 - 6 modes_i = [] for i in range(n_modes): modes_i.append(ensa[0:][16*i:16*i+16]) all_a = [] m = 0 for i in range(n_modes): for j in range(i + 1, n_modes): for k in range(16): all_a.append(ensa[16*n_modes+m:16*n_modes+m+16] - modes_i[j] - modes_i[i][k]) m += 16 arr = np.concatenate([modes_i, all_a]).reshape((-1)) with open('template.rst', 'r') as f: content = f.read() fo = open('filled.rst', 'w') i = 0 for line in content.split('\n'): if i < len(arr): new_line = line.replace('{}', f'{arr[i]:.10f}') if line != new_line: i += 1 else: new_line = line fo.write(new_line + '\n') del fo if __name__ == '__main__': parser = argparse.ArgumentParser() # Обязательные аргументы parser.add_argument('checkpoint_path', type=str, ) parser.add_argument('data_filename', type=str, ) # Флаги (булевые параметры) parser.add_argument('--denormalize', action='store_true', ) args = parser.parse_args() # Вызов основной функции main( denormalize=args.denormalize, checkpoint_path=args.checkpoint_path, data_filename=args.data_filename )