import torch import torch.nn as nn from rdkit import Chem # ============================================================================= # Featurization Utils # ============================================================================= ATOM_FEATURES = { 'atomic_num': list(range(1, 101)), 'degree': [0, 1, 2, 3, 4, 5], 'formal_charge': [-1, -2, 1, 2, 0], 'chiral_tag': [0, 1, 2, 3], 'num_hs': [0, 1, 2, 3, 4], 'hybridization': [ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2 ], } BOND_FDIM = 13 def get_atom_fdim(): return sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 2 def get_bond_fdim(): return BOND_FDIM def onek_encoding_unk(value, choices): encoding = [0] * (len(choices) + 1) index = choices.index(value) if value in choices else -1 encoding[index] = 1 return encoding def atom_features(atom): return ( onek_encoding_unk(atom.GetAtomicNum(), ATOM_FEATURES['atomic_num']) + onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge']) + onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag']) + onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_hs']) + onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + [1 if atom.GetIsAromatic() else 0] + [atom.GetMass() * 0.01] ) def bond_features(bond): bt = bond.GetBondType() feats = [ bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.GetIsConjugated() if bt else 0, bond.IsInRing() if bt else 0, ] feats += onek_encoding_unk(int(bond.GetStereo()), list(range(6))) return feats class MolGraph: def __init__(self, smiles): self.smiles = smiles self.f_atoms = [] self.f_bonds = [] self.a2b = [] self.b2a = [] self.b2revb = [] mol = Chem.MolFromSmiles(smiles) if mol is None: self.n_atoms = 0 self.n_bonds = 0 return self.n_atoms = mol.GetNumAtoms() for atom in mol.GetAtoms(): self.f_atoms.append(atom_features(atom)) self.a2b = [[] for _ in range(self.n_atoms)] self.n_bonds = 0 for bond in mol.GetBonds(): a1, a2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() b_feat = bond_features(bond) b1 = self.n_bonds self.f_bonds.append(b_feat) self.b2a.append(a1) self.a2b[a2].append(b1) b2 = self.n_bonds + 1 self.f_bonds.append(b_feat) self.b2a.append(a2) self.a2b[a1].append(b2) self.b2revb.extend([b2, b1]) self.n_bonds += 2 class BatchMolGraph: def __init__(self, mol_graphs): self.atom_features = [] self.bond_features = [] self.a2b = [] self.b2a = [] self.b2revb = [] self.a_scope = [] total_atoms = 0 total_bonds = 0 for g in mol_graphs: self.atom_features.extend(g.f_atoms) self.bond_features.extend(g.f_bonds) for lst in g.a2b: self.a2b.append([b + total_bonds for b in lst]) self.b2a.extend(a + total_atoms for a in g.b2a) self.b2revb.extend(b + total_bonds for b in g.b2revb) self.a_scope.append((total_atoms, g.n_atoms)) total_atoms += g.n_atoms total_bonds += g.n_bonds self.atom_features = torch.tensor(self.atom_features, dtype=torch.float) self.bond_features = torch.tensor(self.bond_features, dtype=torch.float) self.b2a = torch.tensor(self.b2a, dtype=torch.long) self.b2revb = torch.tensor(self.b2revb, dtype=torch.long) def get_components(self): return ( self.atom_features, self.bond_features, self.a2b, self.b2a, self.b2revb, self.a_scope, ) # ============================================================================= # D-MPNN (Backward-Compatible) # ============================================================================= class DMPNN(nn.Module): def __init__( self, hidden_size=300, depth=3, tasks=12, global_feats_size=217, n_tasks=None, # ← compatibility with old checkpoints **kwargs # ← ignore legacy args safely ): super().__init__() if n_tasks is not None: tasks = n_tasks self.hidden_size = hidden_size self.depth = depth self.atom_fdim = get_atom_fdim() self.bond_fdim = get_bond_fdim() self.global_feats_size = global_feats_size self.W_i = nn.Linear(self.atom_fdim + self.bond_fdim, hidden_size, bias=False) self.W_h = nn.Linear(hidden_size, hidden_size, bias=False) self.W_o = nn.Linear(self.atom_fdim + hidden_size, hidden_size) self.act = nn.ReLU() self.dropout = nn.Dropout(0.1) self.readout_1 = nn.Linear(hidden_size + global_feats_size, hidden_size) self.readout_2 = nn.Linear(hidden_size, tasks) def forward(self, batch_graph, global_feats=None): f_atoms, f_bonds, a2b, b2a, b2revb, a_scope = batch_graph.get_components() if f_atoms.size(0) == 0: return torch.zeros((len(a_scope), self.readout_2.out_features), device=self.W_i.weight.device) device = self.W_i.weight.device f_atoms, f_bonds = f_atoms.to(device), f_bonds.to(device) b2a, b2revb = b2a.to(device), b2revb.to(device) h0 = self.act(self.W_i(torch.cat([f_atoms.index_select(0, b2a), f_bonds], 1))) h = h0 for _ in range(self.depth): atom_msg = torch.zeros(f_atoms.size(0), self.hidden_size, device=device) atom_msg.index_add_(0, b2a.index_select(0, b2revb), h) m = atom_msg.index_select(0, b2a) - h.index_select(0, b2revb) h = self.dropout(self.act(h0 + self.W_h(m))) atom_msg = torch.zeros(f_atoms.size(0), self.hidden_size, device=device) atom_msg.index_add_(0, b2a.index_select(0, b2revb), h) atom_h = self.act(self.W_o(torch.cat([f_atoms, atom_msg], 1))) mol_vecs = [ atom_h.narrow(0, s, n).sum(0) if n > 0 else torch.zeros(self.hidden_size, device=device) for s, n in a_scope ] mol_vecs = torch.stack(mol_vecs) if self.global_feats_size > 0: if global_feats is None: raise ValueError("Global features expected but not provided") mol_vecs = torch.cat([mol_vecs, global_feats.to(device)], 1) x = self.dropout(self.act(self.readout_1(mol_vecs))) return self.readout_2(x)