vscf_mlff / utils_data.py
timcryt's picture
Initial commit
5fae7ca verified
import re
import torch
def read_data(filename):
all_coords = []
all_numbers = []
with open(filename) as f:
cont = f.read()
energies = []
groups = []
lines = cont.split('\n'); i = 0
mol_en = None
while i < len(lines):
try:
n = int(lines[i].strip())
except ValueError:
break
comment = lines[i+1]
energy = float(re.findall('Energy\\:\\s+(-?\\d*\.\\d*)', comment)[0])
g0 = re.findall('Grid: 0', comment)
if g0:
mol_en = energy
grp = ()
g1 = re.findall('Grid: (\\d+): (\\d+)', comment)
if g1:
grp = (g1[0][1], )
g2 = re.findall('Grid: (\\d+): (\\d+), (\\d+): (\\d+)', comment)
if g2:
grp = (g2[0][1], g2[0][3])
energies.append(energy - mol_en)
groups.append(grp)
j = 0
all_coords.append([])
all_numbers.append([])
while j < n:
at, x, y, z = list(filter(lambda x: x != '', lines[i+j+2].strip().split(' ')))
all_coords[-1].append((float(x), float(y), float(z)))
all_numbers[-1].append(int(at))
j += 1
i += n + 2
energies = torch.tensor(energies)
return all_numbers, all_coords, energies, groups
class MolDataset(torch.utils.data.Dataset):
def __init__(self, all_numbers, all_coords, energies, normalize=False):
self.numbers = all_numbers
self.coords = all_coords
self.energies = energies
self.normalize = normalize
def __len__(self):
return len(self.energies)
def __getitem__(self, ind):
energy = self.energies[ind]
atoms = torch.tensor(self.numbers[ind])
coords = torch.tensor(self.coords[ind], dtype=torch.float32)
if self.normalize:
energy = energy.sign() * energy.abs() ** 0.1
return atoms, coords, energy
def collate_mol(batch):
"""
Collate function for molecular dataset.
Args:
batch: List of tuples (atoms, coords, energy) from MolDataset
Returns:
atoms_cat: Concatenated atomic numbers tensor of shape [total_atoms]
coords_cat: Concatenated coordinates tensor of shape [total_atoms, 3]
energies: Energy tensor of shape [batch_size]
batch_tensor: Batch indices tensor of shape [total_atoms]
"""
atoms_list = []
coords_list = []
energies_list = []
batch_indices = []
# Process each molecule in the batch
for i, (atoms, coords, energy) in enumerate(batch):
n_atoms = atoms.size(0)
# Store components
atoms_list.append(atoms)
coords_list.append(coords)
energies_list.append(energy)
# Create batch indices: [i, i, ..., i] for n_atoms times
batch_indices.append(torch.full((n_atoms,), i, dtype=torch.long))
# Concatenate all components
atoms_cat = torch.cat(atoms_list, dim=0) # shape: [total_atoms]
coords_cat = torch.cat(coords_list, dim=0) # shape: [total_atoms, 3]
energies = torch.stack(energies_list) # shape: [batch_size]
batch_tensor = torch.cat(batch_indices, dim=0) # shape: [total_atoms]
return atoms_cat, coords_cat, energies, batch_tensor
def get_train_test_data(ds_all, groups, mode, test_idcs=range(28986, 29803)):
grid1_selection = ['1', '3', '5', '7', '8', '9', '10', '12', '14', '16']
grid2_selection = ['1', '5', '8', '9', '12', '16']
assert(mode in ['pretrain', 'finetune'])
pretrain = mode == 'pretrain'
train_idces = []
for i in range(len(groups)):
if len(groups[i]) == 0:
if (i in test_idcs) != pretrain: train_idces.append(i)
elif len(groups[i]) == 1:
if pretrain or groups[i][0] in grid1_selection:
if (i in test_idcs) != pretrain: train_idces.append(i)
elif len(groups[i]) == 2:
if pretrain or groups[i][0] in grid2_selection and groups[i][1] in grid2_selection:
if (i in test_idcs) != pretrain: train_idces.append(i)
ds_train = torch.utils.data.Subset(ds_all, train_idces)
ds_test = torch.utils.data.Subset(ds_all, test_idcs)
return ds_train, ds_test