Upload 13 files
Browse files- code/GNN/__init__.py +0 -0
- code/GNN/featurizer.py +138 -0
- code/GNN/layers.py +443 -0
- code/GNN/subgraphfp.py +138 -0
- code/GNN/utils.py +240 -0
- code/cliplayers.py +432 -0
- code/config.py +95 -0
- code/dataset.py +142 -0
- code/modules.py +158 -0
- code/predict.py +347 -0
- code/separate_posneg.py +29 -0
- code/train.py +251 -0
- code/utils.py +370 -0
code/GNN/__init__.py
ADDED
|
File without changes
|
code/GNN/featurizer.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from rdkit import Chem
|
| 3 |
+
from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors as rdDesc
|
| 4 |
+
from utils import *
|
| 5 |
+
import torch
|
| 6 |
+
import copy
|
| 7 |
+
from . import subgraphfp as subfp
|
| 8 |
+
|
| 9 |
+
PERIODIC_TABLE = Chem.GetPeriodicTable()
|
| 10 |
+
POSSIBLE_ATOMS = ['H', 'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br','I', 'B']
|
| 11 |
+
HYBRIDS = [ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
|
| 12 |
+
Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]
|
| 13 |
+
CHIRALS = [ Chem.rdchem.ChiralType.CHI_UNSPECIFIED, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
| 14 |
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER]
|
| 15 |
+
BOND_TYPES = [ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC ]
|
| 16 |
+
|
| 17 |
+
def one_of_k_encoding(x, allowable_set):
|
| 18 |
+
if x not in allowable_set:
|
| 19 |
+
raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
|
| 20 |
+
return list(map(lambda s: x == s, allowable_set))
|
| 21 |
+
|
| 22 |
+
def one_of_k_encoding_unk(x, allowable_set):
|
| 23 |
+
"""Maps inputs not in the allowable set to the last element."""
|
| 24 |
+
if x not in allowable_set:
|
| 25 |
+
x = allowable_set[-1]
|
| 26 |
+
|
| 27 |
+
return list(map(lambda s: x == s, allowable_set))
|
| 28 |
+
|
| 29 |
+
def calc_atom_features_onehot(atom, feature):
|
| 30 |
+
'''
|
| 31 |
+
Method that computes atom level features from rdkit atom object
|
| 32 |
+
'''
|
| 33 |
+
atom_features = one_of_k_encoding_unk(atom.GetSymbol(), POSSIBLE_ATOMS)
|
| 34 |
+
atom_features += one_of_k_encoding_unk(atom.GetExplicitValence(), list(range(7)))
|
| 35 |
+
atom_features += one_of_k_encoding_unk(atom.GetImplicitValence(), list(range(7)))
|
| 36 |
+
atom_features += one_of_k_encoding_unk(atom.GetTotalNumHs(), list(range(5)))
|
| 37 |
+
atom_features += one_of_k_encoding_unk(atom.GetNumRadicalElectrons(), list(range(5)))
|
| 38 |
+
atom_features += one_of_k_encoding_unk(atom.GetTotalDegree(), list(range(7)))
|
| 39 |
+
atom_features += one_of_k_encoding_unk(atom.GetFormalCharge(), list(range(-2, 3)))
|
| 40 |
+
atom_features += one_of_k_encoding_unk(atom.GetHybridization(), HYBRIDS)
|
| 41 |
+
atom_features += one_of_k_encoding_unk(atom.GetIsAromatic(), [False, True])
|
| 42 |
+
atom_features += one_of_k_encoding_unk(atom.IsInRing(), [False, True])
|
| 43 |
+
atom_features += one_of_k_encoding_unk(atom.GetChiralTag(), CHIRALS)
|
| 44 |
+
atom_features += one_of_k_encoding_unk(atom.HasProp('_CIPCode'), ['R', 'S'])
|
| 45 |
+
atom_features += [PERIODIC_TABLE.GetRvdw(atom.GetSymbol())]
|
| 46 |
+
atom_features += [atom.HasProp('_ChiralityPossible')]
|
| 47 |
+
atom_features += [atom.GetAtomicNum()]
|
| 48 |
+
atom_features += [atom.GetMass() * 0.01]
|
| 49 |
+
atom_features += [atom.GetDegree()]
|
| 50 |
+
atom_features += [int(i) for i in list('{0:06b}'.format(feature))]
|
| 51 |
+
|
| 52 |
+
return atom_features
|
| 53 |
+
|
| 54 |
+
def calc_adjacent_tensor(bonds, atom_num, with_ring_conj=False):
|
| 55 |
+
'''
|
| 56 |
+
Method that constructs a AdjecentTensor with many AdjecentMatrics
|
| 57 |
+
:param bonds: bonds of a rdkit mol
|
| 58 |
+
:param atom_num: the atom number of the rdkit mol
|
| 59 |
+
:param with_ring_conj: should the AdjecentTensor contains bond in ring and
|
| 60 |
+
is conjugated info
|
| 61 |
+
:return: AdjecentTensor A shaped [N, F, N], where N is atom number and F is bond types
|
| 62 |
+
'''
|
| 63 |
+
bond_types = len(BOND_TYPES)
|
| 64 |
+
if with_ring_conj:
|
| 65 |
+
bond_types += 2
|
| 66 |
+
|
| 67 |
+
A = np.zeros([atom_num, bond_types, atom_num])
|
| 68 |
+
|
| 69 |
+
for bond in bonds:
|
| 70 |
+
b, e = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
| 71 |
+
try:
|
| 72 |
+
bond_type = BOND_TYPES.index(bond.GetBondType())
|
| 73 |
+
A[b, bond_type, e] = 1
|
| 74 |
+
A[e, bond_type, b] = 1
|
| 75 |
+
if with_ring_conj:
|
| 76 |
+
if bond.IsInRing():
|
| 77 |
+
A[b, bond_types-2, e] = 1
|
| 78 |
+
A[e, bond_types-2, b] = 1
|
| 79 |
+
if bond.GetIsConjugated():
|
| 80 |
+
A[b, bond_types-1, e] = 1
|
| 81 |
+
A[e, bond_types-1, b] = 1
|
| 82 |
+
except:
|
| 83 |
+
pass
|
| 84 |
+
return A
|
| 85 |
+
|
| 86 |
+
def calc_data_from_smile(smiles, addh=False, with_ring_conj=False, with_atom_feats=True, with_submol_fp=True, radius=2):
|
| 87 |
+
'''
|
| 88 |
+
Method that constructs the data of a molecular.
|
| 89 |
+
:param smiles: SMILES representation of a molecule
|
| 90 |
+
:param addh: should we add all the Hs of the mol
|
| 91 |
+
:param with_ring_conj: should the AdjecentTensor contains bond in ring and
|
| 92 |
+
is conjugated info
|
| 93 |
+
:return: V, A, global_state, mol_size, subgraph_size
|
| 94 |
+
'''
|
| 95 |
+
mol = Chem.MolFromSmiles(smiles, sanitize=True)
|
| 96 |
+
#mol.UpdatePropertyCache(strict=False)
|
| 97 |
+
|
| 98 |
+
if addh:
|
| 99 |
+
mol = Chem.AddHs(mol)
|
| 100 |
+
#else:
|
| 101 |
+
# mol = Chem.RemoveHs(mol, sanitize=False)
|
| 102 |
+
|
| 103 |
+
mol_size = torch.IntTensor([mol.GetNumAtoms()])
|
| 104 |
+
|
| 105 |
+
V = []
|
| 106 |
+
|
| 107 |
+
if with_atom_feats:
|
| 108 |
+
features = rdDesc.GetFeatureInvariants(mol)
|
| 109 |
+
|
| 110 |
+
submoldict = {}
|
| 111 |
+
if with_submol_fp:
|
| 112 |
+
atoms, submols = subfp.get_atom_submol_radn(mol, radius, sanitize=True)
|
| 113 |
+
submoldict = dict(zip([a.GetIdx() for a in atoms], submols))
|
| 114 |
+
|
| 115 |
+
for i in range(mol.GetNumAtoms()):
|
| 116 |
+
atom_i = mol.GetAtomWithIdx(i)
|
| 117 |
+
if with_atom_feats:
|
| 118 |
+
atom_i_features = calc_atom_features_onehot(atom_i, features[i])
|
| 119 |
+
else:
|
| 120 |
+
atom_i_features = []
|
| 121 |
+
|
| 122 |
+
if with_submol_fp:
|
| 123 |
+
submol = submoldict[i]
|
| 124 |
+
#print(Chem.MolToSmiles(submol))
|
| 125 |
+
submolfp = subfp.gen_fps_from_mol(submol)
|
| 126 |
+
atom_i_features.extend(submolfp)
|
| 127 |
+
|
| 128 |
+
V.append(atom_i_features)
|
| 129 |
+
|
| 130 |
+
V = torch.FloatTensor(V)
|
| 131 |
+
|
| 132 |
+
if len(V.shape) != 2:
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
A = calc_adjacent_tensor(mol.GetBonds(), mol.GetNumAtoms(), with_ring_conj)
|
| 136 |
+
A = torch.FloatTensor(A)
|
| 137 |
+
|
| 138 |
+
return {'V': V, 'A': A, 'mol_size': mol_size}
|
code/GNN/layers.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
import utils
|
| 7 |
+
import pickle
|
| 8 |
+
|
| 9 |
+
DEVICE = torch.cuda.is_available() and torch.device('cuda') or torch.device('cpu')
|
| 10 |
+
|
| 11 |
+
class GraphCNNLayer(nn.Module):
|
| 12 |
+
def __init__(self, n_feats, adj_chans=4, n_filters=64, bias=True):
|
| 13 |
+
super(GraphCNNLayer, self).__init__()
|
| 14 |
+
self.n_feats = n_feats
|
| 15 |
+
self.adj_chans = adj_chans
|
| 16 |
+
self.n_filters = n_filters
|
| 17 |
+
self.has_bias = bias
|
| 18 |
+
|
| 19 |
+
# [C*L, F], C = n_feats, L = adj_chans, F = n_filters; this is for the edge feats
|
| 20 |
+
self.weight_e = nn.Parameter(torch.FloatTensor(adj_chans*n_feats, n_filters))
|
| 21 |
+
# [C, F], this is for 𝐈𝐕in𝐖0
|
| 22 |
+
self.weight_i = nn.Parameter(torch.FloatTensor(n_feats, self.n_filters))
|
| 23 |
+
|
| 24 |
+
if bias:
|
| 25 |
+
self.bias = nn.Parameter(torch.FloatTensor(n_filters))
|
| 26 |
+
else:
|
| 27 |
+
self.register_parameter('bias', None)
|
| 28 |
+
|
| 29 |
+
self.reset_parameters()
|
| 30 |
+
|
| 31 |
+
def reset_parameters(self):
|
| 32 |
+
nn.init.xavier_uniform_(self.weight_e)
|
| 33 |
+
nn.init.xavier_uniform_(self.weight_i)
|
| 34 |
+
|
| 35 |
+
if self.bias is not None:
|
| 36 |
+
self.bias.data.fill_(0.01)
|
| 37 |
+
|
| 38 |
+
def forward(self, V, A):
|
| 39 |
+
'''V node features: [b, N, C], A adjs: [b, N, L, N], L = adj_chans'''
|
| 40 |
+
b, N, C = V.shape
|
| 41 |
+
b, N, L, _ = A.shape
|
| 42 |
+
|
| 43 |
+
# formula: 𝐕out = 𝐈𝐕in𝐖0 + GConv(𝐕in, 𝐹) + 𝐛; 𝐈𝐕in = 𝐕in, so 𝐈𝐕in𝐖0 = 𝐕in𝐖0
|
| 44 |
+
|
| 45 |
+
# A [b, N, L, N] -> [b, N*L, N]
|
| 46 |
+
A_reshape = A.view(-1, N*L, N)
|
| 47 |
+
# [b, N*L, N] * [b, N, C] -> [b, N*L, C]
|
| 48 |
+
n = torch.bmm(A_reshape, V)
|
| 49 |
+
# [b, N*L, C] -> [b, N, L*C]
|
| 50 |
+
n = n.view(-1, N, L*self.n_feats)
|
| 51 |
+
|
| 52 |
+
# n [b, N, L*C], W [C*L, F], V [b, N, C], W_I [C, F]
|
| 53 |
+
# -> [b, N, F] + [b, N, F] + b
|
| 54 |
+
output = torch.matmul(n, self.weight_e) + torch.matmul(V, self.weight_i)
|
| 55 |
+
|
| 56 |
+
if self.has_bias:
|
| 57 |
+
output += self.bias
|
| 58 |
+
|
| 59 |
+
# output: [b, N, F]
|
| 60 |
+
return output
|
| 61 |
+
|
| 62 |
+
def __repr__(self):
|
| 63 |
+
return f'{self.__class__.__name__}(n_feats={self.n_feats},adj_chans={self.adj_chans},n_filters={self.n_filters},bias={self.has_bias}) -> [b, N, {self.n_filters}]'
|
| 64 |
+
|
| 65 |
+
class GraphResidualCNNLayer(nn.Module):
|
| 66 |
+
def __init__(self, n_feats, adj_chans=4, bias=True):
|
| 67 |
+
super(GraphResidualCNNLayer, self).__init__()
|
| 68 |
+
self.n_feats = n_feats
|
| 69 |
+
self.adj_chans = adj_chans
|
| 70 |
+
self.has_bias = bias
|
| 71 |
+
|
| 72 |
+
# [C*L, F], C = n_feats, L = adj_chans
|
| 73 |
+
self.weight_layers = nn.ModuleList([nn.Linear(n_feats, n_feats) for _ in range(adj_chans)])
|
| 74 |
+
|
| 75 |
+
if bias:
|
| 76 |
+
self.bias = nn.Parameter(torch.FloatTensor(n_feats))
|
| 77 |
+
else:
|
| 78 |
+
self.register_parameter('bias', None)
|
| 79 |
+
|
| 80 |
+
self.reset_parameters()
|
| 81 |
+
|
| 82 |
+
def reset_parameters(self):
|
| 83 |
+
if self.bias is not None:
|
| 84 |
+
self.bias.data.fill_(0.01)
|
| 85 |
+
|
| 86 |
+
def forward(self, V, A):
|
| 87 |
+
'''V node features: [b, N, C], A adjs: [b, N, L, N], L = adj_chans'''
|
| 88 |
+
b, N, C = V.shape
|
| 89 |
+
b, N, L, _ = A.shape
|
| 90 |
+
|
| 91 |
+
for i in range(self.adj_chans):
|
| 92 |
+
# [b, N, C] -> [b, N, C]
|
| 93 |
+
hs = F.relu(self.weight_layers[i](V))
|
| 94 |
+
# [b, N, N]
|
| 95 |
+
a = A[:, :, i, :]
|
| 96 |
+
a = a.view(-1, N, N)
|
| 97 |
+
# [b, N, N] * [b, N, C] -> [b, N, C]
|
| 98 |
+
V = V + torch.bmm(a, hs)
|
| 99 |
+
|
| 100 |
+
if self.has_bias:
|
| 101 |
+
V += self.bias
|
| 102 |
+
|
| 103 |
+
# output: [b, N, C]
|
| 104 |
+
return V
|
| 105 |
+
|
| 106 |
+
def __repr__(self):
|
| 107 |
+
return f'{self.__class__.__name__}(n_feats={self.n_feats},adj_chans={self.adj_chans},bias={self.has_bias}) -> [b, N, {self.n_feats}]'
|
| 108 |
+
|
| 109 |
+
class GraphAttentionLayer(nn.Module):
|
| 110 |
+
def __init__(self, n_feats, adj_chans=4, n_filters=64, bias=True, dropout=0., alpha=0.2):
|
| 111 |
+
super(GraphAttentionLayer, self).__init__()
|
| 112 |
+
self.n_feats = n_feats
|
| 113 |
+
self.adj_chans = adj_chans
|
| 114 |
+
self.n_filters = n_filters
|
| 115 |
+
self.has_bias = bias
|
| 116 |
+
self.dropout = dropout
|
| 117 |
+
self.alpha = alpha
|
| 118 |
+
|
| 119 |
+
# [C*L, F], C = n_feats, L = adj_chans, F = n_filters; this is for the edge feats
|
| 120 |
+
self.weight_list = nn.ParameterList([nn.Parameter(torch.FloatTensor(n_feats, n_filters)) for _ in range(adj_chans)])
|
| 121 |
+
self.a1_list = nn.ParameterList([nn.Parameter(torch.FloatTensor(n_filters, 1)) for _ in range(adj_chans)])
|
| 122 |
+
self.a2_list = nn.ParameterList([nn.Parameter(torch.FloatTensor(n_filters, 1)) for _ in range(adj_chans)])
|
| 123 |
+
|
| 124 |
+
if bias:
|
| 125 |
+
self.bias = nn.Parameter(torch.FloatTensor(n_filters))
|
| 126 |
+
else:
|
| 127 |
+
self.register_parameter('bias', None)
|
| 128 |
+
|
| 129 |
+
self.reset_parameters()
|
| 130 |
+
|
| 131 |
+
def reset_parameters(self):
|
| 132 |
+
for w in self.weight_list:
|
| 133 |
+
nn.init.xavier_uniform_(w)
|
| 134 |
+
for w in self.a1_list:
|
| 135 |
+
nn.init.xavier_uniform_(w)
|
| 136 |
+
for w in self.a2_list:
|
| 137 |
+
nn.init.xavier_uniform_(w)
|
| 138 |
+
if self.bias is not None:
|
| 139 |
+
self.bias.data.fill_(0.01)
|
| 140 |
+
|
| 141 |
+
def forward(self, V, A):
|
| 142 |
+
'''V node features: [b, N, C], A adjs: [b, N, L, N], L = adj_chans'''
|
| 143 |
+
b, N, C = V.shape
|
| 144 |
+
b, N, L, _ = A.shape
|
| 145 |
+
|
| 146 |
+
output = None
|
| 147 |
+
|
| 148 |
+
# formula: 𝐕out = 𝐈𝐕in𝐖0 + GConv(𝐕in, 𝐹) + 𝐛; 𝐈𝐕in = 𝐕in, so 𝐈𝐕in𝐖0 = 𝐕in𝐖0
|
| 149 |
+
for i in range(self.adj_chans):
|
| 150 |
+
# [b, N, 1, N] -> [b, N, N]
|
| 151 |
+
adj = A[:, :, i, :].view(-1, N, N)
|
| 152 |
+
|
| 153 |
+
# [b, N, C] * [C, F] -> [b, N, F]
|
| 154 |
+
h = torch.matmul(V, self.weight_list[i])
|
| 155 |
+
# [b, N, F] * [F, 1] -> [b, N, 1]
|
| 156 |
+
f_1 = torch.matmul(h, self.a1_list[i])
|
| 157 |
+
# [b, N, F] * [F, 1] -> [b, N, 1]
|
| 158 |
+
f_2 = torch.matmul(h, self.a2_list[i])
|
| 159 |
+
|
| 160 |
+
# leaky_relu([b, N, 1] + [b, 1, N]) -> [b, N, N]
|
| 161 |
+
e = F.leaky_relu(f_1 + f_2.transpose(1, 2), self.alpha)
|
| 162 |
+
|
| 163 |
+
zero_vec = -9e15 * torch.ones_like(e)
|
| 164 |
+
# [b, N, N]
|
| 165 |
+
att = torch.where(adj > 0, e, zero_vec)
|
| 166 |
+
att = F.softmax(att, dim=1)
|
| 167 |
+
att = F.dropout(att, self.dropout, training=self.training)
|
| 168 |
+
# [b, N, N] * [b, N, F] -> [b, N, F]
|
| 169 |
+
if output is None:
|
| 170 |
+
output = torch.matmul(att, h)
|
| 171 |
+
else:
|
| 172 |
+
output += torch.matmul(att, h)
|
| 173 |
+
|
| 174 |
+
if self.has_bias:
|
| 175 |
+
output += self.bias
|
| 176 |
+
|
| 177 |
+
# output: [b, N, F]
|
| 178 |
+
return output
|
| 179 |
+
|
| 180 |
+
def __repr__(self):
|
| 181 |
+
return f'{self.__class__.__name__}(n_feats={self.n_feats},adj_chans={self.adj_chans},n_filters={self.n_filters},bias={self.has_bias},dropout={self.dropout},alpha={self.alpha}) -> [b, N, {self.n_filters}]'
|
| 182 |
+
|
| 183 |
+
class GraphNodeCatGlobalFeatures(nn.Module):
|
| 184 |
+
def __init__(self, global_feats, out_feats, mols=1, bias=True):
|
| 185 |
+
super(GraphNodeCatGlobalFeatures, self).__init__()
|
| 186 |
+
self.global_feats = global_feats
|
| 187 |
+
self.out_feats = out_feats
|
| 188 |
+
self.mols = mols
|
| 189 |
+
self.has_bias = bias
|
| 190 |
+
|
| 191 |
+
self.weights = nn.ParameterList([nn.Parameter(torch.FloatTensor(int(global_feats/mols), out_feats)) for _ in range(mols)])
|
| 192 |
+
|
| 193 |
+
self.biass = []
|
| 194 |
+
if bias:
|
| 195 |
+
self.biass = nn.ParameterList([nn.Parameter(torch.FloatTensor(out_feats)) for _ in range(mols)])
|
| 196 |
+
else:
|
| 197 |
+
self.register_parameter('bias', None)
|
| 198 |
+
|
| 199 |
+
self.reset_parameters()
|
| 200 |
+
|
| 201 |
+
def reset_parameters(self):
|
| 202 |
+
for weight in self.weights:
|
| 203 |
+
nn.init.xavier_uniform_(weight)
|
| 204 |
+
for bias in self.biass:
|
| 205 |
+
bias.data.fill_(0.01)
|
| 206 |
+
|
| 207 |
+
def forward(self, V, global_state, graph_size, subgraph_size=None):
|
| 208 |
+
# V: [b, N, Ov], global_state: [b, F], subgraph_size: [b, mols]
|
| 209 |
+
b, N, Ov = V.shape
|
| 210 |
+
O = self.out_feats
|
| 211 |
+
if self.mols == 1:
|
| 212 |
+
subgraph_size = graph_size.view(-1, 1)
|
| 213 |
+
global_state = torch.mm(global_state, self.weights[0])
|
| 214 |
+
else:
|
| 215 |
+
# global_state: [b, F] view -> [b*mols, F/mols]
|
| 216 |
+
global_state_view = global_state.view(b*self.mols, -1)
|
| 217 |
+
|
| 218 |
+
# split global_state into that of individual mols
|
| 219 |
+
idxmols = []
|
| 220 |
+
for i in range(self.mols):
|
| 221 |
+
idxmols.append(torch.IntTensor(list(range(i, b*self.mols, self.mols))).to(self.weights[0].device))
|
| 222 |
+
|
| 223 |
+
global_states = []
|
| 224 |
+
for i, idx in enumerate(idxmols):
|
| 225 |
+
# selected global_state of mols from global_state_view [b*mols, F/mols]. Out shape is [b, F/mols]
|
| 226 |
+
gs = global_state_view.index_select(dim=0, index=idx)
|
| 227 |
+
# gs: [b, F/mols] * weight: [F/mols, O] -> [b, O]; F = global_feats, O = out_feats
|
| 228 |
+
gs = torch.mm(gs, self.weights[i])
|
| 229 |
+
|
| 230 |
+
if self.has_bias:
|
| 231 |
+
gs += self.biass[i]
|
| 232 |
+
|
| 233 |
+
global_states.append(F.relu(gs))
|
| 234 |
+
|
| 235 |
+
# convert global_states back to global_state
|
| 236 |
+
# [[b, O] ... ] -> [b, mols*O]
|
| 237 |
+
global_state = torch.cat(global_states, dim=1)
|
| 238 |
+
|
| 239 |
+
# [b, mols*O] || [b, O] -> [b, (mols+1)*O]
|
| 240 |
+
global_state_new = torch.cat([global_state, torch.zeros(b, O).to(self.weights[0].device)], dim=-1)
|
| 241 |
+
# [b*(mols+1), O]
|
| 242 |
+
global_state_new = global_state_new.view(-1, O)
|
| 243 |
+
|
| 244 |
+
repeats = []
|
| 245 |
+
for sz in subgraph_size:
|
| 246 |
+
repeats.extend(sz.tolist() + [N-sz.sum()])
|
| 247 |
+
repeats = torch.tensor(repeats).to(self.weights[0].device)
|
| 248 |
+
|
| 249 |
+
# repeat form [b*(mols+1), O] -> [b*N, O], the content like [m1_feats, m2_feats, ... mn_feats, pads, ...]
|
| 250 |
+
global_state_new = global_state_new.repeat_interleave(repeats, dim=0)
|
| 251 |
+
|
| 252 |
+
# V view: [b*N, Ov], global_state_new: [b*N, O]
|
| 253 |
+
output = torch.cat([V.contiguous().view(-1, Ov), global_state_new], dim=1)
|
| 254 |
+
|
| 255 |
+
# output: [b, N, Ov+O]
|
| 256 |
+
return output.view(-1, N, Ov+O), global_state
|
| 257 |
+
|
| 258 |
+
def __repr__(self):
|
| 259 |
+
return f'{self.__class__.__name__}(global_feats={self.global_feats},out_feats={self.out_feats},bias={self.has_bias}) -> [b, N, {self.global_feats+self.out_feats}], [b, out_feats]'
|
| 260 |
+
|
| 261 |
+
class MultiHeadGlobalAttention(nn.Module):
|
| 262 |
+
'''Input [b, N, C] -> output [b, n_head*C] if concat or else [b, n_head]'''
|
| 263 |
+
def __init__(self, n_feats, n_head=5, alpha=0.2, concat=True, bias=True):
|
| 264 |
+
super(MultiHeadGlobalAttention, self).__init__()
|
| 265 |
+
|
| 266 |
+
self.n_feats = n_feats
|
| 267 |
+
self.n_head = n_head
|
| 268 |
+
self.alpha = alpha
|
| 269 |
+
self.concat = concat
|
| 270 |
+
self.has_bias = bias
|
| 271 |
+
|
| 272 |
+
self.weight = nn.Parameter(torch.FloatTensor(n_feats, n_head*n_feats))
|
| 273 |
+
self.tune_weight = nn.Parameter(torch.FloatTensor(1, n_head, n_feats))
|
| 274 |
+
|
| 275 |
+
if bias:
|
| 276 |
+
self.bias = nn.Parameter(torch.FloatTensor(n_head*n_feats))
|
| 277 |
+
else:
|
| 278 |
+
self.register_parameter('bias', None)
|
| 279 |
+
|
| 280 |
+
self.reset_parameters()
|
| 281 |
+
|
| 282 |
+
def reset_parameters(self):
|
| 283 |
+
nn.init.xavier_uniform_(self.weight)
|
| 284 |
+
nn.init.xavier_uniform_(self.tune_weight)
|
| 285 |
+
if self.bias is not None:
|
| 286 |
+
self.bias.data.fill_(0.01)
|
| 287 |
+
|
| 288 |
+
def forward(self, V, graph_size):
|
| 289 |
+
# Gather V of mols in a batch, after this, the pad was removed.
|
| 290 |
+
#print(248, V.shape, graph_size)
|
| 291 |
+
if V.shape[0] == 1:
|
| 292 |
+
Vg = torch.squeeze(V)
|
| 293 |
+
graph_size = [graph_size]
|
| 294 |
+
else:
|
| 295 |
+
Vg = torch.cat([torch.split(v.view(-1, v.shape[-1]), graph_size[i])[0] for i,v in enumerate(torch.split(V, 1))], dim=0)
|
| 296 |
+
|
| 297 |
+
Vg = torch.matmul(Vg, self.weight)
|
| 298 |
+
if self.has_bias:
|
| 299 |
+
Vg += self.bias
|
| 300 |
+
Vg = Vg.view(-1, self.n_head, self.n_feats)
|
| 301 |
+
|
| 302 |
+
alpha = torch.mul(self.tune_weight, Vg)
|
| 303 |
+
alpha = torch.sum(alpha, dim=-1)
|
| 304 |
+
alpha = F.leaky_relu(alpha, self.alpha) # original code is "alpha = tf.nn.leaky_relu(alpha, alpha=0.2)"
|
| 305 |
+
alpha = utils.segment_softmax(alpha, graph_size)
|
| 306 |
+
|
| 307 |
+
#alpha_collect = torch.mean(alpha, dim=-1) # origin code like this. alpha_collect not used?
|
| 308 |
+
alpha = alpha.view(-1, self.n_head, 1)
|
| 309 |
+
V = torch.mul(Vg, alpha)
|
| 310 |
+
|
| 311 |
+
if self.concat:
|
| 312 |
+
V = utils.segment_sum(V, graph_size)
|
| 313 |
+
V = V.view(-1, self.n_head*self.n_feats)
|
| 314 |
+
else:
|
| 315 |
+
V = torch.mean(V, dim=1)
|
| 316 |
+
V = utils.segment_sum(V, graph_size)
|
| 317 |
+
|
| 318 |
+
return V
|
| 319 |
+
|
| 320 |
+
def __repr__(self):
|
| 321 |
+
if self.concat:
|
| 322 |
+
outc = self.n_head*self.n_feats
|
| 323 |
+
else:
|
| 324 |
+
outc = self.n_head
|
| 325 |
+
return f'{self.__class__.__name__}(n_feats={self.n_feats},n_head={self.n_head},alpha={self.alpha},concat={self.concat},bias={self.has_bias}) -> [b, {outc}]'
|
| 326 |
+
|
| 327 |
+
class GraphEmbedPoolingLayer(nn.Module):
|
| 328 |
+
def __init__(self, n_feats, n_filters=1, mask=None, bias=True):
|
| 329 |
+
super(GraphEmbedPoolingLayer, self).__init__()
|
| 330 |
+
self.n_feats = n_feats
|
| 331 |
+
self.n_filters = n_filters
|
| 332 |
+
self.mask = mask
|
| 333 |
+
self.has_bias = bias
|
| 334 |
+
|
| 335 |
+
self.emb = nn.Linear(n_feats, n_filters, bias=bias)
|
| 336 |
+
|
| 337 |
+
def forward(self, V, A):
|
| 338 |
+
# [b, N, F]
|
| 339 |
+
factors = self.emb(V)
|
| 340 |
+
|
| 341 |
+
if self.mask is not None:
|
| 342 |
+
factors = torch.mul(factors, self.mask)
|
| 343 |
+
|
| 344 |
+
factors = F.softmax(factors, dim=1)
|
| 345 |
+
# [b, N, F] trans -> [b, F, N] * [b, N, C] -> [b, F, C]
|
| 346 |
+
result = torch.matmul(factors.transpose(1, 2).contiguous(), V)
|
| 347 |
+
|
| 348 |
+
if self.n_filters == 1:
|
| 349 |
+
return result.view(-1, self.n_feats), A
|
| 350 |
+
|
| 351 |
+
result_A = A.view(A.shape[0], -1, A.shape[-1])
|
| 352 |
+
result_A = torch.matmul(result_A, factors)
|
| 353 |
+
result_A = result_A.view(A.shape[0], A.shape[-1], -1)
|
| 354 |
+
result_A = torch.matmul(factors.transpose(1, 2).contiguous(), result_A)
|
| 355 |
+
result_A = result_A.view(A.shape[0], self.n_filters, A.shape[2], self.n_filters)
|
| 356 |
+
|
| 357 |
+
return result, result_A
|
| 358 |
+
|
| 359 |
+
def __repr__(self):
|
| 360 |
+
return f'{self.__class__.__name__}(n_feats={self.n_feats},n_filters={self.n_filters},mask={self.mask},bias={self.has_bias}) -> [b, {self.n_filters}, {self.n_feats}], [b, {self.n_filters}, L, {self.n_filters}]'
|
| 361 |
+
|
| 362 |
+
class GConvBlockWithGF(nn.Module):
|
| 363 |
+
def __init__( self,
|
| 364 |
+
n_feats,
|
| 365 |
+
n_filters,
|
| 366 |
+
global_feats,
|
| 367 |
+
global_out_feats,
|
| 368 |
+
mols=1,
|
| 369 |
+
adj_chans=4,
|
| 370 |
+
bias=True,
|
| 371 |
+
usegat=False):
|
| 372 |
+
|
| 373 |
+
super(GConvBlockWithGF, self).__init__()
|
| 374 |
+
|
| 375 |
+
self.n_feats = n_feats
|
| 376 |
+
self.n_filters = n_filters
|
| 377 |
+
self.global_out_feats = global_out_feats
|
| 378 |
+
self.global_feats = global_feats
|
| 379 |
+
self.mols = mols
|
| 380 |
+
self.adj_chans = adj_chans
|
| 381 |
+
self.has_bias = bias
|
| 382 |
+
self.usegat = usegat
|
| 383 |
+
|
| 384 |
+
self.broadcast_global_state = GraphNodeCatGlobalFeatures(global_feats, global_out_feats, mols, bias)
|
| 385 |
+
if usegat:
|
| 386 |
+
self.graph_conv = GraphAttentionLayer(n_feats+global_out_feats, adj_chans, n_filters)
|
| 387 |
+
else:
|
| 388 |
+
self.graph_conv = GraphCNNLayer(n_feats+global_out_feats, adj_chans, n_filters, bias)
|
| 389 |
+
|
| 390 |
+
self.bn_global = nn.BatchNorm1d(global_out_feats*mols)
|
| 391 |
+
self.bn_graph = nn.BatchNorm1d(n_filters)
|
| 392 |
+
|
| 393 |
+
def forward(self, V, A, global_state, graph_size, subgraph_size):
|
| 394 |
+
######## transfer global_state #########
|
| 395 |
+
# V shape from [b, N, C] to [b, N, C+F], F is n_filters
|
| 396 |
+
V, global_state = self.broadcast_global_state(V, global_state, graph_size, subgraph_size)
|
| 397 |
+
|
| 398 |
+
######## Graph Convolution #########
|
| 399 |
+
# V shape from [b, N, C+F] to [b, N, F1], F1 is n_filters
|
| 400 |
+
V = self.graph_conv(V, A)
|
| 401 |
+
V = self.bn_graph(V.transpose(1, 2).contiguous())
|
| 402 |
+
V = F.relu(V.transpose(1, 2))
|
| 403 |
+
|
| 404 |
+
global_state = F.relu(self.bn_global(global_state))
|
| 405 |
+
|
| 406 |
+
return V, global_state
|
| 407 |
+
|
| 408 |
+
def __repr__(self):
|
| 409 |
+
return f'{self.__class__.__name__}(n_feats={self.n_feats},n_filters={self.n_filters},global_feats={self.global_feats},global_out_feats={self.global_out_feats},mols={self.mols},adj_chans={self.adj_chans},bias={self.has_bias},usegat={self.usegat}) -> [b, N, {self.n_filters}], [b, {self.global_out_feats*self.mols}]'
|
| 410 |
+
|
| 411 |
+
class GConvBlockNoGF(nn.Module):
|
| 412 |
+
def __init__( self,
|
| 413 |
+
n_feats,
|
| 414 |
+
n_filters,
|
| 415 |
+
mols=1,
|
| 416 |
+
adj_chans=4,
|
| 417 |
+
bias=True):
|
| 418 |
+
|
| 419 |
+
super(GConvBlockNoGF, self).__init__()
|
| 420 |
+
|
| 421 |
+
self.n_feats = n_feats
|
| 422 |
+
self.n_filters = n_filters
|
| 423 |
+
self.mols = mols
|
| 424 |
+
self.adj_chans = adj_chans
|
| 425 |
+
self.has_bias = bias
|
| 426 |
+
|
| 427 |
+
#self.graph_conv = GraphCNNLayer(n_feats+n_filters, adj_chans, n_filters, bias)
|
| 428 |
+
self.graph_conv = GraphCNNLayer(n_feats, adj_chans, n_filters, bias)
|
| 429 |
+
|
| 430 |
+
#self.bn_global = nn.BatchNorm1d(n_filters*mols)
|
| 431 |
+
self.bn_graph = nn.BatchNorm1d(n_filters)
|
| 432 |
+
|
| 433 |
+
def forward(self, V, A):
|
| 434 |
+
######## Graph Convolution #########
|
| 435 |
+
# V shape from [b, N, C+F] to [b, N, F1], F1 is n_filters
|
| 436 |
+
V = self.graph_conv(V, A)
|
| 437 |
+
V = self.bn_graph(V.transpose(1, 2).contiguous())
|
| 438 |
+
V = F.relu(V.transpose(1, 2))
|
| 439 |
+
|
| 440 |
+
return V
|
| 441 |
+
|
| 442 |
+
def __repr__(self):
|
| 443 |
+
return f'{self.__class__.__name__}(n_feats={self.n_feats},n_filters={self.n_filters},mols={self.mols},adj_chans={self.adj_chans},bias={self.has_bias}) -> [b, N, {self.n_filters}]'
|
code/GNN/subgraphfp.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rdkit import Chem
|
| 2 |
+
from rdkit.Chem import Draw
|
| 3 |
+
from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors as rdDesc
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os, pickle, hashlib
|
| 7 |
+
|
| 8 |
+
AllChem.SetPreferCoordGen(True)
|
| 9 |
+
|
| 10 |
+
FINGERPRINT_DICT = defaultdict(lambda : len(FINGERPRINT_DICT))
|
| 11 |
+
|
| 12 |
+
ELEMENTS = ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al',
|
| 13 |
+
'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn',
|
| 14 |
+
'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb',
|
| 15 |
+
'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In',
|
| 16 |
+
'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm',
|
| 17 |
+
'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta',
|
| 18 |
+
'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At',
|
| 19 |
+
'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk',
|
| 20 |
+
'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt',
|
| 21 |
+
'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og']
|
| 22 |
+
|
| 23 |
+
for e in ELEMENTS:
|
| 24 |
+
FINGERPRINT_DICT[e]
|
| 25 |
+
|
| 26 |
+
if os.path.exists('rdkit_fingerprint_list_r1.pkl'):
|
| 27 |
+
l = pickle.load(open('rdkit_fingerprint_list_r1.pkl', 'rb'))
|
| 28 |
+
|
| 29 |
+
for smi in l:
|
| 30 |
+
FINGERPRINT_DICT[smi]
|
| 31 |
+
|
| 32 |
+
print('Len fingerprint_list: %s' %len(FINGERPRINT_DICT)) + len(ELEMENTS)
|
| 33 |
+
|
| 34 |
+
def mol_with_atom_index(mol):
|
| 35 |
+
atoms = mol.GetNumAtoms()
|
| 36 |
+
for idx in range(atoms):
|
| 37 |
+
mol.GetAtomWithIdx(idx).SetProp('molAtomMapNumber', str(mol.GetAtomWithIdx(idx).GetIdx()))
|
| 38 |
+
return mol
|
| 39 |
+
|
| 40 |
+
def prepare_mol_for_drawing(mol):
|
| 41 |
+
try:
|
| 42 |
+
mol_draw = Draw.rdMolDraw2D.PrepareMolForDrawing(mol)
|
| 43 |
+
except Chem.KekulizeException:
|
| 44 |
+
mol_draw = Draw.rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=False)
|
| 45 |
+
Chem.SanitizeMol(mol_draw, Chem.SANITIZE_ALL ^ Chem.SANITIZE_KEKULIZE)
|
| 46 |
+
return mol_draw
|
| 47 |
+
|
| 48 |
+
def get_atom_submol_radn(mol, radius, sanitize=True):
|
| 49 |
+
atoms = []
|
| 50 |
+
submols = []
|
| 51 |
+
#smis = []
|
| 52 |
+
for atom in mol.GetAtoms():
|
| 53 |
+
atoms.append(atom)
|
| 54 |
+
r = radius
|
| 55 |
+
while r > 0:
|
| 56 |
+
try:
|
| 57 |
+
env = Chem.FindAtomEnvironmentOfRadiusN(mol, r, atom.GetIdx())
|
| 58 |
+
amap={}
|
| 59 |
+
submol = Chem.PathToSubmol(mol, env, atomMap=amap)
|
| 60 |
+
if sanitize:
|
| 61 |
+
Chem.SanitizeMol(submol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL^Chem.SanitizeFlags.SANITIZE_KEKULIZE)
|
| 62 |
+
#smis.append(Chem.MolToSmiles(submol))
|
| 63 |
+
submols.append(submol)
|
| 64 |
+
break
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(64, e)
|
| 67 |
+
r -= 1
|
| 68 |
+
|
| 69 |
+
return atoms, submols #, smis
|
| 70 |
+
|
| 71 |
+
def gen_fps_from_mol(mol, nbits=256, use_morgan=True, use_macc=False, use_rdkit=False):
|
| 72 |
+
# morgan
|
| 73 |
+
fp = []
|
| 74 |
+
if use_morgan:
|
| 75 |
+
fp_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=nbits)
|
| 76 |
+
fp1 = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| 77 |
+
fp = fp1.tolist()
|
| 78 |
+
if use_macc:
|
| 79 |
+
# MACCSkeys
|
| 80 |
+
fp_vec = MACCSkeys.GenMACCSKeys(mol)
|
| 81 |
+
fp1 = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| 82 |
+
fp.extend(fp1.tolist())
|
| 83 |
+
if use_rdkit:
|
| 84 |
+
fp_vec = Chem.RDKFingerprint(mol)
|
| 85 |
+
fp1 = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| 86 |
+
fp.extend(fp1.tolist())
|
| 87 |
+
|
| 88 |
+
return fp
|
| 89 |
+
|
| 90 |
+
def gen_subgraph_fps_from_str(s, wordsdict={}):
|
| 91 |
+
if s in wordsdict:
|
| 92 |
+
return [wordsdict[s]]
|
| 93 |
+
else:
|
| 94 |
+
return [len(wordsdict)]
|
| 95 |
+
|
| 96 |
+
def gen_subgraph_fps_from_mol(mol, wordsdict={}):
|
| 97 |
+
try:
|
| 98 |
+
k = Chem.MolToSmiles(mol)
|
| 99 |
+
return gen_subgraph_fps_from_str(k, wordsdict)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
print(e)
|
| 102 |
+
return [len(wordsdict)]
|
| 103 |
+
|
| 104 |
+
def calc_subgraph_fps_from_mol(mol, radius=2, nbits=128, use_macc=True, fptype=1, wordsdict={}):
|
| 105 |
+
#atoms, submols, smis = get_atom_submol_radn(mol, radius, True)
|
| 106 |
+
atoms, submols = get_atom_submol_radn(mol, radius, True)
|
| 107 |
+
feats = []
|
| 108 |
+
for idx, submol in enumerate(submols):
|
| 109 |
+
if fptype == 1:
|
| 110 |
+
feat = gen_fps_from_mol(submol, nbits, use_macc)
|
| 111 |
+
feats.append(feat)
|
| 112 |
+
elif fptype == 2:
|
| 113 |
+
feat = gen_subgraph_fps_from_mol(submol, wordsdict)
|
| 114 |
+
feats.append(feat)
|
| 115 |
+
|
| 116 |
+
return np.array(feats)
|
| 117 |
+
|
| 118 |
+
if __name__ == '__main__':
|
| 119 |
+
smi = 'C=C(S)C(N)(O)C'
|
| 120 |
+
smi = 'CC1CCN(CC1N(C)C2=NC=NC3=C2C=CN3)C(=O)CC#N'
|
| 121 |
+
|
| 122 |
+
mol = Chem.MolFromSmiles(smi, sanitize=False)
|
| 123 |
+
|
| 124 |
+
print(calc_subgraph_fps_from_mol(mol, 3))
|
| 125 |
+
|
| 126 |
+
mol = mol_with_atom_index(mol)
|
| 127 |
+
submols = get_atom_submol_radn(mol, 3)
|
| 128 |
+
submols = [prepare_mol_for_drawing(m) for m in submols]
|
| 129 |
+
hl = []
|
| 130 |
+
for idx, m in enumerate(submols):
|
| 131 |
+
for a in m.GetAtoms():
|
| 132 |
+
if int(a.GetProp('molAtomMapNumber')) == idx:
|
| 133 |
+
hl.append([a.GetIdx()])
|
| 134 |
+
break
|
| 135 |
+
|
| 136 |
+
draw = Draw.MolsToGridImage([mol] + submols, highlightAtomLists=[[]] + hl, molsPerRow=5)
|
| 137 |
+
draw.show()
|
| 138 |
+
|
code/GNN/utils.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def gather(x, indices):
|
| 4 |
+
indices = indices.view(-1, indices.shape[-1]).tolist()
|
| 5 |
+
out = torch.cat([x[i] for i in indices])
|
| 6 |
+
|
| 7 |
+
return out
|
| 8 |
+
|
| 9 |
+
def gather_nd(x, indices):
|
| 10 |
+
newshape = indices.shape[:-1] + x.shape[indices.shape[-1]:]
|
| 11 |
+
indices = indices.view(-1, indices.shape[-1]).tolist()
|
| 12 |
+
out = torch.cat([x[tuple(i)] for i in indices])
|
| 13 |
+
|
| 14 |
+
return out.reshape(newshape)
|
| 15 |
+
|
| 16 |
+
def gen_node_indices(size_list):
|
| 17 |
+
'''generate node index for extraction of nodes of each graph from batched data'''
|
| 18 |
+
node_num = []
|
| 19 |
+
node_range = []
|
| 20 |
+
size_list = [int(i) for i in size_list]
|
| 21 |
+
for i, n in enumerate(size_list):
|
| 22 |
+
node_num.extend([i]*n)
|
| 23 |
+
node_range.extend(list(range(n)))
|
| 24 |
+
|
| 25 |
+
node_num = torch.tensor(node_num)
|
| 26 |
+
node_range = torch.tensor(node_range)
|
| 27 |
+
indices = torch.stack([node_num, node_range], axis=1)
|
| 28 |
+
return indices, node_num, node_range
|
| 29 |
+
|
| 30 |
+
def segment_max(x, size_list):
|
| 31 |
+
size_list = [int(i) for i in size_list]
|
| 32 |
+
return torch.stack([torch.max(v, 0).values for v in torch.split(x, size_list)])
|
| 33 |
+
|
| 34 |
+
def segment_sum(x, size_list):
|
| 35 |
+
size_list = [int(i) for i in size_list]
|
| 36 |
+
return torch.stack([torch.sum(v, 0) for v in torch.split(x, size_list)])
|
| 37 |
+
|
| 38 |
+
def segment_softmax(gate, size_list):
|
| 39 |
+
segmax = segment_max(gate, size_list)
|
| 40 |
+
# expand segmax shape to alpha shape
|
| 41 |
+
segmax_expand = torch.cat([segmax[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
|
| 42 |
+
subtract = gate - segmax_expand
|
| 43 |
+
exp = torch.exp(subtract)
|
| 44 |
+
segsum = segment_sum(exp, size_list)
|
| 45 |
+
# expand segmax shape to alpha shape
|
| 46 |
+
segsum_expand = torch.cat([segsum[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
|
| 47 |
+
attention = exp / (segsum_expand + 1e-16)
|
| 48 |
+
|
| 49 |
+
return attention
|
| 50 |
+
|
| 51 |
+
def pad_V(V, max_n):
|
| 52 |
+
N, C = V.shape
|
| 53 |
+
if max_n > N:
|
| 54 |
+
zeros = torch.zeros(max_n-N, C)
|
| 55 |
+
V = torch.cat([V, zeros], dim=0)
|
| 56 |
+
return V
|
| 57 |
+
|
| 58 |
+
def pad_A(A, max_n):
|
| 59 |
+
N, L, _ = A.shape
|
| 60 |
+
if max_n > N:
|
| 61 |
+
zeros = torch.zeros(N, L, max_n-N)
|
| 62 |
+
A = torch.cat([A, zeros], dim=-1)
|
| 63 |
+
zeros = torch.zeros(max_n-N, L, max_n)
|
| 64 |
+
A = torch.cat([A, zeros], dim=0)
|
| 65 |
+
|
| 66 |
+
return A
|
| 67 |
+
|
| 68 |
+
def pad_prot(P, max_n):
|
| 69 |
+
N, = P.shape
|
| 70 |
+
if max_n > N:
|
| 71 |
+
zeros = torch.zeros(max_n-N)
|
| 72 |
+
P = torch.cat([P, zeros], dim=0)
|
| 73 |
+
|
| 74 |
+
return P.type(torch.IntTensor)
|
| 75 |
+
|
| 76 |
+
def create_batch(input, pad=False, device=torch.device('cpu')):
|
| 77 |
+
vl = []
|
| 78 |
+
al = []
|
| 79 |
+
gsl = []
|
| 80 |
+
msl = []
|
| 81 |
+
ssl = []
|
| 82 |
+
lbl = []
|
| 83 |
+
idxs = []
|
| 84 |
+
smis = []
|
| 85 |
+
|
| 86 |
+
for d in input:
|
| 87 |
+
vl.append(d['V'])
|
| 88 |
+
al.append(d['A'])
|
| 89 |
+
gsl.append(d['G'])
|
| 90 |
+
msl.append(d['mol_size'])
|
| 91 |
+
ssl.append(d['subgraph_size'])
|
| 92 |
+
lbl.append(d['label'])
|
| 93 |
+
idxs.append(d['index'])
|
| 94 |
+
smis.append(d['smiles'])
|
| 95 |
+
|
| 96 |
+
if gsl[0] is not None:
|
| 97 |
+
gsl = torch.stack(gsl, dim=0).to(device)
|
| 98 |
+
|
| 99 |
+
if pad:
|
| 100 |
+
max_n = max(map(lambda x:x.shape[0], vl))
|
| 101 |
+
vl1 = []
|
| 102 |
+
for v in vl:
|
| 103 |
+
vl1.append(pad_V(v, max_n))
|
| 104 |
+
al1 = []
|
| 105 |
+
for a in al:
|
| 106 |
+
al1.append(pad_A(a, max_n))
|
| 107 |
+
|
| 108 |
+
return {'V': torch.stack(vl1, dim=0).to(device),
|
| 109 |
+
'A': torch.stack(al1, dim=0).to(device),
|
| 110 |
+
'G': gsl,
|
| 111 |
+
'mol_size': torch.cat(msl, dim=0).to(device),
|
| 112 |
+
'subgraph_size': torch.stack(ssl, dim=0).to(device),
|
| 113 |
+
'label': torch.stack(lbl, dim=0).to(device),
|
| 114 |
+
'index': idxs,
|
| 115 |
+
'smiles': smis}
|
| 116 |
+
|
| 117 |
+
return {'V': torch.stack(vl, dim=0).to(device),
|
| 118 |
+
'A': torch.stack(al, dim=0).to(device),
|
| 119 |
+
'G': gsl,
|
| 120 |
+
'mol_size': torch.cat(msl, dim=0).to(device),
|
| 121 |
+
'subgraph_size': torch.stack(ssl, dim=0).to(device),
|
| 122 |
+
'label': torch.stack(lbl, dim=0).to(device),
|
| 123 |
+
'index': idxs,
|
| 124 |
+
'smiles': smis}
|
| 125 |
+
|
| 126 |
+
def create_mol_protein_batch(input, pad=False, device=torch.device('cpu'), pr=True):
|
| 127 |
+
vl = []
|
| 128 |
+
al = []
|
| 129 |
+
gsl = []
|
| 130 |
+
msl = []
|
| 131 |
+
ssl = []
|
| 132 |
+
prot = []
|
| 133 |
+
seq = []
|
| 134 |
+
lbl = []
|
| 135 |
+
idxs = []
|
| 136 |
+
smis = []
|
| 137 |
+
fpl = []
|
| 138 |
+
|
| 139 |
+
for d in input:
|
| 140 |
+
vl.append(d['V'])
|
| 141 |
+
al.append(d['A'])
|
| 142 |
+
gsl.append(d['G'])
|
| 143 |
+
msl.append(d['mol_size'])
|
| 144 |
+
ssl.append(d['subgraph_size'])
|
| 145 |
+
prot.append(d['protein_seq'])
|
| 146 |
+
seq.append(d['protein'])
|
| 147 |
+
lbl.append(d['label'])
|
| 148 |
+
idxs.append(d['index'])
|
| 149 |
+
smis.append(d['smiles'])
|
| 150 |
+
if 'fp' in d:
|
| 151 |
+
fpl.append(d['fp'])
|
| 152 |
+
|
| 153 |
+
if gsl[0] is not None:
|
| 154 |
+
if pad:
|
| 155 |
+
gsl = torch.stack(gsl, dim=0).to(device)
|
| 156 |
+
else:
|
| 157 |
+
gsl = [torch.unsqueeze(g, 0) for g in gsl]
|
| 158 |
+
|
| 159 |
+
if pad:
|
| 160 |
+
max_n = max(map(lambda x:x.shape[0], vl))
|
| 161 |
+
vl1 = []
|
| 162 |
+
if pr:
|
| 163 |
+
print('\tPadding V to max_n:', max_n)
|
| 164 |
+
for v in vl:
|
| 165 |
+
vl1.append(pad_V(v, max_n))
|
| 166 |
+
|
| 167 |
+
al1 = []
|
| 168 |
+
if pr:
|
| 169 |
+
print('\tPadding A to max_n:', max_n)
|
| 170 |
+
for a in al:
|
| 171 |
+
al1.append(pad_A(a, max_n))
|
| 172 |
+
|
| 173 |
+
max_prot = max(map(lambda x:x.shape[0], prot))
|
| 174 |
+
prot1 = []
|
| 175 |
+
if pr:
|
| 176 |
+
print('\tPadding protein_seq to max_n:', max_prot)
|
| 177 |
+
for p in prot:
|
| 178 |
+
prot1.append(pad_prot(p, max_prot))
|
| 179 |
+
|
| 180 |
+
fpt = None
|
| 181 |
+
if fpl:
|
| 182 |
+
fpt = torch.stack(fpl, dim=0).to(device)
|
| 183 |
+
|
| 184 |
+
return {'V': torch.stack(vl1, dim=0).to(device),
|
| 185 |
+
'A': torch.stack(al1, dim=0).to(device),
|
| 186 |
+
'G': gsl,
|
| 187 |
+
'fp': fpt,
|
| 188 |
+
'mol_size': torch.cat(msl, dim=0).to(device),
|
| 189 |
+
'subgraph_size': torch.stack(ssl, dim=0).to(device),
|
| 190 |
+
'protein_seq': torch.stack(prot1, dim=0).to(device),
|
| 191 |
+
'label': torch.stack(lbl, dim=0).view(-1).to(device),
|
| 192 |
+
'index': idxs,
|
| 193 |
+
'smiles': smis,
|
| 194 |
+
'protein': seq}
|
| 195 |
+
|
| 196 |
+
return {'V': [torch.unsqueeze(v, 0) for v in vl],
|
| 197 |
+
'A': [torch.unsqueeze(a, 0) for a in al],
|
| 198 |
+
'G': gsl,
|
| 199 |
+
'fp': fpt,
|
| 200 |
+
'mol_size': torch.cat(msl, dim=0).to(device),
|
| 201 |
+
'subgraph_size': [torch.unsqueeze(s, 0) for s in ssl],
|
| 202 |
+
'protein_seq': [torch.unsqueeze(p, 0) for p in prot],
|
| 203 |
+
'label': torch.stack(lbl, dim=0).view(-1).to(device),
|
| 204 |
+
'index': idxs,
|
| 205 |
+
'smiles': smis,
|
| 206 |
+
'protein': seq}
|
| 207 |
+
|
| 208 |
+
def create_mol_protein_fp_batch(input, pad=False, device=torch.device('cpu'), pr=True):
|
| 209 |
+
fp = []
|
| 210 |
+
prot = []
|
| 211 |
+
lbl = []
|
| 212 |
+
idxs = []
|
| 213 |
+
smis = []
|
| 214 |
+
|
| 215 |
+
for d in input:
|
| 216 |
+
fp.append(d['fp'])
|
| 217 |
+
prot.append(d['protein_seq'])
|
| 218 |
+
lbl.append(d['label'])
|
| 219 |
+
idxs.append(d['index'])
|
| 220 |
+
smis.append(d['smiles'])
|
| 221 |
+
|
| 222 |
+
if pad:
|
| 223 |
+
max_prot = max(map(lambda x:x.shape[0], prot))
|
| 224 |
+
prot1 = []
|
| 225 |
+
if pr:
|
| 226 |
+
print('\tPadding protein_seq to max_n:', max_prot)
|
| 227 |
+
for p in prot:
|
| 228 |
+
prot1.append(pad_prot(p, max_prot))
|
| 229 |
+
|
| 230 |
+
return {'fp': torch.stack(fp, dim=0).to(device),
|
| 231 |
+
'protein_seq': torch.stack(prot1, dim=0).to(device),
|
| 232 |
+
'label': torch.stack(lbl, dim=0).view(-1).to(device),
|
| 233 |
+
'index': idxs,
|
| 234 |
+
'smiles': smis}
|
| 235 |
+
|
| 236 |
+
return {'fp': [torch.unsqueeze(f, 0) for f in fp],
|
| 237 |
+
'protein_seq': [torch.unsqueeze(p, 0) for p in prot],
|
| 238 |
+
'label': torch.stack(lbl, dim=0).view(-1).to(device),
|
| 239 |
+
'index': idxs,
|
| 240 |
+
'smiles': smis}
|
code/cliplayers.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Bottleneck(nn.Module):
|
| 11 |
+
expansion = 4
|
| 12 |
+
|
| 13 |
+
def __init__(self, inplanes, planes, stride=1):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 19 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 20 |
+
|
| 21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 23 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 24 |
+
|
| 25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 26 |
+
|
| 27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 29 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 30 |
+
|
| 31 |
+
self.downsample = None
|
| 32 |
+
self.stride = stride
|
| 33 |
+
|
| 34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 36 |
+
self.downsample = nn.Sequential(OrderedDict([
|
| 37 |
+
("-1", nn.AvgPool2d(stride)),
|
| 38 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 39 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 40 |
+
]))
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor):
|
| 43 |
+
identity = x
|
| 44 |
+
|
| 45 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
| 46 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
| 47 |
+
out = self.avgpool(out)
|
| 48 |
+
out = self.bn3(self.conv3(out))
|
| 49 |
+
|
| 50 |
+
if self.downsample is not None:
|
| 51 |
+
identity = self.downsample(x)
|
| 52 |
+
|
| 53 |
+
out += identity
|
| 54 |
+
out = self.relu3(out)
|
| 55 |
+
return out
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class AttentionPool2d(nn.Module):
|
| 59 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
| 62 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 63 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 64 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 65 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 66 |
+
self.num_heads = num_heads
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
| 70 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
| 71 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 72 |
+
x, _ = F.multi_head_attention_forward(
|
| 73 |
+
query=x[:1], key=x, value=x,
|
| 74 |
+
embed_dim_to_check=x.shape[-1],
|
| 75 |
+
num_heads=self.num_heads,
|
| 76 |
+
q_proj_weight=self.q_proj.weight,
|
| 77 |
+
k_proj_weight=self.k_proj.weight,
|
| 78 |
+
v_proj_weight=self.v_proj.weight,
|
| 79 |
+
in_proj_weight=None,
|
| 80 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 81 |
+
bias_k=None,
|
| 82 |
+
bias_v=None,
|
| 83 |
+
add_zero_attn=False,
|
| 84 |
+
dropout_p=0,
|
| 85 |
+
out_proj_weight=self.c_proj.weight,
|
| 86 |
+
out_proj_bias=self.c_proj.bias,
|
| 87 |
+
use_separate_proj_weight=True,
|
| 88 |
+
training=self.training,
|
| 89 |
+
need_weights=False
|
| 90 |
+
)
|
| 91 |
+
return x.squeeze(0)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ModifiedResNet(nn.Module):
|
| 95 |
+
"""
|
| 96 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 97 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 98 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 99 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.output_dim = output_dim
|
| 105 |
+
self.input_resolution = input_resolution
|
| 106 |
+
|
| 107 |
+
# the 3-layer stem
|
| 108 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 109 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 110 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 111 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 112 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 113 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 114 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 115 |
+
self.bn3 = nn.BatchNorm2d(width)
|
| 116 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 117 |
+
self.avgpool = nn.AvgPool2d(2)
|
| 118 |
+
|
| 119 |
+
# residual layers
|
| 120 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
| 121 |
+
self.layer1 = self._make_layer(width, layers[0])
|
| 122 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 123 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 124 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
| 125 |
+
|
| 126 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
| 127 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
| 128 |
+
|
| 129 |
+
def _make_layer(self, planes, blocks, stride=1):
|
| 130 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 131 |
+
|
| 132 |
+
self._inplanes = planes * Bottleneck.expansion
|
| 133 |
+
for _ in range(1, blocks):
|
| 134 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
| 135 |
+
|
| 136 |
+
return nn.Sequential(*layers)
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
def stem(x):
|
| 140 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
| 141 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
| 142 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
| 143 |
+
x = self.avgpool(x)
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
x = x.type(self.conv1.weight.dtype)
|
| 147 |
+
x = stem(x)
|
| 148 |
+
x = self.layer1(x)
|
| 149 |
+
x = self.layer2(x)
|
| 150 |
+
x = self.layer3(x)
|
| 151 |
+
x = self.layer4(x)
|
| 152 |
+
x = self.attnpool(x)
|
| 153 |
+
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
class LayerNorm(nn.LayerNorm):
|
| 157 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 158 |
+
|
| 159 |
+
def forward(self, x: torch.Tensor):
|
| 160 |
+
orig_type = x.dtype
|
| 161 |
+
ret = super().forward(x.type(torch.float32))
|
| 162 |
+
return ret.type(orig_type)
|
| 163 |
+
|
| 164 |
+
class QuickGELU(nn.Module):
|
| 165 |
+
def forward(self, x: torch.Tensor):
|
| 166 |
+
return x * torch.sigmoid(1.702 * x)
|
| 167 |
+
|
| 168 |
+
class ResidualAttentionBlock(nn.Module):
|
| 169 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
| 170 |
+
super().__init__()
|
| 171 |
+
|
| 172 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 173 |
+
self.ln_1 = LayerNorm(d_model)
|
| 174 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 175 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 176 |
+
("gelu", QuickGELU()),
|
| 177 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 178 |
+
]))
|
| 179 |
+
self.ln_2 = LayerNorm(d_model)
|
| 180 |
+
self.attn_mask = attn_mask
|
| 181 |
+
|
| 182 |
+
def attention(self, x: torch.Tensor):
|
| 183 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 184 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 185 |
+
|
| 186 |
+
def forward(self, x: torch.Tensor):
|
| 187 |
+
x = x + self.attention(self.ln_1(x))
|
| 188 |
+
x = x + self.mlp(self.ln_2(x))
|
| 189 |
+
return x
|
| 190 |
+
|
| 191 |
+
class Transformer(nn.Module):
|
| 192 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.width = width
|
| 195 |
+
self.layers = layers
|
| 196 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
| 197 |
+
|
| 198 |
+
def forward(self, x: torch.Tensor):
|
| 199 |
+
return self.resblocks(x)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class VisionTransformer(nn.Module):
|
| 203 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
| 204 |
+
super().__init__()
|
| 205 |
+
self.input_resolution = input_resolution
|
| 206 |
+
self.output_dim = output_dim
|
| 207 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 208 |
+
|
| 209 |
+
scale = width ** -0.5
|
| 210 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 211 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
| 212 |
+
self.ln_pre = LayerNorm(width)
|
| 213 |
+
|
| 214 |
+
self.transformer = Transformer(width, layers, heads)
|
| 215 |
+
|
| 216 |
+
self.ln_post = LayerNorm(width)
|
| 217 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
| 218 |
+
|
| 219 |
+
def forward(self, x: torch.Tensor):
|
| 220 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
| 221 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 222 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 223 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 224 |
+
x = x + self.positional_embedding.to(x.dtype)
|
| 225 |
+
x = self.ln_pre(x)
|
| 226 |
+
|
| 227 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 228 |
+
x = self.transformer(x)
|
| 229 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 230 |
+
|
| 231 |
+
x = self.ln_post(x[:, 0, :])
|
| 232 |
+
|
| 233 |
+
if self.proj is not None:
|
| 234 |
+
x = x @ self.proj
|
| 235 |
+
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class CLIP(nn.Module):
|
| 240 |
+
def __init__(self,
|
| 241 |
+
embed_dim: int,
|
| 242 |
+
# vision
|
| 243 |
+
image_resolution: int,
|
| 244 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
| 245 |
+
vision_width: int,
|
| 246 |
+
vision_patch_size: int,
|
| 247 |
+
# text
|
| 248 |
+
context_length: int,
|
| 249 |
+
vocab_size: int,
|
| 250 |
+
transformer_width: int,
|
| 251 |
+
transformer_heads: int,
|
| 252 |
+
transformer_layers: int
|
| 253 |
+
):
|
| 254 |
+
super().__init__()
|
| 255 |
+
|
| 256 |
+
self.context_length = context_length
|
| 257 |
+
|
| 258 |
+
if isinstance(vision_layers, (tuple, list)):
|
| 259 |
+
vision_heads = vision_width * 32 // 64
|
| 260 |
+
self.visual = ModifiedResNet(
|
| 261 |
+
layers=vision_layers,
|
| 262 |
+
output_dim=embed_dim,
|
| 263 |
+
heads=vision_heads,
|
| 264 |
+
input_resolution=image_resolution,
|
| 265 |
+
width=vision_width
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
vision_heads = vision_width // 64
|
| 269 |
+
self.visual = VisionTransformer(
|
| 270 |
+
input_resolution=image_resolution,
|
| 271 |
+
patch_size=vision_patch_size,
|
| 272 |
+
width=vision_width,
|
| 273 |
+
layers=vision_layers,
|
| 274 |
+
heads=vision_heads,
|
| 275 |
+
output_dim=embed_dim
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
self.transformer = Transformer(
|
| 279 |
+
width=transformer_width,
|
| 280 |
+
layers=transformer_layers,
|
| 281 |
+
heads=transformer_heads,
|
| 282 |
+
attn_mask=self.build_attention_mask()
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
self.vocab_size = vocab_size
|
| 286 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
| 287 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
| 288 |
+
self.ln_final = LayerNorm(transformer_width)
|
| 289 |
+
|
| 290 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
| 291 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 292 |
+
|
| 293 |
+
self.initialize_parameters()
|
| 294 |
+
|
| 295 |
+
def initialize_parameters(self):
|
| 296 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 297 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 298 |
+
|
| 299 |
+
if isinstance(self.visual, ModifiedResNet):
|
| 300 |
+
if self.visual.attnpool is not None:
|
| 301 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
| 302 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
| 303 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
| 304 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
| 305 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
| 306 |
+
|
| 307 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
| 308 |
+
for name, param in resnet_block.named_parameters():
|
| 309 |
+
if name.endswith("bn3.weight"):
|
| 310 |
+
nn.init.zeros_(param)
|
| 311 |
+
|
| 312 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 313 |
+
attn_std = self.transformer.width ** -0.5
|
| 314 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 315 |
+
for block in self.transformer.resblocks:
|
| 316 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 317 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 318 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 319 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 320 |
+
|
| 321 |
+
if self.text_projection is not None:
|
| 322 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 323 |
+
|
| 324 |
+
def build_attention_mask(self):
|
| 325 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
| 326 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 327 |
+
mask = torch.empty(self.context_length, self.context_length)
|
| 328 |
+
mask.fill_(float("-inf"))
|
| 329 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 330 |
+
return mask
|
| 331 |
+
|
| 332 |
+
@property
|
| 333 |
+
def dtype(self):
|
| 334 |
+
return self.visual.conv1.weight.dtype
|
| 335 |
+
|
| 336 |
+
def encode_image(self, image):
|
| 337 |
+
return self.visual(image.type(self.dtype))
|
| 338 |
+
|
| 339 |
+
def encode_text(self, text):
|
| 340 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
| 341 |
+
|
| 342 |
+
x = x + self.positional_embedding.type(self.dtype)
|
| 343 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 344 |
+
x = self.transformer(x)
|
| 345 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 346 |
+
x = self.ln_final(x).type(self.dtype)
|
| 347 |
+
|
| 348 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
| 349 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 350 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 351 |
+
|
| 352 |
+
return x
|
| 353 |
+
|
| 354 |
+
def forward(self, image, text):
|
| 355 |
+
image_features = self.encode_image(image)
|
| 356 |
+
text_features = self.encode_text(text)
|
| 357 |
+
|
| 358 |
+
# normalized features
|
| 359 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
| 360 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
| 361 |
+
|
| 362 |
+
# cosine similarity as logits
|
| 363 |
+
logit_scale = self.logit_scale.exp()
|
| 364 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 365 |
+
logits_per_text = logits_per_image.t()
|
| 366 |
+
|
| 367 |
+
# shape = [global_batch_size, global_batch_size]
|
| 368 |
+
return logits_per_image, logits_per_text
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def convert_weights(model: nn.Module):
|
| 372 |
+
"""Convert applicable model parameters to fp16"""
|
| 373 |
+
|
| 374 |
+
def _convert_weights_to_fp16(l):
|
| 375 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 376 |
+
l.weight.data = l.weight.data.half()
|
| 377 |
+
if l.bias is not None:
|
| 378 |
+
l.bias.data = l.bias.data.half()
|
| 379 |
+
|
| 380 |
+
if isinstance(l, nn.MultiheadAttention):
|
| 381 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 382 |
+
tensor = getattr(l, attr)
|
| 383 |
+
if tensor is not None:
|
| 384 |
+
tensor.data = tensor.data.half()
|
| 385 |
+
|
| 386 |
+
for name in ["text_projection", "proj"]:
|
| 387 |
+
if hasattr(l, name):
|
| 388 |
+
attr = getattr(l, name)
|
| 389 |
+
if attr is not None:
|
| 390 |
+
attr.data = attr.data.half()
|
| 391 |
+
|
| 392 |
+
model.apply(_convert_weights_to_fp16)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def build_model(state_dict: dict):
|
| 396 |
+
vit = "visual.proj" in state_dict
|
| 397 |
+
|
| 398 |
+
if vit:
|
| 399 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 400 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 401 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 402 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 403 |
+
image_resolution = vision_patch_size * grid_size
|
| 404 |
+
else:
|
| 405 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 406 |
+
vision_layers = tuple(counts)
|
| 407 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 408 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 409 |
+
vision_patch_size = None
|
| 410 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 411 |
+
image_resolution = output_width * 32
|
| 412 |
+
|
| 413 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
| 414 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
| 415 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 416 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 417 |
+
transformer_heads = transformer_width // 64
|
| 418 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
| 419 |
+
|
| 420 |
+
model = CLIP(
|
| 421 |
+
embed_dim,
|
| 422 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
| 423 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 427 |
+
if key in state_dict:
|
| 428 |
+
del state_dict[key]
|
| 429 |
+
|
| 430 |
+
convert_weights(model)
|
| 431 |
+
model.load_state_dict(state_dict)
|
| 432 |
+
return model.eval()
|
code/config.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, json, math, os
|
| 2 |
+
|
| 3 |
+
d = {
|
| 4 |
+
'debug': True,
|
| 5 |
+
'dataset_path': 'data/path_to_your_dataset.json',
|
| 6 |
+
'fptype': 'morgan',
|
| 7 |
+
'valid_ratio': 0.1,
|
| 8 |
+
'batch_size': 128,
|
| 9 |
+
'lr': 1e-3,
|
| 10 |
+
'weight_decay': 1e-3,
|
| 11 |
+
'patience': 2,
|
| 12 |
+
'factor': 0.5,
|
| 13 |
+
'add_nl': True,
|
| 14 |
+
'binary_intn': False,
|
| 15 |
+
'max_mz': 2000,
|
| 16 |
+
'min_mz': 20,
|
| 17 |
+
'energy': 'Energy1',
|
| 18 |
+
'epochs': 50,
|
| 19 |
+
'bin_size': 0.05,
|
| 20 |
+
'ms_embedding_dim': 300,
|
| 21 |
+
'projection_dim': 256,
|
| 22 |
+
'ms_projection_layers': 1,
|
| 23 |
+
'mol_embedding_dim': 2048,
|
| 24 |
+
'mol_projection_layers': 1,
|
| 25 |
+
'tsfm_in_ms': True,
|
| 26 |
+
'tsfm_in_mol': False,
|
| 27 |
+
'tsfm_layers': 6,
|
| 28 |
+
'tsfm_heads': 8,
|
| 29 |
+
'lstm_layers': 2,
|
| 30 |
+
'lstm_in_ms': False,
|
| 31 |
+
'lstm_in_mol': False,
|
| 32 |
+
'dropout': 0.1,
|
| 33 |
+
'nmodels': 1,
|
| 34 |
+
'mol_encoder': 'fp', # fp, gnn or gnn+fp
|
| 35 |
+
'molgnn_n_filters_list': [256, 256, 256],
|
| 36 |
+
'molgnn_nhead': 4,
|
| 37 |
+
'molgnn_readout_layers': 2,
|
| 38 |
+
'seed': 1234,
|
| 39 |
+
'dev_name': 'cuda',
|
| 40 |
+
'keep_best_models_num': 3
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
class ConfigDict(dict):
|
| 44 |
+
'''
|
| 45 |
+
Makes a dictionary behave like an object,with attribute-style access.
|
| 46 |
+
'''
|
| 47 |
+
def __getattr__(self, name):
|
| 48 |
+
try:
|
| 49 |
+
return self[name]
|
| 50 |
+
except:
|
| 51 |
+
raise AttributeError(name)
|
| 52 |
+
|
| 53 |
+
def __setattr__(self, name, value):
|
| 54 |
+
self[name] = value
|
| 55 |
+
|
| 56 |
+
def save(self, fn, onlyprint=False):
|
| 57 |
+
if onlyprint:
|
| 58 |
+
print(self)
|
| 59 |
+
else:
|
| 60 |
+
json.dump(self, open(fn, 'w'), indent=2)
|
| 61 |
+
|
| 62 |
+
def load_dict(self, dic):
|
| 63 |
+
for k, v in dic.items():
|
| 64 |
+
self[k] = v
|
| 65 |
+
self.calc_ms_embedding_dim()
|
| 66 |
+
|
| 67 |
+
def load(self, fn):
|
| 68 |
+
try:
|
| 69 |
+
if type(fn) is dict:
|
| 70 |
+
d = fn
|
| 71 |
+
elif type(fn) is str:
|
| 72 |
+
if os.path.exists(fn):
|
| 73 |
+
d = json.load(open(fn, 'r'))
|
| 74 |
+
else:
|
| 75 |
+
d = json.loads(fn)
|
| 76 |
+
self.load_dict(d)
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(e)
|
| 79 |
+
|
| 80 |
+
def calc_ms_embedding_dim(self):
|
| 81 |
+
if 'bin_size' in self:
|
| 82 |
+
self['ms_embedding_dim'] = math.ceil((self['max_mz'] - self['min_mz']) / self['bin_size'])
|
| 83 |
+
if 'ms_embedding_dim' in self and 'add_nl' in self and self['add_nl']:
|
| 84 |
+
self['ms_embedding_dim'] += math.ceil((200) / self['bin_size'])
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def device(self):
|
| 88 |
+
try:
|
| 89 |
+
return torch.device(self['dev_name'])
|
| 90 |
+
except:
|
| 91 |
+
return torch.device('cpu')
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
CFG = ConfigDict()
|
| 95 |
+
CFG.load_dict(d)
|
code/dataset.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json
|
| 2 |
+
import torch
|
| 3 |
+
import utils
|
| 4 |
+
|
| 5 |
+
def calc_feats(smi, ms, nls, cfg):
|
| 6 |
+
item = {}
|
| 7 |
+
item['ms_bins'] = utils.ms_binner(ms, nls,
|
| 8 |
+
min_mz=cfg.min_mz,
|
| 9 |
+
max_mz=cfg.max_mz,
|
| 10 |
+
bin_size=cfg.bin_size,
|
| 11 |
+
add_nl=cfg.add_nl,
|
| 12 |
+
binary_intn=cfg.binary_intn)
|
| 13 |
+
|
| 14 |
+
fmcalced = False
|
| 15 |
+
if 'fp' in cfg.mol_encoder:
|
| 16 |
+
if not 'fm' in cfg.mol_encoder:
|
| 17 |
+
item['mol_fps'] = utils.mol_fp_encoder(smi,
|
| 18 |
+
tp=cfg.fptype,
|
| 19 |
+
nbits=cfg.mol_embedding_dim)
|
| 20 |
+
else:
|
| 21 |
+
item['mol_fps'], item['mol_fmvec'] = utils.mol_fp_fm_encoder(smi,
|
| 22 |
+
tp=cfg.fptype,
|
| 23 |
+
nbits=cfg.mol_embedding_dim)
|
| 24 |
+
fmcalced = True
|
| 25 |
+
if 'gnn' in cfg.mol_encoder:
|
| 26 |
+
f = utils.mol_graph_featurizer(smi)
|
| 27 |
+
if not f:
|
| 28 |
+
return None
|
| 29 |
+
item.update(f)
|
| 30 |
+
if 'fm' in cfg.mol_encoder and not fmcalced:
|
| 31 |
+
item['mol_fmvec'] = utils.smi2fmvec(smi)
|
| 32 |
+
|
| 33 |
+
return item
|
| 34 |
+
|
| 35 |
+
class Dataset(torch.utils.data.Dataset):
|
| 36 |
+
def __init__(self, inp, cfg):
|
| 37 |
+
if type(inp) is str:
|
| 38 |
+
self.data = json.load(open(inp))
|
| 39 |
+
else:
|
| 40 |
+
self.data = inp
|
| 41 |
+
|
| 42 |
+
self.cfg = cfg
|
| 43 |
+
|
| 44 |
+
def __getitem__(self, idx):
|
| 45 |
+
item = {}
|
| 46 |
+
try:
|
| 47 |
+
if 'nls' in self.data[idx]:
|
| 48 |
+
nls = self.data[idx]['nls']
|
| 49 |
+
else:
|
| 50 |
+
nls = []
|
| 51 |
+
|
| 52 |
+
ms = self.data[idx]['ms']
|
| 53 |
+
smi = self.data[idx]['smiles']
|
| 54 |
+
|
| 55 |
+
item = calc_feats(smi, ms, nls, self.cfg)
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print('='*50, idx, str(e))
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
return item
|
| 62 |
+
|
| 63 |
+
def __len__(self):
|
| 64 |
+
return len(self.data)
|
| 65 |
+
|
| 66 |
+
class DatasetGNNFP(torch.utils.data.Dataset):
|
| 67 |
+
def __init__(self, inp, cfg):
|
| 68 |
+
if type(inp) is str:
|
| 69 |
+
self.data = json.load(open(inp))
|
| 70 |
+
else:
|
| 71 |
+
self.data = inp
|
| 72 |
+
|
| 73 |
+
self.cfg = cfg
|
| 74 |
+
|
| 75 |
+
def __getitem__(self, idx):
|
| 76 |
+
try:
|
| 77 |
+
smi = self.data[idx]['smiles']
|
| 78 |
+
item = {}
|
| 79 |
+
item['mol_fps'] = utils.mol_fp_encoder(smi,
|
| 80 |
+
tp=self.cfg.fptype,
|
| 81 |
+
nbits=self.cfg.mol_embedding_dim)
|
| 82 |
+
item.update(utils.mol_graph_featurizer(smi))
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print('='*50, idx, str(e))
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
return item
|
| 88 |
+
|
| 89 |
+
def __len__(self):
|
| 90 |
+
return len(self.data)
|
| 91 |
+
|
| 92 |
+
class PathDataset(torch.utils.data.Dataset):
|
| 93 |
+
def __init__(self, pathlist, cfg):
|
| 94 |
+
self.fns = pathlist
|
| 95 |
+
self.cfg = cfg
|
| 96 |
+
self.data = {}
|
| 97 |
+
|
| 98 |
+
def __getitem__(self, idx):
|
| 99 |
+
try:
|
| 100 |
+
item = {}
|
| 101 |
+
nls = []
|
| 102 |
+
if not idx in self.data:
|
| 103 |
+
out = self.proc_data(self.fns[idx], self.cfg.energy)
|
| 104 |
+
if out is None:
|
| 105 |
+
return None
|
| 106 |
+
self.data[idx] = out
|
| 107 |
+
|
| 108 |
+
ms = self.data[idx]['ms']
|
| 109 |
+
smi = self.data[idx]['smiles']
|
| 110 |
+
|
| 111 |
+
item = calc_feats(smi, ms, nls, self.cfg)
|
| 112 |
+
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print('='*50, idx, str(e))
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
return item
|
| 118 |
+
|
| 119 |
+
def proc_data(self, fn, energy='Energy1'):
|
| 120 |
+
tl = open(fn).readlines()
|
| 121 |
+
l = []
|
| 122 |
+
try:
|
| 123 |
+
flag = False
|
| 124 |
+
for i in tl:
|
| 125 |
+
if energy in i:
|
| 126 |
+
smi = i.split(';')[-2]
|
| 127 |
+
flag = True
|
| 128 |
+
continue
|
| 129 |
+
if 'END IONS' in i:
|
| 130 |
+
if flag:
|
| 131 |
+
break
|
| 132 |
+
if flag:
|
| 133 |
+
mz, intn = i.split(' ')
|
| 134 |
+
l.append((float(mz), float(intn)))
|
| 135 |
+
except:
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
out = {'ms': l, 'smiles': smi}
|
| 139 |
+
return out
|
| 140 |
+
|
| 141 |
+
def __len__(self):
|
| 142 |
+
return len(self.fns)
|
code/modules.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from config import CFG
|
| 5 |
+
import utils
|
| 6 |
+
import math
|
| 7 |
+
import numpy as np
|
| 8 |
+
from cliplayers import QuickGELU, Transformer as MSTsfmEncoder
|
| 9 |
+
from GNN import layers as gly
|
| 10 |
+
|
| 11 |
+
loss_func_ms = nn.CrossEntropyLoss()
|
| 12 |
+
loss_func = nn.CrossEntropyLoss()
|
| 13 |
+
|
| 14 |
+
class MolGNNEncoder(nn.Module):
|
| 15 |
+
def __init__(self,
|
| 16 |
+
outdim,
|
| 17 |
+
n_feats=74, #330, # 74+256 morgan 256
|
| 18 |
+
n_filters_list=[256, 256, 256],
|
| 19 |
+
n_head=4,
|
| 20 |
+
mols=1,
|
| 21 |
+
adj_chans=6,
|
| 22 |
+
readout_layers=2,
|
| 23 |
+
bias=True):
|
| 24 |
+
|
| 25 |
+
super().__init__()
|
| 26 |
+
|
| 27 |
+
n_filters_list = [i for i in n_filters_list if i is not None]
|
| 28 |
+
lys = []
|
| 29 |
+
|
| 30 |
+
for i, nf in enumerate(n_filters_list):
|
| 31 |
+
if i == 0:
|
| 32 |
+
nf1 = n_feats
|
| 33 |
+
else:
|
| 34 |
+
nf1 = prevnf
|
| 35 |
+
|
| 36 |
+
prevnf = nf
|
| 37 |
+
|
| 38 |
+
ly = gly.GConvBlockNoGF(nf1, nf, mols, adj_chans, bias)
|
| 39 |
+
lys.append(ly)
|
| 40 |
+
|
| 41 |
+
self.block_layers = nn.ModuleList(lys)
|
| 42 |
+
self.attention_layer = gly.MultiHeadGlobalAttention(nf, n_head=n_head, concat=True, bias=bias)
|
| 43 |
+
self.readout_layers = nn.ModuleList([nn.Linear(nf*n_head, outdim, bias=bias)] + [nn.Linear(outdim, outdim) for _ in range(readout_layers-1)])
|
| 44 |
+
self.gelu = QuickGELU()
|
| 45 |
+
|
| 46 |
+
def forward(self, batch):
|
| 47 |
+
V = batch['V']
|
| 48 |
+
A = batch['A']
|
| 49 |
+
mol_size = batch['mol_size']
|
| 50 |
+
|
| 51 |
+
for ly in self.block_layers:
|
| 52 |
+
V = ly(V, A)
|
| 53 |
+
|
| 54 |
+
X = self.attention_layer(V, mol_size)
|
| 55 |
+
|
| 56 |
+
for ly in self.readout_layers:
|
| 57 |
+
X = self.gelu(ly(X))
|
| 58 |
+
|
| 59 |
+
return X
|
| 60 |
+
|
| 61 |
+
class ProjectionHead(nn.Module):
|
| 62 |
+
def __init__(self,
|
| 63 |
+
embedding_dim,
|
| 64 |
+
projection_dim,
|
| 65 |
+
cfg,
|
| 66 |
+
transformer=True,
|
| 67 |
+
lstm=False):
|
| 68 |
+
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.projection = nn.Linear(embedding_dim, projection_dim)
|
| 72 |
+
self.gelu = nn.GELU() #QuickGELU()
|
| 73 |
+
self.transformer = None
|
| 74 |
+
if transformer:
|
| 75 |
+
self.transformer = MSTsfmEncoder(projection_dim, cfg.tsfm_layers, cfg.tsfm_heads)
|
| 76 |
+
self.lstm = None
|
| 77 |
+
if lstm:
|
| 78 |
+
self.lstm = nn.LSTM(input_size=projection_dim, hidden_size=projection_dim, num_layers=cfg.lstm_layers, batch_first=True)
|
| 79 |
+
self.dropout = nn.Dropout(cfg.dropout)
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
projected = self.projection(x)
|
| 83 |
+
if self.transformer is None:
|
| 84 |
+
x = self.gelu(projected)
|
| 85 |
+
else:
|
| 86 |
+
x = self.transformer(projected)
|
| 87 |
+
if not self.lstm is None:
|
| 88 |
+
x, (_, _) = self.lstm(x)
|
| 89 |
+
x = self.dropout(x)
|
| 90 |
+
|
| 91 |
+
return x
|
| 92 |
+
|
| 93 |
+
# New name in paper is CMSSPModel
|
| 94 |
+
class FragSimiModel(nn.Module):
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
cfg
|
| 98 |
+
):
|
| 99 |
+
super().__init__()
|
| 100 |
+
|
| 101 |
+
self.cfg = cfg
|
| 102 |
+
self.mol_gnn_encoder = None
|
| 103 |
+
mol_embedding_dim = cfg.mol_embedding_dim
|
| 104 |
+
|
| 105 |
+
if 'gnn' in self.cfg.mol_encoder:
|
| 106 |
+
self.mol_gnn_encoder = MolGNNEncoder(outdim=cfg.mol_embedding_dim,
|
| 107 |
+
n_filters_list=cfg.molgnn_n_filters_list,
|
| 108 |
+
n_head=cfg.molgnn_nhead,
|
| 109 |
+
readout_layers=cfg.molgnn_readout_layers)
|
| 110 |
+
if 'fp' in self.cfg.mol_encoder:
|
| 111 |
+
mol_embedding_dim = 2*cfg.mol_embedding_dim
|
| 112 |
+
|
| 113 |
+
if 'fm' in self.cfg.mol_encoder:
|
| 114 |
+
mol_embedding_dim += 10
|
| 115 |
+
|
| 116 |
+
self.ms_projection = ProjectionHead(cfg.ms_embedding_dim,
|
| 117 |
+
cfg.projection_dim,
|
| 118 |
+
cfg,
|
| 119 |
+
cfg.tsfm_in_ms,
|
| 120 |
+
cfg.lstm_in_ms)
|
| 121 |
+
|
| 122 |
+
self.mol_projection = ProjectionHead(mol_embedding_dim,
|
| 123 |
+
cfg.projection_dim,
|
| 124 |
+
cfg,
|
| 125 |
+
cfg.tsfm_in_mol,
|
| 126 |
+
cfg.lstm_in_mol)
|
| 127 |
+
|
| 128 |
+
def forward(self, batch):
|
| 129 |
+
ms_features = batch["ms_bins"]
|
| 130 |
+
mol_feat_list = []
|
| 131 |
+
if 'gnn' in self.cfg.mol_encoder:
|
| 132 |
+
mol_feat_list.append(self.mol_gnn_encoder(batch))
|
| 133 |
+
if 'fp' in self.cfg.mol_encoder:
|
| 134 |
+
mol_feat_list.append(batch["mol_fps"])
|
| 135 |
+
if 'fm' in self.cfg.mol_encoder:
|
| 136 |
+
mol_feat_list.append(batch["mol_fmvec"])
|
| 137 |
+
|
| 138 |
+
if len(mol_feat_list) > 1:
|
| 139 |
+
mol_features = torch.cat(mol_feat_list, dim=1)
|
| 140 |
+
else:
|
| 141 |
+
mol_features = mol_feat_list[0]
|
| 142 |
+
|
| 143 |
+
# Getting ms and mol Embeddings (with same dimension)
|
| 144 |
+
ms_embeddings = self.ms_projection(ms_features)
|
| 145 |
+
mol_embeddings = self.mol_projection(mol_features)
|
| 146 |
+
|
| 147 |
+
# Calculating the Loss
|
| 148 |
+
#logits = (mol_embeddings @ ms_embeddings.t())
|
| 149 |
+
#logit_scale = self.logit_scale.exp()
|
| 150 |
+
logits = mol_embeddings @ ms_embeddings.t()
|
| 151 |
+
|
| 152 |
+
ground_truth = torch.arange(ms_features.shape[0], dtype=torch.long, device=self.cfg.device)
|
| 153 |
+
|
| 154 |
+
ms_loss = loss_func(logits, ground_truth)
|
| 155 |
+
mol_loss = loss_func(logits.t(), ground_truth)
|
| 156 |
+
loss = (ms_loss + mol_loss) / 2.0 # shape: (batch_size)
|
| 157 |
+
|
| 158 |
+
return loss.mean()
|
code/predict.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from modules import *
|
| 2 |
+
import os, sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from config import CFG
|
| 8 |
+
import utils
|
| 9 |
+
import json
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import pickle
|
| 12 |
+
|
| 13 |
+
MolFeatsCached = {}
|
| 14 |
+
|
| 15 |
+
def calc_mol_embeddings0(model, smis, cfg):
|
| 16 |
+
model.eval()
|
| 17 |
+
|
| 18 |
+
valid_mol_embeddings = []
|
| 19 |
+
with torch.no_grad():
|
| 20 |
+
for smi in smis:
|
| 21 |
+
try:
|
| 22 |
+
mol_features = utils.mol_fp_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim).to(cfg.device)
|
| 23 |
+
mol_embeddings = model.mol_projection(mol_features.unsqueeze(0))
|
| 24 |
+
valid_mol_embeddings.append(mol_embeddings.squeeze(0))
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(smi, e)
|
| 27 |
+
continue
|
| 28 |
+
|
| 29 |
+
return torch.stack(valid_mol_embeddings)
|
| 30 |
+
|
| 31 |
+
def calc_mol_embeddings1(model, smis, cfg):
|
| 32 |
+
model.eval()
|
| 33 |
+
mol_embeddings = []
|
| 34 |
+
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
for smi in smis:
|
| 37 |
+
try:
|
| 38 |
+
if cfg.mol_encoder == 'fp':
|
| 39 |
+
k = hash(smi + f'fp-{cfg.fptype}-{cfg.mol_embedding_dim}')
|
| 40 |
+
if k in MolFeatsCached:
|
| 41 |
+
feats = MolFeatsCached[k]
|
| 42 |
+
else:
|
| 43 |
+
feats = utils.mol_fp_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim).to(cfg.device)
|
| 44 |
+
MolFeatsCached[k] = feats
|
| 45 |
+
me = model.mol_projection(feats.unsqueeze(0))
|
| 46 |
+
mol_embeddings.append(me.squeeze(0))
|
| 47 |
+
elif cfg.mol_encoder == 'gnn':
|
| 48 |
+
k = hash(smi + 'gnn')
|
| 49 |
+
if k in MolFeatsCached:
|
| 50 |
+
gfeats = MolFeatsCached[k]
|
| 51 |
+
else:
|
| 52 |
+
gfeats = utils.mol_graph_featurizer(smi)
|
| 53 |
+
MolFeatsCached[k] = gfeats
|
| 54 |
+
|
| 55 |
+
bat = {'A': gfeats['A'].unsqueeze(0).to(cfg.device),
|
| 56 |
+
'V': gfeats['V'].unsqueeze(0).to(cfg.device),
|
| 57 |
+
'mol_size': gfeats['mol_size'].unsqueeze(0).to(cfg.device)}
|
| 58 |
+
|
| 59 |
+
feats = model.mol_gnn_encoder(bat)
|
| 60 |
+
me = model.mol_projection(feats)
|
| 61 |
+
mol_embeddings.append(me.squeeze(0))
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(smi, e)
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
return torch.stack(mol_embeddings)
|
| 67 |
+
|
| 68 |
+
def calc_mol_embeddings(model, smis, cfg):
|
| 69 |
+
model.eval()
|
| 70 |
+
fp_featsl = []
|
| 71 |
+
gnn_featsl = []
|
| 72 |
+
fm_featsl = []
|
| 73 |
+
|
| 74 |
+
for smi in smis:
|
| 75 |
+
try:
|
| 76 |
+
if 'gnn' in cfg.mol_encoder:
|
| 77 |
+
k = hash(smi + 'gnn')
|
| 78 |
+
if k in MolFeatsCached:
|
| 79 |
+
gnn_feats = MolFeatsCached[k]
|
| 80 |
+
if gnn_feats is None:
|
| 81 |
+
continue
|
| 82 |
+
else:
|
| 83 |
+
gnn_feats = utils.mol_graph_featurizer(smi)
|
| 84 |
+
MolFeatsCached[k] = gnn_feats
|
| 85 |
+
if gnn_feats is None:
|
| 86 |
+
continue
|
| 87 |
+
gnn_featsl.append(gnn_feats)
|
| 88 |
+
if 'fp' in cfg.mol_encoder:
|
| 89 |
+
k = hash(smi + f'fp-{cfg.fptype}-{cfg.mol_embedding_dim}')
|
| 90 |
+
if k in MolFeatsCached:
|
| 91 |
+
fp_feats = MolFeatsCached[k]
|
| 92 |
+
if fp_feats is None:
|
| 93 |
+
continue
|
| 94 |
+
else:
|
| 95 |
+
fp_feats = utils.mol_fp_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim).to(cfg.device)
|
| 96 |
+
MolFeatsCached[k] = fp_feats
|
| 97 |
+
fp_featsl.append(fp_feats)
|
| 98 |
+
if 'fm' in cfg.mol_encoder:
|
| 99 |
+
k = hash(smi + f'fm-{cfg.fptype}-{cfg.mol_embedding_dim}')
|
| 100 |
+
if k in MolFeatsCached:
|
| 101 |
+
fm_feats = MolFeatsCached[k]
|
| 102 |
+
if fm_feats is None:
|
| 103 |
+
continue
|
| 104 |
+
else:
|
| 105 |
+
fm_feats = utils.smi2fmvec(smi).to(cfg.device)
|
| 106 |
+
MolFeatsCached[k] = fm_feats
|
| 107 |
+
fm_featsl.append(fm_feats)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print(smi, e)
|
| 110 |
+
MolFeatsCached[k] = None
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
mol_feat_list = []
|
| 114 |
+
if 'gnn' in cfg.mol_encoder:
|
| 115 |
+
vl, al, msl = [], [], []
|
| 116 |
+
bat = {}
|
| 117 |
+
for b in gnn_featsl:
|
| 118 |
+
if 'V' in b:
|
| 119 |
+
vl.append(b['V'])
|
| 120 |
+
if 'A' in b:
|
| 121 |
+
al.append(b['A'])
|
| 122 |
+
if 'mol_size' in b:
|
| 123 |
+
msl.append(b['mol_size'])
|
| 124 |
+
|
| 125 |
+
vl1, al1 = [], []
|
| 126 |
+
if vl and al and msl:
|
| 127 |
+
max_n = max(map(lambda x:x.shape[0], vl))
|
| 128 |
+
for v in vl:
|
| 129 |
+
vl1.append(utils.pad_V(v, max_n))
|
| 130 |
+
for a in al:
|
| 131 |
+
al1.append(utils.pad_A(a, max_n))
|
| 132 |
+
|
| 133 |
+
bat['V'] = torch.stack(vl1).to(cfg.device)
|
| 134 |
+
bat['A'] = torch.stack(al1).to(cfg.device)
|
| 135 |
+
bat['mol_size'] = torch.cat(msl, dim=0).to(cfg.device)
|
| 136 |
+
|
| 137 |
+
mol_feat_list.append(model.mol_gnn_encoder(bat))
|
| 138 |
+
|
| 139 |
+
if 'fp' in cfg.mol_encoder:
|
| 140 |
+
mol_feat_list.append(torch.stack(fp_featsl).to(cfg.device))
|
| 141 |
+
|
| 142 |
+
if 'fm' in cfg.mol_encoder:
|
| 143 |
+
mol_feat_list.append(torch.stack(fm_featsl).to(cfg.device))
|
| 144 |
+
|
| 145 |
+
if len(mol_feat_list) > 1:
|
| 146 |
+
mol_features = torch.cat(mol_feat_list, dim=1).to(cfg.device)
|
| 147 |
+
else:
|
| 148 |
+
mol_features = mol_feat_list[0].to(cfg.device)
|
| 149 |
+
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
mol_embeddings = model.mol_projection(mol_features)
|
| 152 |
+
|
| 153 |
+
return mol_embeddings
|
| 154 |
+
|
| 155 |
+
def find_matches(model, ms, smis, cfg, n=10):
|
| 156 |
+
model.eval()
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
ms_features = utils.ms_binner(ms, min_mz=cfg.min_mz, max_mz=cfg.max_mz, bin_size=cfg.bin_size, add_nl=cfg.add_nl, binary_intn=cfg.binary_intn).to(cfg.device)
|
| 159 |
+
ms_features = ms_features.unsqueeze(0)
|
| 160 |
+
ms_embeddings = model.ms_projection(ms_features).squeeze(0)
|
| 161 |
+
|
| 162 |
+
#print(43, ms_features.shape, ms_embeddings.shape)
|
| 163 |
+
|
| 164 |
+
mol_embeddings = calc_mol_embeddings(model, smis, cfg)
|
| 165 |
+
|
| 166 |
+
mol_embeddings_n = F.normalize(mol_embeddings, p=2, dim=-1)
|
| 167 |
+
ms_embeddings_n = F.normalize(ms_embeddings, p=2, dim=-1)
|
| 168 |
+
dot_similarity = mol_embeddings_n @ ms_embeddings_n.t()
|
| 169 |
+
|
| 170 |
+
if n == -1 or n > len(mol_embeddings):
|
| 171 |
+
n = len(mol_embeddings)
|
| 172 |
+
|
| 173 |
+
values, indices = torch.topk(dot_similarity.squeeze(0), n)
|
| 174 |
+
|
| 175 |
+
matchsmis = [smis[idx] for idx in indices]
|
| 176 |
+
|
| 177 |
+
return matchsmis, values.to('cpu').data.numpy()*100, indices.to('cpu').data.numpy()
|
| 178 |
+
|
| 179 |
+
def calc(models, datal, cfg, saveout=True):
|
| 180 |
+
dicall = {}
|
| 181 |
+
coridxd = {}
|
| 182 |
+
|
| 183 |
+
for idx, model in enumerate(models):
|
| 184 |
+
for nn, data in enumerate(datal):
|
| 185 |
+
print(f'Calculating {nn}-th MS...')
|
| 186 |
+
#smipool = [d[1] for d in data['candidates'][:50]]
|
| 187 |
+
smipool = [d[1] for d in data['candidates']]
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
smis, scores, indices = find_matches(model, data['ms'], smipool, cfg, 50)
|
| 191 |
+
except Exception as e:
|
| 192 |
+
print(131, e)
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
dic = {}
|
| 196 |
+
for n, smi in enumerate(smis):
|
| 197 |
+
if smi in dic:
|
| 198 |
+
dic[smi]['score'] += scores[n]
|
| 199 |
+
dic[smi]['iscor'] = data['candidates'][indices[n]][-1]
|
| 200 |
+
dic[smi]['idx'] = data['candidates'][indices[n]][0]
|
| 201 |
+
else:
|
| 202 |
+
dic[smi] = {'score': scores[n], 'iscor': data['candidates'][indices[n]][-1], 'idx': data['candidates'][indices[n]][0]}
|
| 203 |
+
|
| 204 |
+
ikey = data['ikey']
|
| 205 |
+
if ikey in dicall:
|
| 206 |
+
for k, v in dic.items():
|
| 207 |
+
if k in dicall[ikey]:
|
| 208 |
+
dicall[ikey][k]['score'] += v['score']
|
| 209 |
+
else:
|
| 210 |
+
dicall[ikey][k] = v
|
| 211 |
+
else:
|
| 212 |
+
dicall[ikey] = dic
|
| 213 |
+
|
| 214 |
+
for ikey, dic in dicall.items():
|
| 215 |
+
smis = [k for k in dic.keys()]
|
| 216 |
+
scorel = [d['score'] for d in dic.values()]
|
| 217 |
+
iscorl = [d['iscor'] for d in dic.values()]
|
| 218 |
+
indexl = [d['idx'] for d in dic.values()]
|
| 219 |
+
|
| 220 |
+
scoretsor = torch.tensor(scorel)
|
| 221 |
+
n = 100
|
| 222 |
+
if n > len(scorel):
|
| 223 |
+
n = len(scorel)
|
| 224 |
+
|
| 225 |
+
values, indices = torch.topk(scoretsor, n)
|
| 226 |
+
|
| 227 |
+
scorel = values
|
| 228 |
+
smis = [smis[i] for i in indices]
|
| 229 |
+
iscorl = [iscorl[i] for i in indices]
|
| 230 |
+
indexl = [indexl[i] for i in indices]
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
i = iscorl.index(True)
|
| 234 |
+
k = 'Hit %.3d' %(i+1)
|
| 235 |
+
if k in coridxd:
|
| 236 |
+
coridxd[k] += 1
|
| 237 |
+
else:
|
| 238 |
+
coridxd[k] = 1
|
| 239 |
+
except:
|
| 240 |
+
pass
|
| 241 |
+
|
| 242 |
+
ks = sorted(list(coridxd.keys()))
|
| 243 |
+
dc = {}
|
| 244 |
+
sumtop3 = 0
|
| 245 |
+
|
| 246 |
+
for k in ks:
|
| 247 |
+
dc[k] = [coridxd[k]]
|
| 248 |
+
if k in ['Hit 001', 'Hit 002', 'Hit 003']:
|
| 249 |
+
sumtop3 += coridxd[k]
|
| 250 |
+
|
| 251 |
+
for i in range(100):
|
| 252 |
+
k = 'Hit %.3d' %(i+1)
|
| 253 |
+
if not k in dc:
|
| 254 |
+
dc[k] = [0]
|
| 255 |
+
|
| 256 |
+
'''if saveout:
|
| 257 |
+
df0 = pd.DataFrame(dc)
|
| 258 |
+
df0.to_csv('summary.csv', index=False)
|
| 259 |
+
|
| 260 |
+
df = pd.DataFrame({
|
| 261 |
+
'MSFn': ikeysl,
|
| 262 |
+
'Item': iteml,
|
| 263 |
+
'Index': smisidl,
|
| 264 |
+
'Smiles': smis,
|
| 265 |
+
'Score': scoresl,
|
| 266 |
+
'IsCorrect': iscorl})
|
| 267 |
+
|
| 268 |
+
df.to_csv('predicted.csv', index=False)'''
|
| 269 |
+
|
| 270 |
+
return sumtop3, dc, dicall
|
| 271 |
+
|
| 272 |
+
def test(modelfnl, datal, datafn=''):
|
| 273 |
+
maxtop3 = 0
|
| 274 |
+
maxoutt = ''
|
| 275 |
+
|
| 276 |
+
for fn in modelfnl:
|
| 277 |
+
d = torch.load(fn)
|
| 278 |
+
CFG.load(d['config'])
|
| 279 |
+
print(d['config'])
|
| 280 |
+
CFG.save('', True)
|
| 281 |
+
|
| 282 |
+
model = FragSimiModel(CFG).to(CFG.device)
|
| 283 |
+
model.load_state_dict(d['state_dict'])
|
| 284 |
+
model.to(CFG.device)
|
| 285 |
+
|
| 286 |
+
sumtop3, dc, dicall = calc([model], datal, CFG, saveout=False)
|
| 287 |
+
|
| 288 |
+
sumtop10 = 0
|
| 289 |
+
for k in ['Hit %.3d' %(i+1) for i in range(10)]:
|
| 290 |
+
if k in dc:
|
| 291 |
+
sumtop10 += dc[k][0]
|
| 292 |
+
|
| 293 |
+
sumtop50 = 0
|
| 294 |
+
for k in ['Hit %.3d' %(i+1) for i in range(50)]:
|
| 295 |
+
if k in dc:
|
| 296 |
+
sumtop50 += dc[k][0]
|
| 297 |
+
|
| 298 |
+
tops = {}
|
| 299 |
+
for i in range(100):
|
| 300 |
+
k = 'Hit %.3d' %(i+1)
|
| 301 |
+
key = k.replace('Hit', 'Top')
|
| 302 |
+
if not key in tops:
|
| 303 |
+
tops[key] = [0]
|
| 304 |
+
if k in dc:
|
| 305 |
+
for n in range(i+1):
|
| 306 |
+
kk = 'Hit %.3d' %(n+1)
|
| 307 |
+
if kk in dc:
|
| 308 |
+
tops[key][0] += dc[kk][0]
|
| 309 |
+
|
| 310 |
+
outt = f'Top1: {dc.setdefault("Hit 001", [0])[0]}, top3: {sumtop3}, top10: {sumtop10}, top50: {sumtop50} of {len(datal)}'
|
| 311 |
+
|
| 312 |
+
if sumtop3 > maxtop3:
|
| 313 |
+
maxtop3 = sumtop3
|
| 314 |
+
maxoutt = outt
|
| 315 |
+
|
| 316 |
+
dicall['testdata'] = datafn
|
| 317 |
+
dicall['testrlt'] = outt
|
| 318 |
+
pickle.dump(dicall, open(fn.replace('.pth', f'-{os.path.basename(datafn).split(".")[0]}-tstrlt.pkl'), 'wb'))
|
| 319 |
+
|
| 320 |
+
df = pd.DataFrame(tops)
|
| 321 |
+
df.to_csv(fn.replace('.pth', f'-{os.path.basename(datafn).split(".")[0]}-tstrlt.csv'), index=False)
|
| 322 |
+
|
| 323 |
+
return maxoutt, maxtop3
|
| 324 |
+
|
| 325 |
+
def main(datafn, fnl):
|
| 326 |
+
outl = []
|
| 327 |
+
|
| 328 |
+
datal = json.load(open(datafn))
|
| 329 |
+
logfn = f'predict_results.csv'
|
| 330 |
+
|
| 331 |
+
if not os.path.exists(logfn):
|
| 332 |
+
open(logfn, 'w').write('Index,Results,Model,Data\n')
|
| 333 |
+
|
| 334 |
+
n = 0
|
| 335 |
+
for n, fn in enumerate(fnl):
|
| 336 |
+
out, _ = test([fn], datal, datafn)
|
| 337 |
+
print(out, os.path.basename(fn))
|
| 338 |
+
outl.append(out)
|
| 339 |
+
open(logfn, 'a').write(f'{n},"{out}",{fn},{datafn}\n')
|
| 340 |
+
|
| 341 |
+
print(outl)
|
| 342 |
+
|
| 343 |
+
if __name__ == '__main__':
|
| 344 |
+
import time
|
| 345 |
+
t0 = time.time()
|
| 346 |
+
main(sys.argv[1], sys.argv[2:])
|
| 347 |
+
print(300, time.time()-t0)
|
code/separate_posneg.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
|
| 4 |
+
if __name__ == '__main__':
|
| 5 |
+
import sys
|
| 6 |
+
fn = sys.argv[1]
|
| 7 |
+
d = json.load(open(fn))
|
| 8 |
+
|
| 9 |
+
lpos = []
|
| 10 |
+
lneg = []
|
| 11 |
+
|
| 12 |
+
for n, it in enumerate(d):
|
| 13 |
+
print(f'processing {n+1}th...')
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
if it['Ion_Mode'].strip().lower() == 'negative':
|
| 17 |
+
lneg.append(it)
|
| 18 |
+
else:
|
| 19 |
+
lpos.append(it)
|
| 20 |
+
except:
|
| 21 |
+
if it['species'].strip().endswith('-'):
|
| 22 |
+
lneg.append(it)
|
| 23 |
+
else:
|
| 24 |
+
lpos.append(it)
|
| 25 |
+
|
| 26 |
+
print(f'Len lpos = {len(lpos)}, len lneg = {len(lneg)}, sum = {len(lpos)+len(lneg)}')
|
| 27 |
+
|
| 28 |
+
json.dump(lpos, open(fn.replace('.json', '-pos.json'), 'w'), indent=2)
|
| 29 |
+
json.dump(lneg, open(fn.replace('.json', '-neg.json'), 'w'), indent=2)
|
code/train.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import *
|
| 2 |
+
from modules import *
|
| 3 |
+
import os, sys
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import random
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from config import CFG
|
| 10 |
+
from dataset import *
|
| 11 |
+
import torch.utils.data
|
| 12 |
+
import copy, json, pickle
|
| 13 |
+
import itertools as it
|
| 14 |
+
|
| 15 |
+
def make_next_record_dir(basedir, prefix=''):
|
| 16 |
+
path = '%s/%%s001/' %basedir
|
| 17 |
+
n = 2
|
| 18 |
+
while os.path.exists(path %prefix):
|
| 19 |
+
path = '%s/%%s%.3d/' %(basedir, n)
|
| 20 |
+
n += 1
|
| 21 |
+
|
| 22 |
+
pth = path %prefix
|
| 23 |
+
os.makedirs(pth)
|
| 24 |
+
return pth
|
| 25 |
+
|
| 26 |
+
def setup_seed(seed):
|
| 27 |
+
torch.manual_seed(seed)
|
| 28 |
+
torch.cuda.manual_seed(seed)
|
| 29 |
+
np.random.seed(seed)
|
| 30 |
+
random.seed(seed)
|
| 31 |
+
torch.backends.cudnn.deterministic = True
|
| 32 |
+
|
| 33 |
+
def my_collate(batch):
|
| 34 |
+
batch = list(filter(lambda x:(x is not None), batch))
|
| 35 |
+
msbinl, molfpl, molfml, vl, al, msl = [], [], [], [], [], []
|
| 36 |
+
bat = {}
|
| 37 |
+
|
| 38 |
+
for b in batch:
|
| 39 |
+
if 'ms_bins' in b:
|
| 40 |
+
msbinl.append(b['ms_bins'])
|
| 41 |
+
if 'mol_fps' in b:
|
| 42 |
+
molfpl.append(b['mol_fps'])
|
| 43 |
+
if 'mol_fmvec' in b:
|
| 44 |
+
molfml.append(b['mol_fmvec'])
|
| 45 |
+
if 'V' in b:
|
| 46 |
+
vl.append(b['V'])
|
| 47 |
+
if 'A' in b:
|
| 48 |
+
al.append(b['A'])
|
| 49 |
+
if 'mol_size' in b:
|
| 50 |
+
msl.append(b['mol_size'])
|
| 51 |
+
|
| 52 |
+
if msbinl:
|
| 53 |
+
bat['ms_bins'] = torch.stack(msbinl)
|
| 54 |
+
if molfpl:
|
| 55 |
+
bat['mol_fps'] = torch.stack(molfpl)
|
| 56 |
+
if molfml:
|
| 57 |
+
bat['mol_fmvec'] = torch.stack(molfml)
|
| 58 |
+
if vl and al and msl:
|
| 59 |
+
max_n = max(map(lambda x:x.shape[0], vl))
|
| 60 |
+
vl1, al1 = [], []
|
| 61 |
+
for v in vl:
|
| 62 |
+
vl1.append(pad_V(v, max_n))
|
| 63 |
+
for a in al:
|
| 64 |
+
al1.append(pad_A(a, max_n))
|
| 65 |
+
|
| 66 |
+
bat['V'] = torch.stack(vl1)
|
| 67 |
+
bat['A'] = torch.stack(al1)
|
| 68 |
+
bat['mol_size'] = torch.cat(msl, dim=0)
|
| 69 |
+
|
| 70 |
+
#return torch.utils.data.dataloader.default_collate(batch)
|
| 71 |
+
return bat
|
| 72 |
+
|
| 73 |
+
def make_train_valid(data, valid_ratio, seed=1234):
|
| 74 |
+
idxs = np.arange(len(data))
|
| 75 |
+
np.random.seed(seed)
|
| 76 |
+
np.random.shuffle(idxs)
|
| 77 |
+
|
| 78 |
+
lenval = int(valid_ratio*len(data))
|
| 79 |
+
|
| 80 |
+
valid_set = [ data[i] for i in idxs[:lenval] ]
|
| 81 |
+
train_set = [ data[i] for i in idxs[lenval:] ]
|
| 82 |
+
|
| 83 |
+
return train_set, valid_set
|
| 84 |
+
|
| 85 |
+
def build_loaders(inp, mode, cfg, num_workers):
|
| 86 |
+
if type(inp[0]) is dict:
|
| 87 |
+
dataset = Dataset(inp, cfg)
|
| 88 |
+
else:
|
| 89 |
+
dataset = PathDataset(inp, cfg)
|
| 90 |
+
dataloader = torch.utils.data.DataLoader(
|
| 91 |
+
dataset,
|
| 92 |
+
batch_size=cfg.batch_size,
|
| 93 |
+
num_workers=num_workers,
|
| 94 |
+
shuffle=True if mode == "train" else False,
|
| 95 |
+
collate_fn=my_collate
|
| 96 |
+
)
|
| 97 |
+
return dataloader
|
| 98 |
+
|
| 99 |
+
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
|
| 100 |
+
loss_meter = AvgMeter()
|
| 101 |
+
tqdm_object = tqdm(train_loader, total=len(train_loader))
|
| 102 |
+
|
| 103 |
+
for batch in tqdm_object:
|
| 104 |
+
for k, v in batch.items():
|
| 105 |
+
batch[k] = v.to(CFG.device)
|
| 106 |
+
|
| 107 |
+
loss = model(batch)
|
| 108 |
+
optimizer.zero_grad()
|
| 109 |
+
loss.backward()
|
| 110 |
+
optimizer.step()
|
| 111 |
+
if step == "batch":
|
| 112 |
+
lr_scheduler.step()
|
| 113 |
+
|
| 114 |
+
count = batch["ms_bins"].size(0)
|
| 115 |
+
loss_meter.update(loss.item(), count)
|
| 116 |
+
|
| 117 |
+
tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
|
| 118 |
+
return loss_meter
|
| 119 |
+
|
| 120 |
+
def valid_epoch(model, valid_loader):
|
| 121 |
+
loss_meter = AvgMeter()
|
| 122 |
+
|
| 123 |
+
tqdm_object = tqdm(valid_loader, total=len(valid_loader))
|
| 124 |
+
for batch in tqdm_object:
|
| 125 |
+
for k, v in batch.items():
|
| 126 |
+
batch[k] = v.to(CFG.device)
|
| 127 |
+
|
| 128 |
+
loss = model(batch)
|
| 129 |
+
|
| 130 |
+
count = batch["ms_bins"].size(0)
|
| 131 |
+
loss_meter.update(loss.item(), count)
|
| 132 |
+
|
| 133 |
+
tqdm_object.set_postfix(valid_loss=loss_meter.avg)
|
| 134 |
+
|
| 135 |
+
return loss_meter
|
| 136 |
+
|
| 137 |
+
def main(data, cfg=CFG, savedir='data/train', encmodel=None, ratio=1):
|
| 138 |
+
setup_seed(cfg.seed)
|
| 139 |
+
|
| 140 |
+
train_set, valid_set = make_train_valid(data, valid_ratio=cfg.valid_ratio, seed=cfg.seed)
|
| 141 |
+
|
| 142 |
+
n = len(train_set)
|
| 143 |
+
if ratio < 1:
|
| 144 |
+
train_set = random.sample(train_set, int(n*ratio))
|
| 145 |
+
print(f'Ratio {ratio}, lenall {n}, newtrainset {len(train_set)}')
|
| 146 |
+
|
| 147 |
+
train_loader = build_loaders(train_set, "train", cfg, 10)
|
| 148 |
+
valid_loader = build_loaders(valid_set, "valid", cfg, 10)
|
| 149 |
+
|
| 150 |
+
step = "epoch"
|
| 151 |
+
|
| 152 |
+
best_loss = float('inf')
|
| 153 |
+
best_model_fn = ''
|
| 154 |
+
best_model_fns = []
|
| 155 |
+
|
| 156 |
+
model = FragSimiModel(cfg).to(cfg.device)
|
| 157 |
+
|
| 158 |
+
if not encmodel is None:
|
| 159 |
+
model.mol_gnn_encoder.load_state_dict(encmodel.mol_gnn_encoder.state_dict())
|
| 160 |
+
# fraze mol_gnn_encoder weights
|
| 161 |
+
'''for name, param in model.named_parameters():
|
| 162 |
+
if 'mol_gnn_encoder' in name:
|
| 163 |
+
print(152, 'fraze mol_gnn_encoder weights')
|
| 164 |
+
param.requires_grad = False'''
|
| 165 |
+
|
| 166 |
+
print(model)
|
| 167 |
+
|
| 168 |
+
optimizer = torch.optim.AdamW(
|
| 169 |
+
model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 173 |
+
optimizer, mode="min", patience=cfg.patience, factor=cfg.factor
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
for epoch in range(cfg.epochs):
|
| 177 |
+
print(f"Epoch: {epoch + 1}/{cfg.epochs}")
|
| 178 |
+
model.train()
|
| 179 |
+
train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
|
| 180 |
+
model.eval()
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
valid_loss = valid_epoch(model, valid_loader)
|
| 183 |
+
|
| 184 |
+
if valid_loss.avg < best_loss:
|
| 185 |
+
best_loss = valid_loss.avg
|
| 186 |
+
best_model_fn = f"{savedir}/model-tloss{round(train_loss.avg, 3)}-vloss{round(valid_loss.avg, 3)}-epoch{epoch}.pth"
|
| 187 |
+
best_model_fn_base = best_model_fn.replace('.pth', '')
|
| 188 |
+
n = 1
|
| 189 |
+
while os.path.exists(best_model_fn):
|
| 190 |
+
best_model_fn = best_model_fn_base + f'-{n}.pth'
|
| 191 |
+
n += 1
|
| 192 |
+
|
| 193 |
+
checkpoint = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'config': dict(CFG)}
|
| 194 |
+
best_model_fns.append(best_model_fn)
|
| 195 |
+
torch.save(checkpoint, best_model_fn)
|
| 196 |
+
print("Saved Best Model!")
|
| 197 |
+
|
| 198 |
+
best_model_fnl = []
|
| 199 |
+
for fn in best_model_fns:
|
| 200 |
+
if os.path.exists(fn):
|
| 201 |
+
best_model_fnl.append(fn)
|
| 202 |
+
|
| 203 |
+
for fn in best_model_fnl[:-cfg.keep_best_models_num]:
|
| 204 |
+
os.remove(fn)
|
| 205 |
+
|
| 206 |
+
best_model_fnl = best_model_fnl[-cfg.keep_best_models_num:]
|
| 207 |
+
|
| 208 |
+
print(best_model_fnl, best_loss)
|
| 209 |
+
return best_model_fnl, best_loss
|
| 210 |
+
|
| 211 |
+
if __name__ == "__main__":
|
| 212 |
+
try:
|
| 213 |
+
conffn = sys.argv[1]
|
| 214 |
+
if conffn.endswith('.json'):
|
| 215 |
+
CFG.load(sys.argv[1])
|
| 216 |
+
elif conffn.endswith('.pth'):
|
| 217 |
+
dpath = CFG.dataset_path
|
| 218 |
+
d = torch.load(conffn)
|
| 219 |
+
CFG.load(d['config'])
|
| 220 |
+
CFG.dataset_path = dpath
|
| 221 |
+
print('Use config from', conffn)
|
| 222 |
+
except:
|
| 223 |
+
pass
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
savedir = sys.argv[2]
|
| 227 |
+
except:
|
| 228 |
+
savedir = 'data/'
|
| 229 |
+
|
| 230 |
+
os.system('mkdir -p %s' %savedir)
|
| 231 |
+
|
| 232 |
+
mg = None
|
| 233 |
+
|
| 234 |
+
print(CFG)
|
| 235 |
+
|
| 236 |
+
if os.path.isdir(CFG.dataset_path):
|
| 237 |
+
data = [os.path.join(CFG.dataset_path, i) for i in os.listdir(CFG.dataset_path) if i.endswith('mgf')]
|
| 238 |
+
elif os.path.isfile(CFG.dataset_path):
|
| 239 |
+
if CFG.dataset_path.endswith('.pkl'):
|
| 240 |
+
data = pickle.load(open(CFG.dataset_path, 'rb'))
|
| 241 |
+
else:
|
| 242 |
+
data = json.load(open(CFG.dataset_path))
|
| 243 |
+
pklfn = CFG.dataset_path.replace('.json', '.pkl')
|
| 244 |
+
if not os.path.exists(pklfn):
|
| 245 |
+
pickle.dump(data, open(pklfn, 'wb'))
|
| 246 |
+
|
| 247 |
+
subdir = make_next_record_dir(savedir, f'train-')
|
| 248 |
+
os.system(f'cp -a *py {subdir}; cp -a GNN {subdir}')
|
| 249 |
+
CFG.save(f'{subdir}/config.json')
|
| 250 |
+
|
| 251 |
+
modelfnl, _ = main(data, CFG, subdir, mg)
|
code/utils.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rdkit import Chem
|
| 2 |
+
from rdkit.Chem import AllChem, MACCSkeys
|
| 3 |
+
from rdkit.Chem.rdmolops import FastFindRings
|
| 4 |
+
from rdkit.Chem.rdMolDescriptors import CalcMolFormula
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import scipy
|
| 8 |
+
import scipy.sparse as ss
|
| 9 |
+
import scipy.sparse.linalg
|
| 10 |
+
import math
|
| 11 |
+
import json
|
| 12 |
+
import itertools as it
|
| 13 |
+
import re
|
| 14 |
+
from GNN import featurizer as ft
|
| 15 |
+
|
| 16 |
+
import rdkit.RDLogger as rkl
|
| 17 |
+
logger = rkl.logger()
|
| 18 |
+
logger.setLevel(rkl.ERROR)
|
| 19 |
+
|
| 20 |
+
import rdkit.rdBase as rkrb
|
| 21 |
+
rkrb.DisableLog('rdApp.error')
|
| 22 |
+
|
| 23 |
+
# 50w metabolites fpbit relative aboundance > 5%
|
| 24 |
+
FPBitIdx = [1, 5, 13, 41, 69, 80, 84, 94, 114, 117, 118, 119, 125, 133, 145,
|
| 25 |
+
147, 191, 192, 197, 202, 222, 227, 231, 249, 283, 294, 310, 314,
|
| 26 |
+
322, 333, 352, 361, 378, 387, 389, 392, 401, 406, 441, 478, 486,
|
| 27 |
+
489, 519, 521, 524, 555, 561, 591, 598, 599, 610, 622, 650, 656,
|
| 28 |
+
667, 675, 677, 679, 680, 694, 695, 715, 718, 722, 729, 736, 739,
|
| 29 |
+
745, 750, 760, 775, 781, 787, 794, 798, 802, 807, 811, 823, 835,
|
| 30 |
+
841, 849, 869, 872, 874, 875, 881, 890, 896, 926, 935, 980, 991,
|
| 31 |
+
1004, 1009, 1017, 1019, 1027, 1028, 1035, 1037, 1039, 1057, 1060,
|
| 32 |
+
1066, 1070, 1077, 1088, 1097, 1114, 1126, 1136, 1142, 1143, 1145,
|
| 33 |
+
1152, 1154, 1160, 1162, 1171, 1181, 1195, 1199, 1202, 1218, 1234,
|
| 34 |
+
1236, 1243, 1257, 1267, 1274, 1279, 1283, 1292, 1294, 1309, 1313,
|
| 35 |
+
1323, 1325, 1349, 1356, 1357, 1366, 1380, 1381, 1385, 1386, 1391,
|
| 36 |
+
1399, 1436, 1440, 1441, 1444, 1452, 1454, 1457, 1475, 1476, 1477,
|
| 37 |
+
1480, 1487, 1516, 1536, 1544, 1558, 1564, 1573, 1599, 1602, 1604,
|
| 38 |
+
1607, 1619, 1648, 1670, 1683, 1693, 1716, 1722, 1737, 1738, 1745,
|
| 39 |
+
1747, 1750, 1754, 1755, 1764, 1781, 1803, 1808, 1810, 1816, 1838,
|
| 40 |
+
1844, 1847, 1855, 1860, 1866, 1873, 1905, 1911, 1917, 1921, 1923,
|
| 41 |
+
1928, 1933, 1950, 1951, 1970, 1977, 1980, 1984, 1991, 2002, 2033, 2034, 2038]
|
| 42 |
+
|
| 43 |
+
class ConfigDict(dict):
|
| 44 |
+
'''
|
| 45 |
+
Makes a dictionary behave like an object,with attribute-style access.
|
| 46 |
+
'''
|
| 47 |
+
def __getattr__(self, name):
|
| 48 |
+
try:
|
| 49 |
+
return self[name]
|
| 50 |
+
except:
|
| 51 |
+
raise AttributeError(name)
|
| 52 |
+
|
| 53 |
+
def __setattr__(self, name, value):
|
| 54 |
+
self[name] = value
|
| 55 |
+
|
| 56 |
+
def save(self, fn):
|
| 57 |
+
json.dump(self, open(fn, 'w'), indent=2)
|
| 58 |
+
|
| 59 |
+
def load_dict(self, dic):
|
| 60 |
+
for k, v in dic.items():
|
| 61 |
+
self[k] = v
|
| 62 |
+
|
| 63 |
+
def load(self, fn):
|
| 64 |
+
try:
|
| 65 |
+
d = json.load(open(fn, 'r'))
|
| 66 |
+
self.load_dict(d)
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print(e)
|
| 69 |
+
|
| 70 |
+
def conv_out_dim(length_in, kernel, stride, padding, dilation):
|
| 71 |
+
length_out = (length_in + 2 * padding - dilation * (kernel - 1) - 1)// stride + 1
|
| 72 |
+
return length_out
|
| 73 |
+
|
| 74 |
+
def filter_ms(ms, thr=0.05, max_mz=2000):
|
| 75 |
+
mz = []
|
| 76 |
+
intn = []
|
| 77 |
+
maxi = 0
|
| 78 |
+
for m, i in ms:
|
| 79 |
+
if m < max_mz and i > maxi:
|
| 80 |
+
maxi = i
|
| 81 |
+
|
| 82 |
+
for m, i in ms:
|
| 83 |
+
if m < max_mz and i/maxi > thr:
|
| 84 |
+
mz.append(m)
|
| 85 |
+
intn.append(round(i/maxi*100, 2))
|
| 86 |
+
|
| 87 |
+
return mz, intn
|
| 88 |
+
|
| 89 |
+
def calc_nls(ms, thr=0.05, max_mz=2000):
|
| 90 |
+
mz, intn = filter_ms(ms, thr=0.05, max_mz=2000)
|
| 91 |
+
|
| 92 |
+
nlmass = []
|
| 93 |
+
nlintn = []
|
| 94 |
+
for a, b in it.combinations(mz[::-1], 2):
|
| 95 |
+
nl = a - b
|
| 96 |
+
if 0 < nl < 200:
|
| 97 |
+
nlmass.append(round(nl, 5))
|
| 98 |
+
idxa = mz.index(a)
|
| 99 |
+
idxb = mz.index(b)
|
| 100 |
+
nlintn.append(round((intn[idxa]+intn[idxb])/2., 5))
|
| 101 |
+
|
| 102 |
+
nls = sorted(list(zip(nlmass, nlintn)))
|
| 103 |
+
return nls
|
| 104 |
+
|
| 105 |
+
def ms_binner(ms, nls=[], min_mz=20, max_mz=2000, bin_size=0.05, add_nl=False, binary_intn=False):
|
| 106 |
+
"""
|
| 107 |
+
Convert the given spectrum to a binned sparse SciPy vector.
|
| 108 |
+
|
| 109 |
+
Parameters
|
| 110 |
+
----------
|
| 111 |
+
spectrum_mz : np.ndarray
|
| 112 |
+
The peak m/z values of the spectrum to be converted to a vector.
|
| 113 |
+
spectrum_intensity : np.ndarray
|
| 114 |
+
The peak intensities of the spectrum to be converted to a vector.
|
| 115 |
+
min_mz : float
|
| 116 |
+
The minimum m/z to include in the vector.
|
| 117 |
+
bin_size : float
|
| 118 |
+
The bin size in m/z used to divide the m/z range.
|
| 119 |
+
num_bins : int
|
| 120 |
+
The number of elements of which the vector consists.
|
| 121 |
+
|
| 122 |
+
Returns
|
| 123 |
+
-------
|
| 124 |
+
ss.csr_matrix
|
| 125 |
+
The binned spectrum vector.
|
| 126 |
+
"""
|
| 127 |
+
if add_nl and not nls:
|
| 128 |
+
nls = calc_nls(ms, max_mz=max_mz)
|
| 129 |
+
|
| 130 |
+
nltensor = None
|
| 131 |
+
mz, intn = filter_ms(ms)
|
| 132 |
+
|
| 133 |
+
if add_nl:
|
| 134 |
+
nlmass = []
|
| 135 |
+
nlintn = []
|
| 136 |
+
|
| 137 |
+
if not nls:
|
| 138 |
+
nls = calc_nls(ms, max_mz=max_mz)
|
| 139 |
+
|
| 140 |
+
for m, i in nls:
|
| 141 |
+
if m < 200:
|
| 142 |
+
if binary_intn:
|
| 143 |
+
i = 1
|
| 144 |
+
nlmass.append(m)
|
| 145 |
+
nlintn.append(i)
|
| 146 |
+
|
| 147 |
+
nlmass = np.array(nlmass)
|
| 148 |
+
nlintn = np.array(nlintn)
|
| 149 |
+
if len(nlintn) > 0:
|
| 150 |
+
nlintn = nlintn/nlintn.max()
|
| 151 |
+
num_nlbins = math.ceil((200) / bin_size)
|
| 152 |
+
#print('num_nlbins', num_nlbins)
|
| 153 |
+
nlbins = (nlmass / bin_size).astype(np.int32)
|
| 154 |
+
|
| 155 |
+
if len(nlmass) > 0:
|
| 156 |
+
vecnl = ss.csr_matrix(
|
| 157 |
+
(nlintn,
|
| 158 |
+
(np.repeat(0, len(nlintn)), nlbins)),
|
| 159 |
+
shape=(1, num_nlbins),
|
| 160 |
+
dtype=np.float32)
|
| 161 |
+
|
| 162 |
+
vecnl = (vecnl / scipy.sparse.linalg.norm(vecnl)*100)
|
| 163 |
+
nltensor = torch.FloatTensor(vecnl.todense()).view(-1)
|
| 164 |
+
else:
|
| 165 |
+
nltensor = torch.zeros(num_nlbins)
|
| 166 |
+
|
| 167 |
+
mz = np.array(mz)
|
| 168 |
+
keepidx = (mz <= max_mz)
|
| 169 |
+
mz = mz[keepidx]
|
| 170 |
+
intn = np.array(intn)
|
| 171 |
+
intn = intn[keepidx]
|
| 172 |
+
|
| 173 |
+
if binary_intn:
|
| 174 |
+
intn[intn > 0] = 1.0
|
| 175 |
+
elif len(intn) > 0:
|
| 176 |
+
intn = intn/intn.max()
|
| 177 |
+
|
| 178 |
+
num_bins = math.ceil((max_mz - min_mz) / bin_size)
|
| 179 |
+
#print('num_bins', num_bins)
|
| 180 |
+
bins = ((mz - min_mz) / bin_size).astype(np.int32)
|
| 181 |
+
|
| 182 |
+
#print(num_bins, intn, bins)
|
| 183 |
+
|
| 184 |
+
if len(mz) > 0:
|
| 185 |
+
vec = ss.csr_matrix(
|
| 186 |
+
(intn,
|
| 187 |
+
(np.repeat(0, len(intn)), bins)),
|
| 188 |
+
shape=(1, num_bins),
|
| 189 |
+
dtype=np.float32)
|
| 190 |
+
|
| 191 |
+
if not binary_intn:
|
| 192 |
+
vec = (vec / scipy.sparse.linalg.norm(vec)*100)
|
| 193 |
+
|
| 194 |
+
mstensor = torch.FloatTensor(vec.todense()).view(-1)
|
| 195 |
+
else:
|
| 196 |
+
mstensor = torch.zeros(num_bins)
|
| 197 |
+
|
| 198 |
+
if not nltensor is None:
|
| 199 |
+
return torch.cat([nltensor, mstensor], dim=0)
|
| 200 |
+
|
| 201 |
+
return mstensor
|
| 202 |
+
|
| 203 |
+
def formula2vec(formula, elements=['C', 'H', 'O', 'N', 'P', 'S', 'P', 'F', 'Cl', 'Br']):
|
| 204 |
+
formula_p = re.findall(r'([A-Z][a-z]*)(\d*)', formula)
|
| 205 |
+
vec = np.zeros(len(elements))
|
| 206 |
+
for i in range(len(formula_p)):
|
| 207 |
+
ele = formula_p[i][0]
|
| 208 |
+
num = formula_p[i][1]
|
| 209 |
+
if num == '':
|
| 210 |
+
num = 1
|
| 211 |
+
else:
|
| 212 |
+
num = int(num)
|
| 213 |
+
if ele in elements:
|
| 214 |
+
vec[elements.index(ele)] += num
|
| 215 |
+
return np.array(vec)
|
| 216 |
+
|
| 217 |
+
def mol_fp_encoder0(smiles, tp='rdkit', nbits=2048):
|
| 218 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 219 |
+
if mol is None:
|
| 220 |
+
mol = Chem.MolFromSmiles(smiles, sanitize=False)
|
| 221 |
+
if not mol is None:
|
| 222 |
+
mol.UpdatePropertyCache()
|
| 223 |
+
FastFindRings(mol)
|
| 224 |
+
|
| 225 |
+
if mol is None:
|
| 226 |
+
return None, None
|
| 227 |
+
|
| 228 |
+
if tp == 'morgan':
|
| 229 |
+
fp_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=nbits)
|
| 230 |
+
fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| 231 |
+
fp = fp.tolist()
|
| 232 |
+
elif tp == 'morgan1':
|
| 233 |
+
fp_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
|
| 234 |
+
fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| 235 |
+
fp = fp[FPBitIdx].tolist()
|
| 236 |
+
elif tp == 'macc':
|
| 237 |
+
# MACCSkeys
|
| 238 |
+
fp_vec = MACCSkeys.GenMACCSKeys(mol)
|
| 239 |
+
fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| 240 |
+
fp = fp.tolist()
|
| 241 |
+
elif tp == 'rdkit':
|
| 242 |
+
fp_vec = Chem.RDKFingerprint(mol, nBitsPerHash=1)
|
| 243 |
+
fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| 244 |
+
fp = fp.tolist()
|
| 245 |
+
|
| 246 |
+
return torch.FloatTensor(fp), mol
|
| 247 |
+
|
| 248 |
+
def mol_fp_encoder(smiles, tp='rdkit', nbits=2048):
|
| 249 |
+
fpenc, _ = mol_fp_encoder0(smiles, tp, nbits)
|
| 250 |
+
return fpenc
|
| 251 |
+
|
| 252 |
+
def mol_fp_fm_encoder(smiles, tp='rdkit', nbits=2048):
|
| 253 |
+
fmenc = None
|
| 254 |
+
fpenc, mol = mol_fp_encoder0(smiles, tp, nbits)
|
| 255 |
+
if not mol is None:
|
| 256 |
+
fm = CalcMolFormula(mol)
|
| 257 |
+
fmenc = torch.FloatTensor(formula2vec(fm))
|
| 258 |
+
return fpenc, fmenc
|
| 259 |
+
|
| 260 |
+
def smi2fmvec(smiles):
|
| 261 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 262 |
+
if mol is None:
|
| 263 |
+
return None
|
| 264 |
+
fm = CalcMolFormula(mol)
|
| 265 |
+
fmenc = torch.FloatTensor(formula2vec(fm))
|
| 266 |
+
|
| 267 |
+
return fmenc
|
| 268 |
+
|
| 269 |
+
def mol_graph_featurizer(smiles):
|
| 270 |
+
# mol_graph = {V, A, mol_size}
|
| 271 |
+
'''mol_graph = ft.calc_data_from_smile(smiles,
|
| 272 |
+
addh=True,
|
| 273 |
+
with_ring_conj=True,
|
| 274 |
+
with_atom_feats=True,
|
| 275 |
+
with_submol_fp=True,
|
| 276 |
+
radius=2)
|
| 277 |
+
'''
|
| 278 |
+
mol_graph = ft.calc_data_from_smile(smiles,
|
| 279 |
+
addh=False,
|
| 280 |
+
with_ring_conj=True,
|
| 281 |
+
with_atom_feats=True,
|
| 282 |
+
with_submol_fp=False,
|
| 283 |
+
radius=2)
|
| 284 |
+
return mol_graph
|
| 285 |
+
|
| 286 |
+
def pad_V(V, max_n):
|
| 287 |
+
N, C = V.shape
|
| 288 |
+
if max_n > N:
|
| 289 |
+
zeros = torch.zeros(max_n-N, C)
|
| 290 |
+
V = torch.cat([V, zeros], dim=0)
|
| 291 |
+
return V
|
| 292 |
+
|
| 293 |
+
def pad_A(A, max_n):
|
| 294 |
+
N, L, _ = A.shape
|
| 295 |
+
if max_n > N:
|
| 296 |
+
zeros = torch.zeros(N, L, max_n-N)
|
| 297 |
+
A = torch.cat([A, zeros], dim=-1)
|
| 298 |
+
zeros = torch.zeros(max_n-N, L, max_n)
|
| 299 |
+
A = torch.cat([A, zeros], dim=0)
|
| 300 |
+
return A
|
| 301 |
+
|
| 302 |
+
class AvgMeter:
|
| 303 |
+
def __init__(self, name="Metric"):
|
| 304 |
+
self.name = name
|
| 305 |
+
self.reset()
|
| 306 |
+
|
| 307 |
+
def reset(self):
|
| 308 |
+
self.avg, self.sum, self.count = [0] * 3
|
| 309 |
+
|
| 310 |
+
def update(self, val, count=1):
|
| 311 |
+
self.count += count
|
| 312 |
+
self.sum += val * count
|
| 313 |
+
self.avg = self.sum / self.count
|
| 314 |
+
|
| 315 |
+
def __repr__(self):
|
| 316 |
+
text = f"{self.name}: {self.avg:.4f}"
|
| 317 |
+
return text
|
| 318 |
+
|
| 319 |
+
def get_lr(optimizer):
|
| 320 |
+
for param_group in optimizer.param_groups:
|
| 321 |
+
return param_group["lr"]
|
| 322 |
+
|
| 323 |
+
def segment_max(x, size_list):
|
| 324 |
+
size_list = [int(i) for i in size_list]
|
| 325 |
+
return torch.stack([torch.max(v, 0).values for v in torch.split(x, size_list)])
|
| 326 |
+
|
| 327 |
+
def segment_sum(x, size_list):
|
| 328 |
+
size_list = [int(i) for i in size_list]
|
| 329 |
+
return torch.stack([torch.sum(v, 0) for v in torch.split(x, size_list)])
|
| 330 |
+
|
| 331 |
+
def segment_softmax(gate, size_list):
|
| 332 |
+
segmax = segment_max(gate, size_list)
|
| 333 |
+
# expand segmax shape to alpha shape
|
| 334 |
+
segmax_expand = torch.cat([segmax[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
|
| 335 |
+
subtract = gate - segmax_expand
|
| 336 |
+
exp = torch.exp(subtract)
|
| 337 |
+
segsum = segment_sum(exp, size_list)
|
| 338 |
+
# expand segmax shape to alpha shape
|
| 339 |
+
segsum_expand = torch.cat([segsum[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
|
| 340 |
+
attention = exp / (segsum_expand + 1e-16)
|
| 341 |
+
|
| 342 |
+
return attention
|
| 343 |
+
|
| 344 |
+
def pad_ms_list(ms_list, thr=0.05, min_mz=20, max_mz=2000):
|
| 345 |
+
thr = thr*100
|
| 346 |
+
mslst = []
|
| 347 |
+
for ms in ms_list:
|
| 348 |
+
ms = np.array(ms)
|
| 349 |
+
ms[:,1] = ms[:,1]/ms[:,1].max()*100
|
| 350 |
+
|
| 351 |
+
if thr > 0:
|
| 352 |
+
ms = ms[(ms[:,1] >= thr)]
|
| 353 |
+
|
| 354 |
+
ms = ms[(ms[:,0] >= min_mz)]
|
| 355 |
+
ms = ms[(ms[:,0] <= max_mz)]
|
| 356 |
+
|
| 357 |
+
mslst.append(ms)
|
| 358 |
+
|
| 359 |
+
size_list = [ms.shape[0] for ms in mslst]
|
| 360 |
+
maxlen = max(size_list)
|
| 361 |
+
|
| 362 |
+
l = []
|
| 363 |
+
for ms in mslst:
|
| 364 |
+
extn = maxlen-len(ms)
|
| 365 |
+
if extn > 0:
|
| 366 |
+
l.append(np.concatenate([ms, [[0,0]]*extn], axis=0))
|
| 367 |
+
else:
|
| 368 |
+
l.append(ms)
|
| 369 |
+
|
| 370 |
+
return torch.FloatTensor(np.stack(l)), torch.IntTensor(size_list)
|