Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from rdkit import Chem | |
| # ============================================================================= | |
| # Featurization Utils | |
| # ============================================================================= | |
| ATOM_FEATURES = { | |
| 'atomic_num': list(range(1, 101)), | |
| 'degree': [0, 1, 2, 3, 4, 5], | |
| 'formal_charge': [-1, -2, 1, 2, 0], | |
| 'chiral_tag': [0, 1, 2, 3], | |
| 'num_hs': [0, 1, 2, 3, 4], | |
| 'hybridization': [ | |
| Chem.rdchem.HybridizationType.SP, | |
| Chem.rdchem.HybridizationType.SP2, | |
| Chem.rdchem.HybridizationType.SP3, | |
| Chem.rdchem.HybridizationType.SP3D, | |
| Chem.rdchem.HybridizationType.SP3D2 | |
| ], | |
| } | |
| BOND_FDIM = 13 | |
| def get_atom_fdim(): | |
| return sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 2 | |
| def get_bond_fdim(): | |
| return BOND_FDIM | |
| def onek_encoding_unk(value, choices): | |
| encoding = [0] * (len(choices) + 1) | |
| index = choices.index(value) if value in choices else -1 | |
| encoding[index] = 1 | |
| return encoding | |
| def atom_features(atom): | |
| return ( | |
| onek_encoding_unk(atom.GetAtomicNum(), ATOM_FEATURES['atomic_num']) + | |
| onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + | |
| onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge']) + | |
| onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag']) + | |
| onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_hs']) + | |
| onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + | |
| [1 if atom.GetIsAromatic() else 0] + | |
| [atom.GetMass() * 0.01] | |
| ) | |
| def bond_features(bond): | |
| bt = bond.GetBondType() | |
| feats = [ | |
| bt == Chem.rdchem.BondType.SINGLE, | |
| bt == Chem.rdchem.BondType.DOUBLE, | |
| bt == Chem.rdchem.BondType.TRIPLE, | |
| bt == Chem.rdchem.BondType.AROMATIC, | |
| bond.GetIsConjugated() if bt else 0, | |
| bond.IsInRing() if bt else 0, | |
| ] | |
| feats += onek_encoding_unk(int(bond.GetStereo()), list(range(6))) | |
| return feats | |
| class MolGraph: | |
| def __init__(self, smiles): | |
| self.smiles = smiles | |
| self.f_atoms = [] | |
| self.f_bonds = [] | |
| self.a2b = [] | |
| self.b2a = [] | |
| self.b2revb = [] | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| self.n_atoms = 0 | |
| self.n_bonds = 0 | |
| return | |
| self.n_atoms = mol.GetNumAtoms() | |
| for atom in mol.GetAtoms(): | |
| self.f_atoms.append(atom_features(atom)) | |
| self.a2b = [[] for _ in range(self.n_atoms)] | |
| self.n_bonds = 0 | |
| for bond in mol.GetBonds(): | |
| a1, a2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() | |
| b_feat = bond_features(bond) | |
| b1 = self.n_bonds | |
| self.f_bonds.append(b_feat) | |
| self.b2a.append(a1) | |
| self.a2b[a2].append(b1) | |
| b2 = self.n_bonds + 1 | |
| self.f_bonds.append(b_feat) | |
| self.b2a.append(a2) | |
| self.a2b[a1].append(b2) | |
| self.b2revb.extend([b2, b1]) | |
| self.n_bonds += 2 | |
| class BatchMolGraph: | |
| def __init__(self, mol_graphs): | |
| self.atom_features = [] | |
| self.bond_features = [] | |
| self.a2b = [] | |
| self.b2a = [] | |
| self.b2revb = [] | |
| self.a_scope = [] | |
| total_atoms = 0 | |
| total_bonds = 0 | |
| for g in mol_graphs: | |
| self.atom_features.extend(g.f_atoms) | |
| self.bond_features.extend(g.f_bonds) | |
| for lst in g.a2b: | |
| self.a2b.append([b + total_bonds for b in lst]) | |
| self.b2a.extend(a + total_atoms for a in g.b2a) | |
| self.b2revb.extend(b + total_bonds for b in g.b2revb) | |
| self.a_scope.append((total_atoms, g.n_atoms)) | |
| total_atoms += g.n_atoms | |
| total_bonds += g.n_bonds | |
| self.atom_features = torch.tensor(self.atom_features, dtype=torch.float) | |
| self.bond_features = torch.tensor(self.bond_features, dtype=torch.float) | |
| self.b2a = torch.tensor(self.b2a, dtype=torch.long) | |
| self.b2revb = torch.tensor(self.b2revb, dtype=torch.long) | |
| def get_components(self): | |
| return ( | |
| self.atom_features, | |
| self.bond_features, | |
| self.a2b, | |
| self.b2a, | |
| self.b2revb, | |
| self.a_scope, | |
| ) | |
| # ============================================================================= | |
| # D-MPNN (Backward-Compatible) | |
| # ============================================================================= | |
| class DMPNN(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size=300, | |
| depth=3, | |
| tasks=12, | |
| global_feats_size=217, | |
| n_tasks=None, # ← compatibility with old checkpoints | |
| **kwargs # ← ignore legacy args safely | |
| ): | |
| super().__init__() | |
| if n_tasks is not None: | |
| tasks = n_tasks | |
| self.hidden_size = hidden_size | |
| self.depth = depth | |
| self.atom_fdim = get_atom_fdim() | |
| self.bond_fdim = get_bond_fdim() | |
| self.global_feats_size = global_feats_size | |
| self.W_i = nn.Linear(self.atom_fdim + self.bond_fdim, hidden_size, bias=False) | |
| self.W_h = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.W_o = nn.Linear(self.atom_fdim + hidden_size, hidden_size) | |
| self.act = nn.ReLU() | |
| self.dropout = nn.Dropout(0.1) | |
| self.readout_1 = nn.Linear(hidden_size + global_feats_size, hidden_size) | |
| self.readout_2 = nn.Linear(hidden_size, tasks) | |
| def forward(self, batch_graph, global_feats=None): | |
| f_atoms, f_bonds, a2b, b2a, b2revb, a_scope = batch_graph.get_components() | |
| if f_atoms.size(0) == 0: | |
| return torch.zeros((len(a_scope), self.readout_2.out_features), | |
| device=self.W_i.weight.device) | |
| device = self.W_i.weight.device | |
| f_atoms, f_bonds = f_atoms.to(device), f_bonds.to(device) | |
| b2a, b2revb = b2a.to(device), b2revb.to(device) | |
| h0 = self.act(self.W_i(torch.cat([f_atoms.index_select(0, b2a), f_bonds], 1))) | |
| h = h0 | |
| for _ in range(self.depth): | |
| atom_msg = torch.zeros(f_atoms.size(0), self.hidden_size, device=device) | |
| atom_msg.index_add_(0, b2a.index_select(0, b2revb), h) | |
| m = atom_msg.index_select(0, b2a) - h.index_select(0, b2revb) | |
| h = self.dropout(self.act(h0 + self.W_h(m))) | |
| atom_msg = torch.zeros(f_atoms.size(0), self.hidden_size, device=device) | |
| atom_msg.index_add_(0, b2a.index_select(0, b2revb), h) | |
| atom_h = self.act(self.W_o(torch.cat([f_atoms, atom_msg], 1))) | |
| mol_vecs = [ | |
| atom_h.narrow(0, s, n).sum(0) if n > 0 else torch.zeros(self.hidden_size, device=device) | |
| for s, n in a_scope | |
| ] | |
| mol_vecs = torch.stack(mol_vecs) | |
| if self.global_feats_size > 0: | |
| if global_feats is None: | |
| raise ValueError("Global features expected but not provided") | |
| mol_vecs = torch.cat([mol_vecs, global_feats.to(device)], 1) | |
| x = self.dropout(self.act(self.readout_1(mol_vecs))) | |
| return self.readout_2(x) | |