File size: 4,272 Bytes
5fae7ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import re
import torch
def read_data(filename):
all_coords = []
all_numbers = []
with open(filename) as f:
cont = f.read()
energies = []
groups = []
lines = cont.split('\n'); i = 0
mol_en = None
while i < len(lines):
try:
n = int(lines[i].strip())
except ValueError:
break
comment = lines[i+1]
energy = float(re.findall('Energy\\:\\s+(-?\\d*\.\\d*)', comment)[0])
g0 = re.findall('Grid: 0', comment)
if g0:
mol_en = energy
grp = ()
g1 = re.findall('Grid: (\\d+): (\\d+)', comment)
if g1:
grp = (g1[0][1], )
g2 = re.findall('Grid: (\\d+): (\\d+), (\\d+): (\\d+)', comment)
if g2:
grp = (g2[0][1], g2[0][3])
energies.append(energy - mol_en)
groups.append(grp)
j = 0
all_coords.append([])
all_numbers.append([])
while j < n:
at, x, y, z = list(filter(lambda x: x != '', lines[i+j+2].strip().split(' ')))
all_coords[-1].append((float(x), float(y), float(z)))
all_numbers[-1].append(int(at))
j += 1
i += n + 2
energies = torch.tensor(energies)
return all_numbers, all_coords, energies, groups
class MolDataset(torch.utils.data.Dataset):
def __init__(self, all_numbers, all_coords, energies, normalize=False):
self.numbers = all_numbers
self.coords = all_coords
self.energies = energies
self.normalize = normalize
def __len__(self):
return len(self.energies)
def __getitem__(self, ind):
energy = self.energies[ind]
atoms = torch.tensor(self.numbers[ind])
coords = torch.tensor(self.coords[ind], dtype=torch.float32)
if self.normalize:
energy = energy.sign() * energy.abs() ** 0.1
return atoms, coords, energy
def collate_mol(batch):
"""
Collate function for molecular dataset.
Args:
batch: List of tuples (atoms, coords, energy) from MolDataset
Returns:
atoms_cat: Concatenated atomic numbers tensor of shape [total_atoms]
coords_cat: Concatenated coordinates tensor of shape [total_atoms, 3]
energies: Energy tensor of shape [batch_size]
batch_tensor: Batch indices tensor of shape [total_atoms]
"""
atoms_list = []
coords_list = []
energies_list = []
batch_indices = []
# Process each molecule in the batch
for i, (atoms, coords, energy) in enumerate(batch):
n_atoms = atoms.size(0)
# Store components
atoms_list.append(atoms)
coords_list.append(coords)
energies_list.append(energy)
# Create batch indices: [i, i, ..., i] for n_atoms times
batch_indices.append(torch.full((n_atoms,), i, dtype=torch.long))
# Concatenate all components
atoms_cat = torch.cat(atoms_list, dim=0) # shape: [total_atoms]
coords_cat = torch.cat(coords_list, dim=0) # shape: [total_atoms, 3]
energies = torch.stack(energies_list) # shape: [batch_size]
batch_tensor = torch.cat(batch_indices, dim=0) # shape: [total_atoms]
return atoms_cat, coords_cat, energies, batch_tensor
def get_train_test_data(ds_all, groups, mode, test_idcs=range(28986, 29803)):
grid1_selection = ['1', '3', '5', '7', '8', '9', '10', '12', '14', '16']
grid2_selection = ['1', '5', '8', '9', '12', '16']
assert(mode in ['pretrain', 'finetune'])
pretrain = mode == 'pretrain'
train_idces = []
for i in range(len(groups)):
if len(groups[i]) == 0:
if (i in test_idcs) != pretrain: train_idces.append(i)
elif len(groups[i]) == 1:
if pretrain or groups[i][0] in grid1_selection:
if (i in test_idcs) != pretrain: train_idces.append(i)
elif len(groups[i]) == 2:
if pretrain or groups[i][0] in grid2_selection and groups[i][1] in grid2_selection:
if (i in test_idcs) != pretrain: train_idces.append(i)
ds_train = torch.utils.data.Subset(ds_all, train_idces)
ds_test = torch.utils.data.Subset(ds_all, test_idcs)
return ds_train, ds_test |