File size: 3,519 Bytes
233f6d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class Tokenizer(object):
    """ A class to turn arbitrary inputs into integer classes. """

    def __init__(self):
        # the default class for an unseen entry during test-time
        self._data = {'unk': 1}
        self.num_classes = 1
        self.train = True
        self.unknown = []

    def __call__(self, item):
        """ Check to see if the Tokenizer has seen `item` before, and if so,
        return the integer class associated with it. Otherwise, if we're
        training, create a new integer class, otherwise return the 'unknown'
        class.

        """
        try:
            return self._data[item]

        except KeyError:
            if self.train:
                self._add_token(item)
                return self(item)

            else:
                # Record the unknown item, then return the unknown label
                self.unknown += [item]
                return self._data['unk']

    def _add_token(self, item):
        self.num_classes += 1
        self._data[item] = self.num_classes

# The rest of the methods in this module are specific functions for computing
# atom and bond features. New ones can be easily added though, and these are
# passed directly to the Preprocessor class.

def get_ring_size(obj, max_size=12):
    if not obj.IsInRing():
        return 0
    else:
        for i in range(max_size):
            if obj.IsInRingSize(i):
                return i
        else:
            return 'max'

def atom_features(atom):
    return atom.GetAtomicNum()


def atom_features_v1(atom):
    """ Return an integer hash representing the atom type
    """

    return str((
        atom.GetSymbol(),
        atom.GetDegree(),
        atom.GetTotalNumHs(),
        atom.GetImplicitValence(),
        atom.GetIsAromatic(),
    ))


def atom_features_v2(atom):

    props = ['GetChiralTag', 'GetDegree', 'GetExplicitValence',
             'GetFormalCharge', 'GetHybridization', 'GetImplicitValence',
             'GetIsAromatic', 'GetNoImplicit', 'GetNumExplicitHs',
             'GetNumImplicitHs', 'GetNumRadicalElectrons', 'GetSymbol',
             'GetTotalDegree', 'GetTotalNumHs', 'GetTotalValence']

    atom_type = [getattr(atom, prop)() for prop in props]
    atom_type += [get_ring_size(atom)]

    return str(tuple(atom_type))


def bond_features_v1(bond, **kwargs):
    """ Return an integer hash representing the bond type.

    flipped : bool
        Only valid for 'v3' version, whether to swap the begin and end atom types

    """

    return str((
        bond.GetBondType(),
        bond.GetIsConjugated(),
        bond.IsInRing(),
        sorted([
            bond.GetBeginAtom().GetSymbol(),
            bond.GetEndAtom().GetSymbol()]),
        ))


def bond_features_v2(bond, **kwargs):

    return str((
        bond.GetBondType(),
        bond.GetIsConjugated(),
        bond.GetStereo(),
        get_ring_size(bond),
        sorted([
            bond.GetBeginAtom().GetSymbol(),
            bond.GetEndAtom().GetSymbol()]),
        ))


def bond_features_v3(bond, flipped=False):

    if not flipped:
        start_atom = atom_features(bond.GetBeginAtom())
        end_atom = atom_features(bond.GetEndAtom())

    else:
        start_atom = atom_features(bond.GetEndAtom())
        end_atom = atom_features(bond.GetBeginAtom())

    return str((
        bond.GetBondType(),
        bond.GetIsConjugated(),
        bond.GetStereo(),
        get_ring_size(bond),
        bond.GetEndAtom().GetSymbol(),
        start_atom,
        end_atom))