File size: 4,272 Bytes
5fae7ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import re
import torch


def read_data(filename):
    all_coords = []
    all_numbers = []
    
    with open(filename) as f:
        cont = f.read()
    
    energies = []
    groups = []
    lines = cont.split('\n'); i = 0
    mol_en = None
    while i < len(lines):
        try:
            n = int(lines[i].strip())
        except ValueError:
            break
        comment = lines[i+1]
        energy = float(re.findall('Energy\\:\\s+(-?\\d*\.\\d*)', comment)[0])
        g0 = re.findall('Grid: 0', comment)
        if g0:
            mol_en = energy
            grp = ()
        g1 = re.findall('Grid: (\\d+): (\\d+)', comment)
        if g1:
            grp = (g1[0][1], )
        g2 = re.findall('Grid: (\\d+): (\\d+), (\\d+): (\\d+)', comment)
        if g2:
            grp = (g2[0][1], g2[0][3])
        energies.append(energy - mol_en)
        groups.append(grp)
        j = 0
        all_coords.append([])
        all_numbers.append([])
        while j < n:
            at, x, y, z = list(filter(lambda x: x != '', lines[i+j+2].strip().split(' ')))
            all_coords[-1].append((float(x), float(y), float(z)))
            all_numbers[-1].append(int(at))
            j += 1
        i += n + 2
    
    energies = torch.tensor(energies)
    
    return all_numbers, all_coords, energies, groups

class MolDataset(torch.utils.data.Dataset):
    def __init__(self, all_numbers, all_coords, energies, normalize=False):
        self.numbers = all_numbers
        self.coords = all_coords
        self.energies = energies
        self.normalize = normalize

    def __len__(self):
        return len(self.energies)

    def __getitem__(self, ind):
        energy = self.energies[ind]

        atoms = torch.tensor(self.numbers[ind])
        coords = torch.tensor(self.coords[ind], dtype=torch.float32)

        if self.normalize:
            energy = energy.sign() * energy.abs() ** 0.1

        return atoms, coords, energy

def collate_mol(batch):
    """
    Collate function for molecular dataset.
    
    Args:
        batch: List of tuples (atoms, coords, energy) from MolDataset
    
    Returns:
        atoms_cat: Concatenated atomic numbers tensor of shape [total_atoms]
        coords_cat: Concatenated coordinates tensor of shape [total_atoms, 3]
        energies: Energy tensor of shape [batch_size]
        batch_tensor: Batch indices tensor of shape [total_atoms]
    """
    atoms_list = []
    coords_list = []
    energies_list = []
    batch_indices = []
    
    # Process each molecule in the batch
    for i, (atoms, coords, energy) in enumerate(batch):
        n_atoms = atoms.size(0)
        
        # Store components
        atoms_list.append(atoms)
        coords_list.append(coords)
        energies_list.append(energy)
        
        # Create batch indices: [i, i, ..., i] for n_atoms times
        batch_indices.append(torch.full((n_atoms,), i, dtype=torch.long))
    
    # Concatenate all components
    atoms_cat = torch.cat(atoms_list, dim=0)      # shape: [total_atoms]
    coords_cat = torch.cat(coords_list, dim=0)    # shape: [total_atoms, 3]
    energies = torch.stack(energies_list)          # shape: [batch_size]
    batch_tensor = torch.cat(batch_indices, dim=0) # shape: [total_atoms]
    
    return atoms_cat, coords_cat, energies, batch_tensor

def get_train_test_data(ds_all, groups, mode, test_idcs=range(28986, 29803)):
    grid1_selection = ['1', '3', '5', '7', '8', '9', '10', '12', '14', '16']
    grid2_selection = ['1', '5', '8', '9', '12', '16']

    assert(mode in ['pretrain', 'finetune'])

    pretrain = mode == 'pretrain'

    train_idces = []
    
    for i in range(len(groups)):
        if len(groups[i]) == 0:
            if (i in test_idcs) != pretrain: train_idces.append(i)
        elif len(groups[i]) == 1:
            if pretrain or groups[i][0] in grid1_selection:
                if (i in test_idcs) != pretrain: train_idces.append(i)
        elif len(groups[i]) == 2:
            if pretrain or groups[i][0] in grid2_selection and groups[i][1] in grid2_selection:
                if (i in test_idcs) != pretrain: train_idces.append(i)

    ds_train = torch.utils.data.Subset(ds_all, train_idces)
    ds_test = torch.utils.data.Subset(ds_all, test_idcs)

    return ds_train, ds_test