import re import torch def read_data(filename): all_coords = [] all_numbers = [] with open(filename) as f: cont = f.read() energies = [] groups = [] lines = cont.split('\n'); i = 0 mol_en = None while i < len(lines): try: n = int(lines[i].strip()) except ValueError: break comment = lines[i+1] energy = float(re.findall('Energy\\:\\s+(-?\\d*\.\\d*)', comment)[0]) g0 = re.findall('Grid: 0', comment) if g0: mol_en = energy grp = () g1 = re.findall('Grid: (\\d+): (\\d+)', comment) if g1: grp = (g1[0][1], ) g2 = re.findall('Grid: (\\d+): (\\d+), (\\d+): (\\d+)', comment) if g2: grp = (g2[0][1], g2[0][3]) energies.append(energy - mol_en) groups.append(grp) j = 0 all_coords.append([]) all_numbers.append([]) while j < n: at, x, y, z = list(filter(lambda x: x != '', lines[i+j+2].strip().split(' '))) all_coords[-1].append((float(x), float(y), float(z))) all_numbers[-1].append(int(at)) j += 1 i += n + 2 energies = torch.tensor(energies) return all_numbers, all_coords, energies, groups class MolDataset(torch.utils.data.Dataset): def __init__(self, all_numbers, all_coords, energies, normalize=False): self.numbers = all_numbers self.coords = all_coords self.energies = energies self.normalize = normalize def __len__(self): return len(self.energies) def __getitem__(self, ind): energy = self.energies[ind] atoms = torch.tensor(self.numbers[ind]) coords = torch.tensor(self.coords[ind], dtype=torch.float32) if self.normalize: energy = energy.sign() * energy.abs() ** 0.1 return atoms, coords, energy def collate_mol(batch): """ Collate function for molecular dataset. Args: batch: List of tuples (atoms, coords, energy) from MolDataset Returns: atoms_cat: Concatenated atomic numbers tensor of shape [total_atoms] coords_cat: Concatenated coordinates tensor of shape [total_atoms, 3] energies: Energy tensor of shape [batch_size] batch_tensor: Batch indices tensor of shape [total_atoms] """ atoms_list = [] coords_list = [] energies_list = [] batch_indices = [] # Process each molecule in the batch for i, (atoms, coords, energy) in enumerate(batch): n_atoms = atoms.size(0) # Store components atoms_list.append(atoms) coords_list.append(coords) energies_list.append(energy) # Create batch indices: [i, i, ..., i] for n_atoms times batch_indices.append(torch.full((n_atoms,), i, dtype=torch.long)) # Concatenate all components atoms_cat = torch.cat(atoms_list, dim=0) # shape: [total_atoms] coords_cat = torch.cat(coords_list, dim=0) # shape: [total_atoms, 3] energies = torch.stack(energies_list) # shape: [batch_size] batch_tensor = torch.cat(batch_indices, dim=0) # shape: [total_atoms] return atoms_cat, coords_cat, energies, batch_tensor def get_train_test_data(ds_all, groups, mode, test_idcs=range(28986, 29803)): grid1_selection = ['1', '3', '5', '7', '8', '9', '10', '12', '14', '16'] grid2_selection = ['1', '5', '8', '9', '12', '16'] assert(mode in ['pretrain', 'finetune']) pretrain = mode == 'pretrain' train_idces = [] for i in range(len(groups)): if len(groups[i]) == 0: if (i in test_idcs) != pretrain: train_idces.append(i) elif len(groups[i]) == 1: if pretrain or groups[i][0] in grid1_selection: if (i in test_idcs) != pretrain: train_idces.append(i) elif len(groups[i]) == 2: if pretrain or groups[i][0] in grid2_selection and groups[i][1] in grid2_selection: if (i in test_idcs) != pretrain: train_idces.append(i) ds_train = torch.utils.data.Subset(ds_all, train_idces) ds_test = torch.utils.data.Subset(ds_all, test_idcs) return ds_train, ds_test