|
|
import re |
|
|
import torch |
|
|
|
|
|
|
|
|
def read_data(filename): |
|
|
all_coords = [] |
|
|
all_numbers = [] |
|
|
|
|
|
with open(filename) as f: |
|
|
cont = f.read() |
|
|
|
|
|
energies = [] |
|
|
groups = [] |
|
|
lines = cont.split('\n'); i = 0 |
|
|
mol_en = None |
|
|
while i < len(lines): |
|
|
try: |
|
|
n = int(lines[i].strip()) |
|
|
except ValueError: |
|
|
break |
|
|
comment = lines[i+1] |
|
|
energy = float(re.findall('Energy\\:\\s+(-?\\d*\.\\d*)', comment)[0]) |
|
|
g0 = re.findall('Grid: 0', comment) |
|
|
if g0: |
|
|
mol_en = energy |
|
|
grp = () |
|
|
g1 = re.findall('Grid: (\\d+): (\\d+)', comment) |
|
|
if g1: |
|
|
grp = (g1[0][1], ) |
|
|
g2 = re.findall('Grid: (\\d+): (\\d+), (\\d+): (\\d+)', comment) |
|
|
if g2: |
|
|
grp = (g2[0][1], g2[0][3]) |
|
|
energies.append(energy - mol_en) |
|
|
groups.append(grp) |
|
|
j = 0 |
|
|
all_coords.append([]) |
|
|
all_numbers.append([]) |
|
|
while j < n: |
|
|
at, x, y, z = list(filter(lambda x: x != '', lines[i+j+2].strip().split(' '))) |
|
|
all_coords[-1].append((float(x), float(y), float(z))) |
|
|
all_numbers[-1].append(int(at)) |
|
|
j += 1 |
|
|
i += n + 2 |
|
|
|
|
|
energies = torch.tensor(energies) |
|
|
|
|
|
return all_numbers, all_coords, energies, groups |
|
|
|
|
|
class MolDataset(torch.utils.data.Dataset): |
|
|
def __init__(self, all_numbers, all_coords, energies, normalize=False): |
|
|
self.numbers = all_numbers |
|
|
self.coords = all_coords |
|
|
self.energies = energies |
|
|
self.normalize = normalize |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.energies) |
|
|
|
|
|
def __getitem__(self, ind): |
|
|
energy = self.energies[ind] |
|
|
|
|
|
atoms = torch.tensor(self.numbers[ind]) |
|
|
coords = torch.tensor(self.coords[ind], dtype=torch.float32) |
|
|
|
|
|
if self.normalize: |
|
|
energy = energy.sign() * energy.abs() ** 0.1 |
|
|
|
|
|
return atoms, coords, energy |
|
|
|
|
|
def collate_mol(batch): |
|
|
""" |
|
|
Collate function for molecular dataset. |
|
|
|
|
|
Args: |
|
|
batch: List of tuples (atoms, coords, energy) from MolDataset |
|
|
|
|
|
Returns: |
|
|
atoms_cat: Concatenated atomic numbers tensor of shape [total_atoms] |
|
|
coords_cat: Concatenated coordinates tensor of shape [total_atoms, 3] |
|
|
energies: Energy tensor of shape [batch_size] |
|
|
batch_tensor: Batch indices tensor of shape [total_atoms] |
|
|
""" |
|
|
atoms_list = [] |
|
|
coords_list = [] |
|
|
energies_list = [] |
|
|
batch_indices = [] |
|
|
|
|
|
|
|
|
for i, (atoms, coords, energy) in enumerate(batch): |
|
|
n_atoms = atoms.size(0) |
|
|
|
|
|
|
|
|
atoms_list.append(atoms) |
|
|
coords_list.append(coords) |
|
|
energies_list.append(energy) |
|
|
|
|
|
|
|
|
batch_indices.append(torch.full((n_atoms,), i, dtype=torch.long)) |
|
|
|
|
|
|
|
|
atoms_cat = torch.cat(atoms_list, dim=0) |
|
|
coords_cat = torch.cat(coords_list, dim=0) |
|
|
energies = torch.stack(energies_list) |
|
|
batch_tensor = torch.cat(batch_indices, dim=0) |
|
|
|
|
|
return atoms_cat, coords_cat, energies, batch_tensor |
|
|
|
|
|
def get_train_test_data(ds_all, groups, mode, test_idcs=range(28986, 29803)): |
|
|
grid1_selection = ['1', '3', '5', '7', '8', '9', '10', '12', '14', '16'] |
|
|
grid2_selection = ['1', '5', '8', '9', '12', '16'] |
|
|
|
|
|
assert(mode in ['pretrain', 'finetune']) |
|
|
|
|
|
pretrain = mode == 'pretrain' |
|
|
|
|
|
train_idces = [] |
|
|
|
|
|
for i in range(len(groups)): |
|
|
if len(groups[i]) == 0: |
|
|
if (i in test_idcs) != pretrain: train_idces.append(i) |
|
|
elif len(groups[i]) == 1: |
|
|
if pretrain or groups[i][0] in grid1_selection: |
|
|
if (i in test_idcs) != pretrain: train_idces.append(i) |
|
|
elif len(groups[i]) == 2: |
|
|
if pretrain or groups[i][0] in grid2_selection and groups[i][1] in grid2_selection: |
|
|
if (i in test_idcs) != pretrain: train_idces.append(i) |
|
|
|
|
|
ds_train = torch.utils.data.Subset(ds_all, train_idces) |
|
|
ds_test = torch.utils.data.Subset(ds_all, test_idcs) |
|
|
|
|
|
return ds_train, ds_test |